summaryrefslogtreecommitdiff
path: root/nn/runtime/test/TestNeuralNetworksWrapper.h
diff options
context:
space:
mode:
Diffstat (limited to 'nn/runtime/test/TestNeuralNetworksWrapper.h')
-rw-r--r--nn/runtime/test/TestNeuralNetworksWrapper.h24
1 files changed, 19 insertions, 5 deletions
diff --git a/nn/runtime/test/TestNeuralNetworksWrapper.h b/nn/runtime/test/TestNeuralNetworksWrapper.h
index 6df16e217..ae40121c7 100644
--- a/nn/runtime/test/TestNeuralNetworksWrapper.h
+++ b/nn/runtime/test/TestNeuralNetworksWrapper.h
@@ -23,6 +23,7 @@
#include <math.h>
#include <algorithm>
+#include <memory>
#include <optional>
#include <string>
#include <utility>
@@ -242,6 +243,21 @@ class Model {
class Compilation {
public:
+ // On success, createForDevice(s) will return Result::NO_ERROR and the created compilation;
+ // otherwise, it will return the error code and Compilation object wrapping a nullptr handle.
+ static std::pair<Result, Compilation> createForDevice(const Model* model,
+ const ANeuralNetworksDevice* device) {
+ return createForDevices(model, {device});
+ }
+ static std::pair<Result, Compilation> createForDevices(
+ const Model* model, const std::vector<const ANeuralNetworksDevice*>& devices) {
+ ANeuralNetworksCompilation* compilation = nullptr;
+ const Result result = static_cast<Result>(ANeuralNetworksCompilation_createForDevices(
+ model->getHandle(), devices.empty() ? nullptr : devices.data(), devices.size(),
+ &compilation));
+ return {result, Compilation(compilation)};
+ }
+
Compilation(const Model* model) {
int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation);
if (result != 0) {
@@ -272,11 +288,6 @@ class Compilation {
return *this;
}
- Result createForDevice(const Model* model, const ANeuralNetworksDevice* device) {
- return static_cast<Result>(ANeuralNetworksCompilation_createForDevices(
- model->getHandle(), &device, 1, &mCompilation));
- }
-
Result setPreference(ExecutePreference preference) {
return static_cast<Result>(ANeuralNetworksCompilation_setPreference(
mCompilation, static_cast<int32_t>(preference)));
@@ -300,6 +311,9 @@ class Compilation {
ANeuralNetworksCompilation* getHandle() const { return mCompilation; }
protected:
+ // Takes the ownership of ANeuralNetworksCompilation.
+ Compilation(ANeuralNetworksCompilation* compilation) : mCompilation(compilation) {}
+
ANeuralNetworksCompilation* mCompilation = nullptr;
};