diff options
author | Michael Butler <butlermichael@google.com> | 2020-11-05 17:11:22 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2020-11-05 17:11:22 +0000 |
commit | 67d7e29ad3430e716a530799b3fb23ead678601e (patch) | |
tree | 7453b5e0852551022c3ba140674e60661facb938 | |
parent | 9d8007c5342dcdd324af06af8ca18f6ea2974af1 (diff) | |
parent | 2bd2cae2be76017277a7aae6b5fe4bc7fc1ff67c (diff) | |
download | ml-67d7e29ad3430e716a530799b3fb23ead678601e.tar.gz |
Merge "Update NNAPI canonical validation"
-rw-r--r-- | nn/common/IndexedShapeWrapper.cpp | 4 | ||||
-rw-r--r-- | nn/common/TypeUtils.cpp | 16 | ||||
-rw-r--r-- | nn/common/Validation.cpp | 49 | ||||
-rw-r--r-- | nn/common/include/CpuOperationUtils.h | 3 | ||||
-rw-r--r-- | nn/common/include/OperationResolver.h | 14 | ||||
-rw-r--r-- | nn/common/include/OperationsUtils.h | 3 | ||||
-rw-r--r-- | nn/common/include/Utils.h | 90 | ||||
-rw-r--r-- | nn/common/include/ValidateHal.h | 10 | ||||
-rw-r--r-- | nn/common/include/nnapi/TypeUtils.h | 92 | ||||
-rw-r--r-- | nn/common/operations/Broadcast.cpp | 2 | ||||
-rw-r--r-- | nn/common/operations/Reshape.cpp | 8 |
11 files changed, 157 insertions, 134 deletions
diff --git a/nn/common/IndexedShapeWrapper.cpp b/nn/common/IndexedShapeWrapper.cpp index e90665986..8101c016f 100644 --- a/nn/common/IndexedShapeWrapper.cpp +++ b/nn/common/IndexedShapeWrapper.cpp @@ -18,6 +18,10 @@ #include "IndexedShapeWrapper.h" +#include <vector> + +#include "Utils.h" + namespace android { namespace nn { diff --git a/nn/common/TypeUtils.cpp b/nn/common/TypeUtils.cpp index f5bbe07f0..ad17d9444 100644 --- a/nn/common/TypeUtils.cpp +++ b/nn/common/TypeUtils.cpp @@ -816,6 +816,22 @@ std::ostream& operator<<(std::ostream& os, const Version& version) { return os << "Version{" << underlyingType(version) << "}"; } +std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion) { + switch (halVersion) { + case HalVersion::UNKNOWN: + return os << "UNKNOWN HAL version"; + case HalVersion::V1_0: + return os << "HAL version 1.0"; + case HalVersion::V1_1: + return os << "HAL version 1.1"; + case HalVersion::V1_2: + return os << "HAL version 1.2"; + case HalVersion::V1_3: + return os << "HAL version 1.3"; + } + return os << "HalVersion{" << underlyingType(halVersion) << "}"; +} + bool operator==(const Timing& a, const Timing& b) { return a.timeOnDevice == b.timeOnDevice && a.timeInDriver == b.timeInDriver; } diff --git a/nn/common/Validation.cpp b/nn/common/Validation.cpp index 9127bec7c..68f778d22 100644 --- a/nn/common/Validation.cpp +++ b/nn/common/Validation.cpp @@ -35,6 +35,7 @@ #include "ControlFlow.h" #include "OperandTypes.h" +#include "OperationResolver.h" #include "OperationTypes.h" #include "Result.h" #include "TypeUtils.h" @@ -1174,17 +1175,13 @@ Result<Version> validateMemoryDescImpl( return Version::ANDROID_R; } -// TODO: Enable this block of code once canonical types are integrated in the rest of the NNAPI -// codebase. -#if 0 class OperationValidationContext : public IOperationValidationContext { DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext); public: - OperationValidationContext(const char* operationName, const std::vector<uint32_t>& - inputIndexes, + OperationValidationContext(const char* operationName, const std::vector<uint32_t>& inputIndexes, const std::vector<uint32_t>& outputIndexes, - const std::vector<Operand>& operands, Version version) + const std::vector<Operand>& operands, HalVersion version) : operationName(operationName), inputIndexes(inputIndexes), outputIndexes(outputIndexes), @@ -1192,12 +1189,12 @@ class OperationValidationContext : public IOperationValidationContext { version(version) {} const char* getOperationName() const override; - Version getVersion() const override; + HalVersion getHalVersion() const override; uint32_t getNumInputs() const override; OperandType getInputType(uint32_t index) const override; Shape getInputShape(uint32_t index) const override; - const Operand::ExtraParams getInputExtraParams(uint32_t index) const override; + const Operand::ExtraParams& getInputExtraParams(uint32_t index) const override; uint32_t getNumOutputs() const override; OperandType getOutputType(uint32_t index) const override; @@ -1211,14 +1208,14 @@ class OperationValidationContext : public IOperationValidationContext { const std::vector<uint32_t>& inputIndexes; const std::vector<uint32_t>& outputIndexes; const std::vector<Operand>& operands; - Version version; + HalVersion version; }; const char* OperationValidationContext::getOperationName() const { return operationName; } -Version OperationValidationContext::getVersion() const { +HalVersion OperationValidationContext::getHalVersion() const { return version; } @@ -1252,8 +1249,7 @@ Shape OperationValidationContext::getInputShape(uint32_t index) const { operand->extraParams}; } -const Operand::ExtraParams OperationValidationContext::getInputExtraParams(uint32_t index) const -{ +const Operand::ExtraParams& OperationValidationContext::getInputExtraParams(uint32_t index) const { return getInputOperand(index)->extraParams; } @@ -1266,7 +1262,6 @@ Shape OperationValidationContext::getOutputShape(uint32_t index) const { return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint, operand->extraParams}; } -#endif // TODO(b/169345292): reduce the duplicate validation here @@ -2517,9 +2512,6 @@ Result<Version> validateOperationImpl(const Operation& operation, return validateWhileOperation(inputIndexes, outputIndexes, operands, subgraphs); } default: { - // TODO: Enable this block of code once canonical types are integrated in the rest of - // the NNAPI codebase. -#if 0 const OperationRegistration* operationRegistration = BuiltinOperationResolver::get()->findOperation( static_cast<OperationType>(opType)); @@ -2528,16 +2520,23 @@ Result<Version> validateOperationImpl(const Operation& operation, // TODO: return ErrorStatus::UNEXPECTED_NULL NN_VALIDATE(operationRegistration->validate != nullptr) << "Incomplete operation registration: " << opType; - OperationValidationContext context(operationRegistration->name, inputIndexes, - outputIndexes, operands); - auto result = operationRegistration->validate(&context); - if (!result.has_value()) { - return NN_ERROR() << "Validation failed for operation " << opType << ": " - << std::move(result).error(); + + constexpr HalVersion kHalVersions[] = {HalVersion::V1_0, HalVersion::V1_1, + HalVersion::V1_2, HalVersion::V1_3}; + constexpr Version kVersions[] = {Version::ANDROID_OC_MR1, Version::ANDROID_P, + Version::ANDROID_Q, Version::ANDROID_R}; + static_assert(std::size(kHalVersions) == std::size(kVersions)); + + for (size_t i = 0; i < std::size(kHalVersions); ++i) { + OperationValidationContext context(operationRegistration->name, inputIndexes, + outputIndexes, operands, kHalVersions[i]); + auto valid = operationRegistration->validate(&context); + if (valid) { + return kVersions[i]; + } } - return result; -#endif - NN_VALIDATE_FAIL() << "Validation for " << opType << " is not yet implemented"; + + return NN_ERROR() << "Validation failed for operation " << opType; } } } diff --git a/nn/common/include/CpuOperationUtils.h b/nn/common/include/CpuOperationUtils.h index 879952932..ff58ff11c 100644 --- a/nn/common/include/CpuOperationUtils.h +++ b/nn/common/include/CpuOperationUtils.h @@ -17,6 +17,7 @@ #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_CPU_OPERATION_UTILS_H #define ANDROID_FRAMEWORKS_ML_NN_COMMON_CPU_OPERATION_UTILS_H +#include <android-base/logging.h> #include <tensorflow/lite/kernels/internal/types.h> #include <algorithm> @@ -32,7 +33,7 @@ namespace nn { // The implementations in tflite/kernels/internal/ take a Dims<4> object // even if the original tensors were not 4D. inline tflite::Dims<4> convertShapeToDims(const Shape& shape) { - nnAssert(shape.dimensions.size() <= 4); + CHECK_LE(shape.dimensions.size(), 4u); tflite::Dims<4> dims; // The dimensions are reversed in Dims<4>. diff --git a/nn/common/include/OperationResolver.h b/nn/common/include/OperationResolver.h index 700513d13..d2c066cd3 100644 --- a/nn/common/include/OperationResolver.h +++ b/nn/common/include/OperationResolver.h @@ -17,7 +17,10 @@ #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H #define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H -#include "HalInterfaces.h" +#include <android-base/macros.h> + +#include <utility> + #include "OperationsUtils.h" namespace android { @@ -53,9 +56,9 @@ struct OperationRegistration { std::function<bool(IOperationExecutionContext*)> execute, Flag flags) : type(type), name(name), - validate(validate), - prepare(prepare), - execute(execute), + validate(std::move(validate)), + prepare(std::move(prepare)), + execute(std::move(execute)), flags(flags) {} }; @@ -88,6 +91,9 @@ class BuiltinOperationResolver : public IOperationResolver { const OperationRegistration* findOperation(OperationType operationType) const override; + // The number of operation types (OperationCode) defined in NeuralNetworks.h. + static constexpr int kNumberOfOperationTypes = 102; + private: BuiltinOperationResolver(); diff --git a/nn/common/include/OperationsUtils.h b/nn/common/include/OperationsUtils.h index 9b0a9bdaa..492583f24 100644 --- a/nn/common/include/OperationsUtils.h +++ b/nn/common/include/OperationsUtils.h @@ -21,8 +21,7 @@ #include <cstdint> #include <vector> -#include "HalInterfaces.h" -#include "Utils.h" +#include "nnapi/TypeUtils.h" #include "nnapi/Types.h" namespace android { diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h index 1d4c6811c..cdaf91172 100644 --- a/nn/common/include/Utils.h +++ b/nn/common/include/Utils.h @@ -27,6 +27,7 @@ #include "HalInterfaces.h" #include "NeuralNetworks.h" +#include "OperationResolver.h" #include "ValidateHal.h" #include "nnapi/TypeUtils.h" #include "nnapi/Types.h" @@ -39,6 +40,7 @@ const int kNumberOfDataTypes = 16; // The number of operation types (OperationCode) defined in NeuralNetworks.h. const int kNumberOfOperationTypes = 102; +static_assert(kNumberOfOperationTypes == BuiltinOperationResolver::kNumberOfOperationTypes); // The number of execution preferences defined in NeuralNetworks.h. const int kNumberOfPreferences = 3; @@ -86,57 +88,6 @@ void initVLogMask(); } \ } while (0) -// The NN_RET_CHECK family of macros defined below is similar to the CHECK family defined in -// system/libbase/include/android-base/logging.h -// -// The difference is that NN_RET_CHECK macros use LOG(ERROR) instead of LOG(FATAL) -// and return false instead of aborting. - -// Logs an error and returns false. Append context using << after. For example: -// -// NN_RET_CHECK_FAIL() << "Something went wrong"; -// -// The containing function must return a bool. -#define NN_RET_CHECK_FAIL() \ - return ::android::nn::FalseyErrorStream() \ - << "NN_RET_CHECK failed (" << __FILE__ << ":" << __LINE__ << "): " - -// Logs an error and returns false if condition is false. Extra logging can be appended using << -// after. For example: -// -// NN_RET_CHECK(false) << "Something went wrong"; -// -// The containing function must return a bool. -#define NN_RET_CHECK(condition) \ - while (UNLIKELY(!(condition))) NN_RET_CHECK_FAIL() << #condition << " " - -// Helper for NN_CHECK_xx(x, y) macros. -#define NN_RET_CHECK_OP(LHS, RHS, OP) \ - for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \ - UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \ - /* empty */) \ - NN_RET_CHECK_FAIL() \ - << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \ - << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \ - << ", " << #RHS << " = " \ - << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \ - << ") " - -// Logs an error and returns false if a condition between x and y does not hold. Extra logging can -// be appended using << after. For example: -// -// NN_RET_CHECK_EQ(a, b) << "Something went wrong"; -// -// The values must implement the appropriate comparison operator as well as -// `operator<<(std::ostream&, ...)`. -// The containing function must return a bool. -#define NN_RET_CHECK_EQ(x, y) NN_RET_CHECK_OP(x, y, ==) -#define NN_RET_CHECK_NE(x, y) NN_RET_CHECK_OP(x, y, !=) -#define NN_RET_CHECK_LE(x, y) NN_RET_CHECK_OP(x, y, <=) -#define NN_RET_CHECK_LT(x, y) NN_RET_CHECK_OP(x, y, <) -#define NN_RET_CHECK_GE(x, y) NN_RET_CHECK_OP(x, y, >=) -#define NN_RET_CHECK_GT(x, y) NN_RET_CHECK_OP(x, y, >) - // Make an TimeoutDuration from a duration in nanoseconds. If the value exceeds // the max duration, return the maximum expressible duration. TimeoutDuration makeTimeoutDuration(uint64_t nanoseconds); @@ -180,28 +131,6 @@ OptionalTimePoint makeTimePoint(const std::optional<Deadline>& deadline); // correct instance, using the correct LOG_TAG namespace { -// A wrapper around LOG(ERROR) that can be implicitly converted to bool (always evaluates to false). -// Used to implement stream logging in NN_RET_CHECK. -class FalseyErrorStream { - DISALLOW_COPY_AND_ASSIGN(FalseyErrorStream); - - public: - FalseyErrorStream() {} - - template <typename T> - FalseyErrorStream& operator<<(const T& value) { - mBuffer << value; - return *this; - } - - ~FalseyErrorStream() { LOG(ERROR) << mBuffer.str(); } - - operator bool() const { return false; } - - private: - std::ostringstream mBuffer; -}; - template <HalVersion version> struct VersionedType {}; @@ -373,21 +302,6 @@ std::string toString(const std::pair<A, B>& pair) { return oss.str(); } -inline std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion) { - switch (halVersion) { - case HalVersion::UNKNOWN: - return os << "UNKNOWN HAL version"; - case HalVersion::V1_0: - return os << "HAL version 1.0"; - case HalVersion::V1_1: - return os << "HAL version 1.1"; - case HalVersion::V1_2: - return os << "HAL version 1.2"; - case HalVersion::V1_3: - return os << "HAL version 1.3"; - } -} - inline bool validCode(uint32_t codeCount, uint32_t codeCountOEM, uint32_t code) { return (code < codeCount) || (code >= kOEMCodeBase && (code - kOEMCodeBase) < codeCountOEM); } diff --git a/nn/common/include/ValidateHal.h b/nn/common/include/ValidateHal.h index c501fc011..57ba0792d 100644 --- a/nn/common/include/ValidateHal.h +++ b/nn/common/include/ValidateHal.h @@ -21,19 +21,11 @@ #include <tuple> #include "HalInterfaces.h" +#include "nnapi/TypeUtils.h" namespace android { namespace nn { -enum class HalVersion : int32_t { - UNKNOWN, - V1_0, - V1_1, - V1_2, - V1_3, - LATEST = V1_3, -}; - enum class IOType { INPUT, OUTPUT }; using PreparedModelRole = std::tuple<const V1_3::IPreparedModel*, IOType, uint32_t>; diff --git a/nn/common/include/nnapi/TypeUtils.h b/nn/common/include/nnapi/TypeUtils.h index 77761c40b..9dc67cf46 100644 --- a/nn/common/include/nnapi/TypeUtils.h +++ b/nn/common/include/nnapi/TypeUtils.h @@ -17,6 +17,9 @@ #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H #define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H +#include <android-base/logging.h> +#include <android-base/macros.h> + #include <ostream> #include <utility> #include <vector> @@ -28,6 +31,15 @@ namespace android::nn { +enum class HalVersion : int32_t { + UNKNOWN, + V1_0, + V1_1, + V1_2, + V1_3, + LATEST = V1_3, +}; + bool isExtension(OperandType type); bool isExtension(OperationType type); @@ -93,6 +105,7 @@ std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTime std::ostream& operator<<(std::ostream& os, const TimeoutDuration& timeoutDuration); std::ostream& operator<<(std::ostream& os, const OptionalTimeoutDuration& optionalTimeoutDuration); std::ostream& operator<<(std::ostream& os, const Version& version); +std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion); bool operator==(const Timing& a, const Timing& b); bool operator!=(const Timing& a, const Timing& b); @@ -112,6 +125,85 @@ bool operator!=(const Operand& a, const Operand& b); bool operator==(const Operation& a, const Operation& b); bool operator!=(const Operation& a, const Operation& b); +// The NN_RET_CHECK family of macros defined below is similar to the CHECK family defined in +// system/libbase/include/android-base/logging.h +// +// The difference is that NN_RET_CHECK macros use LOG(ERROR) instead of LOG(FATAL) +// and return false instead of aborting. + +// Logs an error and returns false. Append context using << after. For example: +// +// NN_RET_CHECK_FAIL() << "Something went wrong"; +// +// The containing function must return a bool. +#define NN_RET_CHECK_FAIL() \ + return ::android::nn::FalseyErrorStream() \ + << "NN_RET_CHECK failed (" << __FILE__ << ":" << __LINE__ << "): " + +// Logs an error and returns false if condition is false. Extra logging can be appended using << +// after. For example: +// +// NN_RET_CHECK(false) << "Something went wrong"; +// +// The containing function must return a bool. +#define NN_RET_CHECK(condition) \ + while (UNLIKELY(!(condition))) NN_RET_CHECK_FAIL() << #condition << " " + +// Helper for NN_CHECK_xx(x, y) macros. +#define NN_RET_CHECK_OP(LHS, RHS, OP) \ + for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \ + UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \ + /* empty */) \ + NN_RET_CHECK_FAIL() \ + << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \ + << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \ + << ", " << #RHS << " = " \ + << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \ + << ") " + +// Logs an error and returns false if a condition between x and y does not hold. Extra logging can +// be appended using << after. For example: +// +// NN_RET_CHECK_EQ(a, b) << "Something went wrong"; +// +// The values must implement the appropriate comparison operator as well as +// `operator<<(std::ostream&, ...)`. +// The containing function must return a bool. +#define NN_RET_CHECK_EQ(x, y) NN_RET_CHECK_OP(x, y, ==) +#define NN_RET_CHECK_NE(x, y) NN_RET_CHECK_OP(x, y, !=) +#define NN_RET_CHECK_LE(x, y) NN_RET_CHECK_OP(x, y, <=) +#define NN_RET_CHECK_LT(x, y) NN_RET_CHECK_OP(x, y, <) +#define NN_RET_CHECK_GE(x, y) NN_RET_CHECK_OP(x, y, >=) +#define NN_RET_CHECK_GT(x, y) NN_RET_CHECK_OP(x, y, >) + +// Ensure that every user of FalseyErrorStream is linked to the +// correct instance, using the correct LOG_TAG +namespace { + +// A wrapper around LOG(ERROR) that can be implicitly converted to bool (always evaluates to false). +// Used to implement stream logging in NN_RET_CHECK. +class FalseyErrorStream { + DISALLOW_COPY_AND_ASSIGN(FalseyErrorStream); + + public: + FalseyErrorStream() {} + + template <typename T> + FalseyErrorStream& operator<<(const T& value) { + mBuffer << value; + return *this; + } + + ~FalseyErrorStream() { LOG(ERROR) << mBuffer.str(); } + + operator bool() const { return false; } + + private: + std::ostringstream mBuffer; +}; + +} // namespace + } // namespace android::nn #endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp index 67bb914bd..cf75f3a87 100644 --- a/nn/common/operations/Broadcast.cpp +++ b/nn/common/operations/Broadcast.cpp @@ -206,7 +206,7 @@ bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& sha bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData, const Shape& bShape, int32_t activation, int32_t* outputData, const Shape& outputShape, int32_t func(int32_t, int32_t)) { - NN_RET_CHECK_EQ(activation, ANEURALNETWORKS_FUSED_NONE); + NN_RET_CHECK_EQ(static_cast<FusedActivationFunc>(activation), FusedActivationFunc::NONE); IndexedShapeWrapper aShapeIndexed(aShape); IndexedShapeWrapper bShapeIndexed(bShape); IndexedShapeWrapper outputShapeIndexed(outputShape); diff --git a/nn/common/operations/Reshape.cpp b/nn/common/operations/Reshape.cpp index 48c293e7a..76effb8be 100644 --- a/nn/common/operations/Reshape.cpp +++ b/nn/common/operations/Reshape.cpp @@ -18,15 +18,15 @@ #define LOG_TAG "Operations" +#include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h> +#include <tensorflow/lite/kernels/internal/reference/reference_ops.h> + #include <vector> #include "CpuOperationUtils.h" #include "Operations.h" - -#include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h> -#include <tensorflow/lite/kernels/internal/reference/reference_ops.h> - #include "Tracing.h" +#include "Utils.h" namespace android { namespace nn { |