summaryrefslogtreecommitdiff
path: root/nn/runtime/Manager.cpp
diff options
context:
space:
mode:
authorSlava Shklyaev <slavash@google.com>2019-01-23 16:09:51 +0000
committerSlava Shklyaev <slavash@google.com>2019-01-28 22:11:11 +0000
commite86a07b8843c8c2c4daba891841fe34aa867bc58 (patch)
tree3faa35f793b712696e7b142014c805e26662fbae /nn/runtime/Manager.cpp
parentb1a600bb8377b903c8eed5ce7a2c1e0093e13f8f (diff)
downloadml-e86a07b8843c8c2c4daba891841fe34aa867bc58.tar.gz
Add Extensions API
Please see the commit message of change Ia9b99015eec7a48bbf969cbe503862271f09adca for motivation. Bug: 118604960 Bug: 118606929 Test: NeuralNetworksTest_static Change-Id: I2703b963f040a846889554888ddd984eac6b6c08
Diffstat (limited to 'nn/runtime/Manager.cpp')
-rw-r--r--nn/runtime/Manager.cpp80
1 files changed, 64 insertions, 16 deletions
diff --git a/nn/runtime/Manager.cpp b/nn/runtime/Manager.cpp
index 3d73465fc..e3439a5a7 100644
--- a/nn/runtime/Manager.cpp
+++ b/nn/runtime/Manager.cpp
@@ -36,6 +36,40 @@ using ::android::hardware::neuralnetworks::V1_2::implementation::PreparedModelCa
namespace android {
namespace nn {
+uint32_t Device::getSizeOfData(const Operand& operand,
+ const std::map<std::string, uint16_t>& extensionNameToPrefix) const {
+ if (!isExtensionOperandType(operand.type)) {
+ return sizeOfData(operand);
+ }
+
+ // A slow naive implementation.
+ // TODO(b/123178734): Speed it up.
+ uint32_t operandType = static_cast<uint32_t>(operand.type);
+ uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
+ uint16_t prefix = operandType >> kLowBitsType;
+ uint16_t typeWithinExtension = operandType & ((1 << kLowBitsType) - 1);
+ for (const Extension& extension : getSupportedExtensions()) {
+ if (extensionNameToPrefix.at(extension.name) != prefix) {
+ continue;
+ }
+ for (auto& extensionOperandType : extension.operandTypes) {
+ if (extensionOperandType.type == typeWithinExtension) {
+ uint32_t numElements = 1;
+ if (extensionOperandType.isTensor) {
+ for (auto dimension : operand.dimensions) {
+ numElements *= dimension;
+ }
+ }
+ return numElements * extensionOperandType.byteSize;
+ }
+ }
+ }
+
+ CHECK(false) << "Cannot determine the size of extension operand type "
+ << toString(operand.type);
+ return 0;
+}
+
// A Device with actual underlying driver
class DriverDevice : public Device {
DISALLOW_IMPLICIT_CONSTRUCTORS(DriverDevice);
@@ -51,7 +85,9 @@ class DriverDevice : public Device {
VersionedIDevice* getInterface() override { return &mInterface; }
int64_t getFeatureLevel() override { return mInterface.getFeatureLevel(); }
int32_t getType() const override { return mInterface.getType(); }
- void getSupportedOperations(const Model& hidlModel, hidl_vec<bool>* supported) override;
+ hidl_vec<Extension> getSupportedExtensions() const override;
+ void getSupportedOperations(const Model& hidlModel,
+ hidl_vec<bool>* supportedOperations) override;
PerformanceInfo getFloat32Performance() const override { return mFloat32Performance; }
PerformanceInfo getQuantized8Performance() const override { return mQuantized8Performance; }
PerformanceInfo getRelaxedFloat32toFloat16Performance() const override {
@@ -68,6 +104,7 @@ class DriverDevice : public Device {
PerformanceInfo mFloat32Performance;
PerformanceInfo mQuantized8Performance;
PerformanceInfo mRelaxedFloat32toFloat16Performance;
+ hidl_vec<Extension> mSupportedExtensions;
#ifdef NN_DEBUGGABLE
// For debugging: behavior of IDevice::getSupportedOperations for SampleDriver.
@@ -90,12 +127,14 @@ bool DriverDevice::initialize() {
: 0;
#endif // NN_DEBUGGABLE
+ bool success = true;
ErrorStatus status = ErrorStatus::GENERAL_FAILURE;
+
Capabilities capabilities;
std::tie(status, capabilities) = mInterface.getCapabilities();
-
if (status != ErrorStatus::NONE) {
LOG(ERROR) << "IDevice::getCapabilities returned the error " << toString(status);
+ success = false;
} else {
VLOG(MANAGER) << "Capab " << capabilities.float32Performance.execTime;
VLOG(MANAGER) << "Capab " << capabilities.quantized8Performance.execTime;
@@ -105,14 +144,24 @@ bool DriverDevice::initialize() {
mRelaxedFloat32toFloat16Performance = capabilities.relaxedFloat32toFloat16Performance;
}
- auto result = mInterface.getVersionString();
+ std::tie(status, mVersionString) = mInterface.getVersionString();
// TODO(miaowang): add a validation test case for in case of error.
- if (result.first != ErrorStatus::NONE) {
+ if (status != ErrorStatus::NONE) {
LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(status);
- } else {
- mVersionString = result.second;
+ success = false;
}
- return status == ErrorStatus::NONE;
+
+ std::tie(status, mSupportedExtensions) = mInterface.getSupportedExtensions();
+ if (status != ErrorStatus::NONE) {
+ LOG(ERROR) << "IDevice::getSupportedExtensions returned the error " << toString(status);
+ success = false;
+ }
+
+ return success;
+}
+
+hidl_vec<Extension> DriverDevice::getSupportedExtensions() const {
+ return mSupportedExtensions;
}
void DriverDevice::getSupportedOperations(const Model& hidlModel,
@@ -232,7 +281,9 @@ class CpuDevice : public Device {
VersionedIDevice* getInterface() override { return nullptr; }
int64_t getFeatureLevel() override { return kFeatureLevel; }
int32_t getType() const override { return ANEURALNETWORKS_DEVICE_CPU; }
- void getSupportedOperations(const Model& hidlModel, hidl_vec<bool>* supported) override;
+ hidl_vec<Extension> getSupportedExtensions() const override { return {/* No extensions. */}; }
+ void getSupportedOperations(const Model& hidlModel,
+ hidl_vec<bool>* supportedOperations) override;
PerformanceInfo getFloat32Performance() const override { return kPerformance; }
PerformanceInfo getQuantized8Performance() const override { return kPerformance; }
PerformanceInfo getRelaxedFloat32toFloat16Performance() const override { return kPerformance; }
@@ -250,19 +301,16 @@ class CpuDevice : public Device {
const PerformanceInfo kPerformance = {.execTime = 1.0f, .powerUsage = 1.0f};
};
-void CpuDevice::getSupportedOperations(const Model& hidlModel, hidl_vec<bool>* supported) {
+void CpuDevice::getSupportedOperations(const Model& hidlModel,
+ hidl_vec<bool>* supportedOperations) {
const size_t count = hidlModel.operations.size();
- hidl_vec<bool> supportedOperations(count);
+ hidl_vec<bool> result(count);
for (size_t i = 0; i < count; i++) {
// TODO(b/119870033): Decide whether and how post-P operations would be supported on CPU.
// CPU fallback should support all the operations except for OEM_OPERATION
- if (hidlModel.operations[i].type == OperationType::OEM_OPERATION) {
- supportedOperations[i] = false;
- } else {
- supportedOperations[i] = true;
- }
+ result[i] = hidlModel.operations[i].type != OperationType::OEM_OPERATION;
}
- *supported = std::move(supportedOperations);
+ *supportedOperations = std::move(result);
}
int CpuDevice::prepareModel(const Model& hidlModel, ExecutionPreference executionPreference,