/* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #define LOG_TAG "FibonacciDriver" #include "FibonacciDriver.h" #include #include "FibonacciExtension.h" #include "HalInterfaces.h" #include "NeuralNetworksExtensions.h" #include "OperationResolver.h" #include "OperationsUtils.h" #include "Utils.h" #include "ValidateHal.h" namespace android { namespace nn { namespace sample_driver { namespace { using namespace hal; const uint8_t kLowBitsType = static_cast(ExtensionTypeEncoding::LOW_BITS_TYPE); const uint32_t kTypeWithinExtensionMask = (1 << kLowBitsType) - 1; namespace fibonacci_op { constexpr char kOperationName[] = "EXAMPLE_FIBONACCI"; constexpr uint32_t kNumInputs = 1; constexpr uint32_t kInputN = 0; constexpr uint32_t kNumOutputs = 1; constexpr uint32_t kOutputTensor = 0; bool getFibonacciExtensionPrefix(const Model& model, uint16_t* prefix) { NN_RET_CHECK_EQ(model.extensionNameToPrefix.size(), 1u); // Assumes no other extensions in use. NN_RET_CHECK_EQ(model.extensionNameToPrefix[0].name, EXAMPLE_FIBONACCI_EXTENSION_NAME); *prefix = model.extensionNameToPrefix[0].prefix; return true; } bool isFibonacciOperation(const Operation& operation, const Model& model) { int32_t operationType = static_cast(operation.type); uint16_t prefix; NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix)); NN_RET_CHECK_EQ(operationType, (prefix << kLowBitsType) | EXAMPLE_FIBONACCI); return true; } bool validate(const Operation& operation, const Model& model) { NN_RET_CHECK(isFibonacciOperation(operation, model)); NN_RET_CHECK_EQ(operation.inputs.size(), kNumInputs); NN_RET_CHECK_EQ(operation.outputs.size(), kNumOutputs); int32_t inputType = static_cast(model.main.operands[operation.inputs[0]].type); int32_t outputType = static_cast(model.main.operands[operation.outputs[0]].type); uint16_t prefix; NN_RET_CHECK(getFibonacciExtensionPrefix(model, &prefix)); NN_RET_CHECK(inputType == ((prefix << kLowBitsType) | EXAMPLE_INT64) || inputType == ANEURALNETWORKS_TENSOR_FLOAT32); NN_RET_CHECK(outputType == ((prefix << kLowBitsType) | EXAMPLE_TENSOR_QUANT64_ASYMM) || outputType == ANEURALNETWORKS_TENSOR_FLOAT32); return true; } bool prepare(IOperationExecutionContext* context) { int64_t n; if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) { n = static_cast(context->getInputValue(kInputN)); } else { n = context->getInputValue(kInputN); } NN_RET_CHECK_GE(n, 1); Shape output = context->getOutputShape(kOutputTensor); output.dimensions = {static_cast(n)}; return context->setOutputShape(kOutputTensor, output); } template bool compute(int32_t n, ScaleT outputScale, ZeroPointT outputZeroPoint, OutputT* output) { // Compute the Fibonacci numbers. if (n >= 1) { output[0] = 1; } if (n >= 2) { output[1] = 1; } if (n >= 3) { for (int32_t i = 2; i < n; ++i) { output[i] = output[i - 1] + output[i - 2]; } } // Quantize output. for (int32_t i = 0; i < n; ++i) { output[i] = output[i] / outputScale + outputZeroPoint; } return true; } bool execute(IOperationExecutionContext* context) { int64_t n; if (context->getInputType(kInputN) == OperandType::TENSOR_FLOAT32) { n = static_cast(context->getInputValue(kInputN)); } else { n = context->getInputValue(kInputN); } if (context->getOutputType(kOutputTensor) == OperandType::TENSOR_FLOAT32) { float* output = context->getOutputBuffer(kOutputTensor); return compute(n, /*scale=*/1.0, /*zeroPoint=*/0, output); } else { uint64_t* output = context->getOutputBuffer(kOutputTensor); Shape outputShape = context->getOutputShape(kOutputTensor); auto outputQuant = reinterpret_cast( outputShape.extraParams.extension().data()); return compute(n, outputQuant->scale, outputQuant->zeroPoint, output); } } } // namespace fibonacci_op } // namespace const OperationRegistration* FibonacciOperationResolver::findOperation( OperationType operationType) const { // .validate is omitted because it's not used by the extension driver. static OperationRegistration operationRegistration(operationType, fibonacci_op::kOperationName, nullptr, fibonacci_op::prepare, fibonacci_op::execute, {}); uint16_t prefix = static_cast(operationType) >> kLowBitsType; uint16_t typeWithinExtension = static_cast(operationType) & kTypeWithinExtensionMask; // Assumes no other extensions in use. return prefix != 0 && typeWithinExtension == EXAMPLE_FIBONACCI ? &operationRegistration : nullptr; } Return FibonacciDriver::getSupportedExtensions(getSupportedExtensions_cb cb) { cb(V1_0::ErrorStatus::NONE, { { .name = EXAMPLE_FIBONACCI_EXTENSION_NAME, .operandTypes = { { .type = EXAMPLE_INT64, .isTensor = false, .byteSize = 8, }, { .type = EXAMPLE_TENSOR_QUANT64_ASYMM, .isTensor = true, .byteSize = 8, }, }, }, }); return Void(); } Return FibonacciDriver::getCapabilities_1_3(getCapabilities_1_3_cb cb) { android::nn::initVLogMask(); VLOG(DRIVER) << "getCapabilities()"; static const PerformanceInfo kPerf = {.execTime = 1.0f, .powerUsage = 1.0f}; Capabilities capabilities = { .relaxedFloat32toFloat16PerformanceScalar = kPerf, .relaxedFloat32toFloat16PerformanceTensor = kPerf, .operandPerformance = nonExtensionOperandPerformance(kPerf), .ifPerformance = kPerf, .whilePerformance = kPerf}; cb(V1_3::ErrorStatus::NONE, capabilities); return Void(); } Return FibonacciDriver::getSupportedOperations_1_3(const V1_3::Model& model, getSupportedOperations_1_3_cb cb) { VLOG(DRIVER) << "getSupportedOperations()"; if (!validateModel(model)) { cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {}); return Void(); } const size_t count = model.main.operations.size(); std::vector supported(count); for (size_t i = 0; i < count; ++i) { const Operation& operation = model.main.operations[i]; if (fibonacci_op::isFibonacciOperation(operation, model)) { if (!fibonacci_op::validate(operation, model)) { cb(V1_3::ErrorStatus::INVALID_ARGUMENT, {}); return Void(); } supported[i] = true; } } cb(V1_3::ErrorStatus::NONE, supported); return Void(); } } // namespace sample_driver } // namespace nn } // namespace android