diff options
Diffstat (limited to 'nn/runtime/test/TestNeuralNetworksWrapper.h')
-rw-r--r-- | nn/runtime/test/TestNeuralNetworksWrapper.h | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/nn/runtime/test/TestNeuralNetworksWrapper.h b/nn/runtime/test/TestNeuralNetworksWrapper.h index 6df16e217..ae40121c7 100644 --- a/nn/runtime/test/TestNeuralNetworksWrapper.h +++ b/nn/runtime/test/TestNeuralNetworksWrapper.h @@ -23,6 +23,7 @@ #include <math.h> #include <algorithm> +#include <memory> #include <optional> #include <string> #include <utility> @@ -242,6 +243,21 @@ class Model { class Compilation { public: + // On success, createForDevice(s) will return Result::NO_ERROR and the created compilation; + // otherwise, it will return the error code and Compilation object wrapping a nullptr handle. + static std::pair<Result, Compilation> createForDevice(const Model* model, + const ANeuralNetworksDevice* device) { + return createForDevices(model, {device}); + } + static std::pair<Result, Compilation> createForDevices( + const Model* model, const std::vector<const ANeuralNetworksDevice*>& devices) { + ANeuralNetworksCompilation* compilation = nullptr; + const Result result = static_cast<Result>(ANeuralNetworksCompilation_createForDevices( + model->getHandle(), devices.empty() ? nullptr : devices.data(), devices.size(), + &compilation)); + return {result, Compilation(compilation)}; + } + Compilation(const Model* model) { int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation); if (result != 0) { @@ -272,11 +288,6 @@ class Compilation { return *this; } - Result createForDevice(const Model* model, const ANeuralNetworksDevice* device) { - return static_cast<Result>(ANeuralNetworksCompilation_createForDevices( - model->getHandle(), &device, 1, &mCompilation)); - } - Result setPreference(ExecutePreference preference) { return static_cast<Result>(ANeuralNetworksCompilation_setPreference( mCompilation, static_cast<int32_t>(preference))); @@ -300,6 +311,9 @@ class Compilation { ANeuralNetworksCompilation* getHandle() const { return mCompilation; } protected: + // Takes the ownership of ANeuralNetworksCompilation. + Compilation(ANeuralNetworksCompilation* compilation) : mCompilation(compilation) {} + ANeuralNetworksCompilation* mCompilation = nullptr; }; |