summaryrefslogtreecommitdiff
path: root/nn/runtime/test/TestNeuralNetworksWrapper.h
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2019-08-02 15:04:00 -0700
committerMichael Butler <butlermichael@google.com>2019-08-16 12:32:59 -0700
commit008caeab69787b8978e874f8fb3811e2400f5d55 (patch)
tree27050dbbc1cc757a3570e69813ca7b2dc51fa900 /nn/runtime/test/TestNeuralNetworksWrapper.h
parent907c72d4c0c483f06824c91b6a706ac0d9634f22 (diff)
downloadml-008caeab69787b8978e874f8fb3811e2400f5d55.tar.gz
Cleanup NNAPI runtime Memory objects
Prior to this CL, runtime Memory* objects were default constructed and set to a value later. Rebinding the value led to multiple bugs in the past and made the Memory* objects prone to data races. This CL addresses these issues by determining all values in the Memory* objects' factory methods, and making the Memory* objects immutable after construction. This CL also untangles some of the inheritance hierarchy. Prior to this CL, all MemoryFd and MemoryAHWB inherited from Memory, which was effectively ashmem-based memory. This CL separates MemoryAshmem into its own, unique class type, and has a common base Memory class. This reorganization also uncovered that getPointer was only used in the runtime on ashmem-based memory, and removes the unused getPointer methods. Finally, this CL improves documentation of NeuralNetworks.h in two ways: 1) Fixes the typo "ANeuralNetworksMemory_createFromAHardwarBuffer" 2) Documents the missing lifetime constraints of ANeuralNetworksMemory_createFromAHardwareBuffer Fixes: 138852228 Fixes: 69632863 Fixes: 69633035 Fixes: 129572123 Fixes: 132323765 Fixes: 139213289 Test: mma Test: atest NeuralNetworksTest_static Test: atest CtsNNAPITestCases Change-Id: I49a2356d6b8fc38e501d8e37ba8a8893f3e91395
Diffstat (limited to 'nn/runtime/test/TestNeuralNetworksWrapper.h')
-rw-r--r--nn/runtime/test/TestNeuralNetworksWrapper.h158
1 files changed, 156 insertions, 2 deletions
diff --git a/nn/runtime/test/TestNeuralNetworksWrapper.h b/nn/runtime/test/TestNeuralNetworksWrapper.h
index e37817058..8bc8a9e7c 100644
--- a/nn/runtime/test/TestNeuralNetworksWrapper.h
+++ b/nn/runtime/test/TestNeuralNetworksWrapper.h
@@ -38,13 +38,167 @@ using wrapper::ExecutePreference;
using wrapper::ExtensionModel;
using wrapper::ExtensionOperandParams;
using wrapper::ExtensionOperandType;
-using wrapper::Memory;
-using wrapper::Model;
using wrapper::OperandType;
using wrapper::Result;
using wrapper::SymmPerChannelQuantParams;
using wrapper::Type;
+class Memory {
+ public:
+ // Takes ownership of a ANeuralNetworksMemory
+ Memory(ANeuralNetworksMemory* memory) : mMemory(memory) {}
+
+ Memory(size_t size, int protect, int fd, size_t offset) {
+ mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) ==
+ ANEURALNETWORKS_NO_ERROR;
+ }
+
+ Memory(AHardwareBuffer* buffer) {
+ mValid = ANeuralNetworksMemory_createFromAHardwareBuffer(buffer, &mMemory) ==
+ ANEURALNETWORKS_NO_ERROR;
+ }
+
+ ~Memory() { ANeuralNetworksMemory_free(mMemory); }
+
+ // Disallow copy semantics to ensure the runtime object can only be freed
+ // once. Copy semantics could be enabled if some sort of reference counting
+ // or deep-copy system for runtime objects is added later.
+ Memory(const Memory&) = delete;
+ Memory& operator=(const Memory&) = delete;
+
+ // Move semantics to remove access to the runtime object from the wrapper
+ // object that is being moved. This ensures the runtime object will be
+ // freed only once.
+ Memory(Memory&& other) { *this = std::move(other); }
+ Memory& operator=(Memory&& other) {
+ if (this != &other) {
+ ANeuralNetworksMemory_free(mMemory);
+ mMemory = other.mMemory;
+ mValid = other.mValid;
+ other.mMemory = nullptr;
+ other.mValid = false;
+ }
+ return *this;
+ }
+
+ ANeuralNetworksMemory* get() const { return mMemory; }
+ bool isValid() const { return mValid; }
+
+ private:
+ ANeuralNetworksMemory* mMemory = nullptr;
+ bool mValid = true;
+};
+
+class Model {
+ public:
+ Model() {
+ // TODO handle the value returned by this call
+ ANeuralNetworksModel_create(&mModel);
+ }
+ ~Model() { ANeuralNetworksModel_free(mModel); }
+
+ // Disallow copy semantics to ensure the runtime object can only be freed
+ // once. Copy semantics could be enabled if some sort of reference counting
+ // or deep-copy system for runtime objects is added later.
+ Model(const Model&) = delete;
+ Model& operator=(const Model&) = delete;
+
+ // Move semantics to remove access to the runtime object from the wrapper
+ // object that is being moved. This ensures the runtime object will be
+ // freed only once.
+ Model(Model&& other) { *this = std::move(other); }
+ Model& operator=(Model&& other) {
+ if (this != &other) {
+ ANeuralNetworksModel_free(mModel);
+ mModel = other.mModel;
+ mNextOperandId = other.mNextOperandId;
+ mValid = other.mValid;
+ other.mModel = nullptr;
+ other.mNextOperandId = 0;
+ other.mValid = false;
+ }
+ return *this;
+ }
+
+ Result finish() {
+ if (mValid) {
+ auto result = static_cast<Result>(ANeuralNetworksModel_finish(mModel));
+ if (result != Result::NO_ERROR) {
+ mValid = false;
+ }
+ return result;
+ } else {
+ return Result::BAD_STATE;
+ }
+ }
+
+ uint32_t addOperand(const OperandType* type) {
+ if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
+ ANEURALNETWORKS_NO_ERROR) {
+ mValid = false;
+ }
+ if (type->channelQuant) {
+ if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(
+ mModel, mNextOperandId, &type->channelQuant.value().params) !=
+ ANEURALNETWORKS_NO_ERROR) {
+ mValid = false;
+ }
+ }
+ return mNextOperandId++;
+ }
+
+ void setOperandValue(uint32_t index, const void* buffer, size_t length) {
+ if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) !=
+ ANEURALNETWORKS_NO_ERROR) {
+ mValid = false;
+ }
+ }
+
+ void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset,
+ size_t length) {
+ if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset,
+ length) != ANEURALNETWORKS_NO_ERROR) {
+ mValid = false;
+ }
+ }
+
+ void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs,
+ const std::vector<uint32_t>& outputs) {
+ if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()),
+ inputs.data(), static_cast<uint32_t>(outputs.size()),
+ outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
+ mValid = false;
+ }
+ }
+ void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs,
+ const std::vector<uint32_t>& outputs) {
+ if (ANeuralNetworksModel_identifyInputsAndOutputs(
+ mModel, static_cast<uint32_t>(inputs.size()), inputs.data(),
+ static_cast<uint32_t>(outputs.size()),
+ outputs.data()) != ANEURALNETWORKS_NO_ERROR) {
+ mValid = false;
+ }
+ }
+
+ void relaxComputationFloat32toFloat16(bool isRelax) {
+ if (ANeuralNetworksModel_relaxComputationFloat32toFloat16(mModel, isRelax) ==
+ ANEURALNETWORKS_NO_ERROR) {
+ mRelaxed = isRelax;
+ }
+ }
+
+ ANeuralNetworksModel* getHandle() const { return mModel; }
+ bool isValid() const { return mValid; }
+ bool isRelaxed() const { return mRelaxed; }
+
+ protected:
+ ANeuralNetworksModel* mModel = nullptr;
+ // We keep track of the operand ID as a convenience to the caller.
+ uint32_t mNextOperandId = 0;
+ bool mValid = true;
+ bool mRelaxed = false;
+};
+
class Compilation {
public:
Compilation(const Model* model) {