diff options
author | Michael Butler <butlermichael@google.com> | 2019-08-02 15:04:00 -0700 |
---|---|---|
committer | Michael Butler <butlermichael@google.com> | 2019-08-16 12:32:59 -0700 |
commit | 008caeab69787b8978e874f8fb3811e2400f5d55 (patch) | |
tree | 27050dbbc1cc757a3570e69813ca7b2dc51fa900 /nn/runtime/test/TestNeuralNetworksWrapper.h | |
parent | 907c72d4c0c483f06824c91b6a706ac0d9634f22 (diff) | |
download | ml-008caeab69787b8978e874f8fb3811e2400f5d55.tar.gz |
Cleanup NNAPI runtime Memory objects
Prior to this CL, runtime Memory* objects were default constructed and
set to a value later. Rebinding the value led to multiple bugs in the
past and made the Memory* objects prone to data races. This CL addresses
these issues by determining all values in the Memory* objects' factory
methods, and making the Memory* objects immutable after construction.
This CL also untangles some of the inheritance hierarchy. Prior to this
CL, all MemoryFd and MemoryAHWB inherited from Memory, which was
effectively ashmem-based memory. This CL separates MemoryAshmem into its
own, unique class type, and has a common base Memory class. This
reorganization also uncovered that getPointer was only used in the
runtime on ashmem-based memory, and removes the unused getPointer
methods.
Finally, this CL improves documentation of NeuralNetworks.h in two ways:
1) Fixes the typo "ANeuralNetworksMemory_createFromAHardwarBuffer"
2) Documents the missing lifetime constraints of
ANeuralNetworksMemory_createFromAHardwareBuffer
Fixes: 138852228
Fixes: 69632863
Fixes: 69633035
Fixes: 129572123
Fixes: 132323765
Fixes: 139213289
Test: mma
Test: atest NeuralNetworksTest_static
Test: atest CtsNNAPITestCases
Change-Id: I49a2356d6b8fc38e501d8e37ba8a8893f3e91395
Diffstat (limited to 'nn/runtime/test/TestNeuralNetworksWrapper.h')
-rw-r--r-- | nn/runtime/test/TestNeuralNetworksWrapper.h | 158 |
1 files changed, 156 insertions, 2 deletions
diff --git a/nn/runtime/test/TestNeuralNetworksWrapper.h b/nn/runtime/test/TestNeuralNetworksWrapper.h index e37817058..8bc8a9e7c 100644 --- a/nn/runtime/test/TestNeuralNetworksWrapper.h +++ b/nn/runtime/test/TestNeuralNetworksWrapper.h @@ -38,13 +38,167 @@ using wrapper::ExecutePreference; using wrapper::ExtensionModel; using wrapper::ExtensionOperandParams; using wrapper::ExtensionOperandType; -using wrapper::Memory; -using wrapper::Model; using wrapper::OperandType; using wrapper::Result; using wrapper::SymmPerChannelQuantParams; using wrapper::Type; +class Memory { + public: + // Takes ownership of a ANeuralNetworksMemory + Memory(ANeuralNetworksMemory* memory) : mMemory(memory) {} + + Memory(size_t size, int protect, int fd, size_t offset) { + mValid = ANeuralNetworksMemory_createFromFd(size, protect, fd, offset, &mMemory) == + ANEURALNETWORKS_NO_ERROR; + } + + Memory(AHardwareBuffer* buffer) { + mValid = ANeuralNetworksMemory_createFromAHardwareBuffer(buffer, &mMemory) == + ANEURALNETWORKS_NO_ERROR; + } + + ~Memory() { ANeuralNetworksMemory_free(mMemory); } + + // Disallow copy semantics to ensure the runtime object can only be freed + // once. Copy semantics could be enabled if some sort of reference counting + // or deep-copy system for runtime objects is added later. + Memory(const Memory&) = delete; + Memory& operator=(const Memory&) = delete; + + // Move semantics to remove access to the runtime object from the wrapper + // object that is being moved. This ensures the runtime object will be + // freed only once. + Memory(Memory&& other) { *this = std::move(other); } + Memory& operator=(Memory&& other) { + if (this != &other) { + ANeuralNetworksMemory_free(mMemory); + mMemory = other.mMemory; + mValid = other.mValid; + other.mMemory = nullptr; + other.mValid = false; + } + return *this; + } + + ANeuralNetworksMemory* get() const { return mMemory; } + bool isValid() const { return mValid; } + + private: + ANeuralNetworksMemory* mMemory = nullptr; + bool mValid = true; +}; + +class Model { + public: + Model() { + // TODO handle the value returned by this call + ANeuralNetworksModel_create(&mModel); + } + ~Model() { ANeuralNetworksModel_free(mModel); } + + // Disallow copy semantics to ensure the runtime object can only be freed + // once. Copy semantics could be enabled if some sort of reference counting + // or deep-copy system for runtime objects is added later. + Model(const Model&) = delete; + Model& operator=(const Model&) = delete; + + // Move semantics to remove access to the runtime object from the wrapper + // object that is being moved. This ensures the runtime object will be + // freed only once. + Model(Model&& other) { *this = std::move(other); } + Model& operator=(Model&& other) { + if (this != &other) { + ANeuralNetworksModel_free(mModel); + mModel = other.mModel; + mNextOperandId = other.mNextOperandId; + mValid = other.mValid; + other.mModel = nullptr; + other.mNextOperandId = 0; + other.mValid = false; + } + return *this; + } + + Result finish() { + if (mValid) { + auto result = static_cast<Result>(ANeuralNetworksModel_finish(mModel)); + if (result != Result::NO_ERROR) { + mValid = false; + } + return result; + } else { + return Result::BAD_STATE; + } + } + + uint32_t addOperand(const OperandType* type) { + if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != + ANEURALNETWORKS_NO_ERROR) { + mValid = false; + } + if (type->channelQuant) { + if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams( + mModel, mNextOperandId, &type->channelQuant.value().params) != + ANEURALNETWORKS_NO_ERROR) { + mValid = false; + } + } + return mNextOperandId++; + } + + void setOperandValue(uint32_t index, const void* buffer, size_t length) { + if (ANeuralNetworksModel_setOperandValue(mModel, index, buffer, length) != + ANEURALNETWORKS_NO_ERROR) { + mValid = false; + } + } + + void setOperandValueFromMemory(uint32_t index, const Memory* memory, uint32_t offset, + size_t length) { + if (ANeuralNetworksModel_setOperandValueFromMemory(mModel, index, memory->get(), offset, + length) != ANEURALNETWORKS_NO_ERROR) { + mValid = false; + } + } + + void addOperation(ANeuralNetworksOperationType type, const std::vector<uint32_t>& inputs, + const std::vector<uint32_t>& outputs) { + if (ANeuralNetworksModel_addOperation(mModel, type, static_cast<uint32_t>(inputs.size()), + inputs.data(), static_cast<uint32_t>(outputs.size()), + outputs.data()) != ANEURALNETWORKS_NO_ERROR) { + mValid = false; + } + } + void identifyInputsAndOutputs(const std::vector<uint32_t>& inputs, + const std::vector<uint32_t>& outputs) { + if (ANeuralNetworksModel_identifyInputsAndOutputs( + mModel, static_cast<uint32_t>(inputs.size()), inputs.data(), + static_cast<uint32_t>(outputs.size()), + outputs.data()) != ANEURALNETWORKS_NO_ERROR) { + mValid = false; + } + } + + void relaxComputationFloat32toFloat16(bool isRelax) { + if (ANeuralNetworksModel_relaxComputationFloat32toFloat16(mModel, isRelax) == + ANEURALNETWORKS_NO_ERROR) { + mRelaxed = isRelax; + } + } + + ANeuralNetworksModel* getHandle() const { return mModel; } + bool isValid() const { return mValid; } + bool isRelaxed() const { return mRelaxed; } + + protected: + ANeuralNetworksModel* mModel = nullptr; + // We keep track of the operand ID as a convenience to the caller. + uint32_t mNextOperandId = 0; + bool mValid = true; + bool mRelaxed = false; +}; + class Compilation { public: Compilation(const Model* model) { |