summaryrefslogtreecommitdiff
path: root/nn/runtime/test/TestNeuralNetworksWrapper.h
diff options
context:
space:
mode:
authorXusong Wang <xusongw@google.com>2020-04-09 14:52:23 -0700
committerXusong Wang <xusongw@google.com>2020-04-17 17:38:50 -0700
commit69db15d05d4f7a3c168185520cdcdb60f1d09fab (patch)
treefacba5b5cab875aebd0795a458a856008901eb19 /nn/runtime/test/TestNeuralNetworksWrapper.h
parent9746e462ca0f4c2239feb28314e8857974322fda (diff)
downloadml-69db15d05d4f7a3c168185520cdcdb60f1d09fab.tar.gz
Add NNT_static internal tests for device memory allocation.
These tests use a customized IDevice to test the device memory allocation and fallback logic. This CL also allows the runtime to dispatch device memory allocation with dynamic shape to drivers. Additionally, this CL fixes a bug that a failed device memory allocation will return BAD_DATA -- it should return OP_FAILED instead. Bug: 152209365 Test: NNT_static Change-Id: I1facb2dad345958c3b9b1bab4a9564085c382c4a
Diffstat (limited to 'nn/runtime/test/TestNeuralNetworksWrapper.h')
-rw-r--r--nn/runtime/test/TestNeuralNetworksWrapper.h24
1 files changed, 19 insertions, 5 deletions
diff --git a/nn/runtime/test/TestNeuralNetworksWrapper.h b/nn/runtime/test/TestNeuralNetworksWrapper.h
index 6df16e217..ae40121c7 100644
--- a/nn/runtime/test/TestNeuralNetworksWrapper.h
+++ b/nn/runtime/test/TestNeuralNetworksWrapper.h
@@ -23,6 +23,7 @@
#include <math.h>
#include <algorithm>
+#include <memory>
#include <optional>
#include <string>
#include <utility>
@@ -242,6 +243,21 @@ class Model {
class Compilation {
public:
+ // On success, createForDevice(s) will return Result::NO_ERROR and the created compilation;
+ // otherwise, it will return the error code and Compilation object wrapping a nullptr handle.
+ static std::pair<Result, Compilation> createForDevice(const Model* model,
+ const ANeuralNetworksDevice* device) {
+ return createForDevices(model, {device});
+ }
+ static std::pair<Result, Compilation> createForDevices(
+ const Model* model, const std::vector<const ANeuralNetworksDevice*>& devices) {
+ ANeuralNetworksCompilation* compilation = nullptr;
+ const Result result = static_cast<Result>(ANeuralNetworksCompilation_createForDevices(
+ model->getHandle(), devices.empty() ? nullptr : devices.data(), devices.size(),
+ &compilation));
+ return {result, Compilation(compilation)};
+ }
+
Compilation(const Model* model) {
int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
if (result != 0) {
@@ -272,11 +288,6 @@ class Compilation {
return *this;
}
- Result createForDevice(const Model* model, const ANeuralNetworksDevice* device) {
- return static_cast<Result>(ANeuralNetworksCompilation_createForDevices(
- model->getHandle(), &device, 1, &mCompilation));
- }
-
Result setPreference(ExecutePreference preference) {
return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
mCompilation, static_cast<int32_t>(preference)));
@@ -300,6 +311,9 @@ class Compilation {
ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
protected:
+ // Takes the ownership of ANeuralNetworksCompilation.
+ Compilation(ANeuralNetworksCompilation* compilation) : mCompilation(compilation) {}
+
ANeuralNetworksCompilation* mCompilation = nullptr;
};