summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2020-11-05 17:11:22 +0000
committerGerrit Code Review <noreply-gerritcodereview@google.com>2020-11-05 17:11:22 +0000
commit67d7e29ad3430e716a530799b3fb23ead678601e (patch)
tree7453b5e0852551022c3ba140674e60661facb938
parent9d8007c5342dcdd324af06af8ca18f6ea2974af1 (diff)
parent2bd2cae2be76017277a7aae6b5fe4bc7fc1ff67c (diff)
downloadml-67d7e29ad3430e716a530799b3fb23ead678601e.tar.gz
Merge "Update NNAPI canonical validation"
-rw-r--r--nn/common/IndexedShapeWrapper.cpp4
-rw-r--r--nn/common/TypeUtils.cpp16
-rw-r--r--nn/common/Validation.cpp49
-rw-r--r--nn/common/include/CpuOperationUtils.h3
-rw-r--r--nn/common/include/OperationResolver.h14
-rw-r--r--nn/common/include/OperationsUtils.h3
-rw-r--r--nn/common/include/Utils.h90
-rw-r--r--nn/common/include/ValidateHal.h10
-rw-r--r--nn/common/include/nnapi/TypeUtils.h92
-rw-r--r--nn/common/operations/Broadcast.cpp2
-rw-r--r--nn/common/operations/Reshape.cpp8
11 files changed, 157 insertions, 134 deletions
diff --git a/nn/common/IndexedShapeWrapper.cpp b/nn/common/IndexedShapeWrapper.cpp
index e90665986..8101c016f 100644
--- a/nn/common/IndexedShapeWrapper.cpp
+++ b/nn/common/IndexedShapeWrapper.cpp
@@ -18,6 +18,10 @@
#include "IndexedShapeWrapper.h"
+#include <vector>
+
+#include "Utils.h"
+
namespace android {
namespace nn {
diff --git a/nn/common/TypeUtils.cpp b/nn/common/TypeUtils.cpp
index f5bbe07f0..ad17d9444 100644
--- a/nn/common/TypeUtils.cpp
+++ b/nn/common/TypeUtils.cpp
@@ -816,6 +816,22 @@ std::ostream& operator<<(std::ostream& os, const Version& version) {
return os << "Version{" << underlyingType(version) << "}";
}
+std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion) {
+ switch (halVersion) {
+ case HalVersion::UNKNOWN:
+ return os << "UNKNOWN HAL version";
+ case HalVersion::V1_0:
+ return os << "HAL version 1.0";
+ case HalVersion::V1_1:
+ return os << "HAL version 1.1";
+ case HalVersion::V1_2:
+ return os << "HAL version 1.2";
+ case HalVersion::V1_3:
+ return os << "HAL version 1.3";
+ }
+ return os << "HalVersion{" << underlyingType(halVersion) << "}";
+}
+
bool operator==(const Timing& a, const Timing& b) {
return a.timeOnDevice == b.timeOnDevice && a.timeInDriver == b.timeInDriver;
}
diff --git a/nn/common/Validation.cpp b/nn/common/Validation.cpp
index 9127bec7c..68f778d22 100644
--- a/nn/common/Validation.cpp
+++ b/nn/common/Validation.cpp
@@ -35,6 +35,7 @@
#include "ControlFlow.h"
#include "OperandTypes.h"
+#include "OperationResolver.h"
#include "OperationTypes.h"
#include "Result.h"
#include "TypeUtils.h"
@@ -1174,17 +1175,13 @@ Result<Version> validateMemoryDescImpl(
return Version::ANDROID_R;
}
-// TODO: Enable this block of code once canonical types are integrated in the rest of the NNAPI
-// codebase.
-#if 0
class OperationValidationContext : public IOperationValidationContext {
DISALLOW_IMPLICIT_CONSTRUCTORS(OperationValidationContext);
public:
- OperationValidationContext(const char* operationName, const std::vector<uint32_t>&
- inputIndexes,
+ OperationValidationContext(const char* operationName, const std::vector<uint32_t>& inputIndexes,
const std::vector<uint32_t>& outputIndexes,
- const std::vector<Operand>& operands, Version version)
+ const std::vector<Operand>& operands, HalVersion version)
: operationName(operationName),
inputIndexes(inputIndexes),
outputIndexes(outputIndexes),
@@ -1192,12 +1189,12 @@ class OperationValidationContext : public IOperationValidationContext {
version(version) {}
const char* getOperationName() const override;
- Version getVersion() const override;
+ HalVersion getHalVersion() const override;
uint32_t getNumInputs() const override;
OperandType getInputType(uint32_t index) const override;
Shape getInputShape(uint32_t index) const override;
- const Operand::ExtraParams getInputExtraParams(uint32_t index) const override;
+ const Operand::ExtraParams& getInputExtraParams(uint32_t index) const override;
uint32_t getNumOutputs() const override;
OperandType getOutputType(uint32_t index) const override;
@@ -1211,14 +1208,14 @@ class OperationValidationContext : public IOperationValidationContext {
const std::vector<uint32_t>& inputIndexes;
const std::vector<uint32_t>& outputIndexes;
const std::vector<Operand>& operands;
- Version version;
+ HalVersion version;
};
const char* OperationValidationContext::getOperationName() const {
return operationName;
}
-Version OperationValidationContext::getVersion() const {
+HalVersion OperationValidationContext::getHalVersion() const {
return version;
}
@@ -1252,8 +1249,7 @@ Shape OperationValidationContext::getInputShape(uint32_t index) const {
operand->extraParams};
}
-const Operand::ExtraParams OperationValidationContext::getInputExtraParams(uint32_t index) const
-{
+const Operand::ExtraParams& OperationValidationContext::getInputExtraParams(uint32_t index) const {
return getInputOperand(index)->extraParams;
}
@@ -1266,7 +1262,6 @@ Shape OperationValidationContext::getOutputShape(uint32_t index) const {
return {operand->type, operand->dimensions, operand->scale, operand->zeroPoint,
operand->extraParams};
}
-#endif
// TODO(b/169345292): reduce the duplicate validation here
@@ -2517,9 +2512,6 @@ Result<Version> validateOperationImpl(const Operation& operation,
return validateWhileOperation(inputIndexes, outputIndexes, operands, subgraphs);
}
default: {
- // TODO: Enable this block of code once canonical types are integrated in the rest of
- // the NNAPI codebase.
-#if 0
const OperationRegistration* operationRegistration =
BuiltinOperationResolver::get()->findOperation(
static_cast<OperationType>(opType));
@@ -2528,16 +2520,23 @@ Result<Version> validateOperationImpl(const Operation& operation,
// TODO: return ErrorStatus::UNEXPECTED_NULL
NN_VALIDATE(operationRegistration->validate != nullptr)
<< "Incomplete operation registration: " << opType;
- OperationValidationContext context(operationRegistration->name, inputIndexes,
- outputIndexes, operands);
- auto result = operationRegistration->validate(&context);
- if (!result.has_value()) {
- return NN_ERROR() << "Validation failed for operation " << opType << ": "
- << std::move(result).error();
+
+ constexpr HalVersion kHalVersions[] = {HalVersion::V1_0, HalVersion::V1_1,
+ HalVersion::V1_2, HalVersion::V1_3};
+ constexpr Version kVersions[] = {Version::ANDROID_OC_MR1, Version::ANDROID_P,
+ Version::ANDROID_Q, Version::ANDROID_R};
+ static_assert(std::size(kHalVersions) == std::size(kVersions));
+
+ for (size_t i = 0; i < std::size(kHalVersions); ++i) {
+ OperationValidationContext context(operationRegistration->name, inputIndexes,
+ outputIndexes, operands, kHalVersions[i]);
+ auto valid = operationRegistration->validate(&context);
+ if (valid) {
+ return kVersions[i];
+ }
}
- return result;
-#endif
- NN_VALIDATE_FAIL() << "Validation for " << opType << " is not yet implemented";
+
+ return NN_ERROR() << "Validation failed for operation " << opType;
}
}
}
diff --git a/nn/common/include/CpuOperationUtils.h b/nn/common/include/CpuOperationUtils.h
index 879952932..ff58ff11c 100644
--- a/nn/common/include/CpuOperationUtils.h
+++ b/nn/common/include/CpuOperationUtils.h
@@ -17,6 +17,7 @@
#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_CPU_OPERATION_UTILS_H
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_CPU_OPERATION_UTILS_H
+#include <android-base/logging.h>
#include <tensorflow/lite/kernels/internal/types.h>
#include <algorithm>
@@ -32,7 +33,7 @@ namespace nn {
// The implementations in tflite/kernels/internal/ take a Dims<4> object
// even if the original tensors were not 4D.
inline tflite::Dims<4> convertShapeToDims(const Shape& shape) {
- nnAssert(shape.dimensions.size() <= 4);
+ CHECK_LE(shape.dimensions.size(), 4u);
tflite::Dims<4> dims;
// The dimensions are reversed in Dims<4>.
diff --git a/nn/common/include/OperationResolver.h b/nn/common/include/OperationResolver.h
index 700513d13..d2c066cd3 100644
--- a/nn/common/include/OperationResolver.h
+++ b/nn/common/include/OperationResolver.h
@@ -17,7 +17,10 @@
#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_OPERATION_RESOLVER_H
-#include "HalInterfaces.h"
+#include <android-base/macros.h>
+
+#include <utility>
+
#include "OperationsUtils.h"
namespace android {
@@ -53,9 +56,9 @@ struct OperationRegistration {
std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
: type(type),
name(name),
- validate(validate),
- prepare(prepare),
- execute(execute),
+ validate(std::move(validate)),
+ prepare(std::move(prepare)),
+ execute(std::move(execute)),
flags(flags) {}
};
@@ -88,6 +91,9 @@ class BuiltinOperationResolver : public IOperationResolver {
const OperationRegistration* findOperation(OperationType operationType) const override;
+ // The number of operation types (OperationCode) defined in NeuralNetworks.h.
+ static constexpr int kNumberOfOperationTypes = 102;
+
private:
BuiltinOperationResolver();
diff --git a/nn/common/include/OperationsUtils.h b/nn/common/include/OperationsUtils.h
index 9b0a9bdaa..492583f24 100644
--- a/nn/common/include/OperationsUtils.h
+++ b/nn/common/include/OperationsUtils.h
@@ -21,8 +21,7 @@
#include <cstdint>
#include <vector>
-#include "HalInterfaces.h"
-#include "Utils.h"
+#include "nnapi/TypeUtils.h"
#include "nnapi/Types.h"
namespace android {
diff --git a/nn/common/include/Utils.h b/nn/common/include/Utils.h
index 1d4c6811c..cdaf91172 100644
--- a/nn/common/include/Utils.h
+++ b/nn/common/include/Utils.h
@@ -27,6 +27,7 @@
#include "HalInterfaces.h"
#include "NeuralNetworks.h"
+#include "OperationResolver.h"
#include "ValidateHal.h"
#include "nnapi/TypeUtils.h"
#include "nnapi/Types.h"
@@ -39,6 +40,7 @@ const int kNumberOfDataTypes = 16;
// The number of operation types (OperationCode) defined in NeuralNetworks.h.
const int kNumberOfOperationTypes = 102;
+static_assert(kNumberOfOperationTypes == BuiltinOperationResolver::kNumberOfOperationTypes);
// The number of execution preferences defined in NeuralNetworks.h.
const int kNumberOfPreferences = 3;
@@ -86,57 +88,6 @@ void initVLogMask();
} \
} while (0)
-// The NN_RET_CHECK family of macros defined below is similar to the CHECK family defined in
-// system/libbase/include/android-base/logging.h
-//
-// The difference is that NN_RET_CHECK macros use LOG(ERROR) instead of LOG(FATAL)
-// and return false instead of aborting.
-
-// Logs an error and returns false. Append context using << after. For example:
-//
-// NN_RET_CHECK_FAIL() << "Something went wrong";
-//
-// The containing function must return a bool.
-#define NN_RET_CHECK_FAIL() \
- return ::android::nn::FalseyErrorStream() \
- << "NN_RET_CHECK failed (" << __FILE__ << ":" << __LINE__ << "): "
-
-// Logs an error and returns false if condition is false. Extra logging can be appended using <<
-// after. For example:
-//
-// NN_RET_CHECK(false) << "Something went wrong";
-//
-// The containing function must return a bool.
-#define NN_RET_CHECK(condition) \
- while (UNLIKELY(!(condition))) NN_RET_CHECK_FAIL() << #condition << " "
-
-// Helper for NN_CHECK_xx(x, y) macros.
-#define NN_RET_CHECK_OP(LHS, RHS, OP) \
- for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \
- UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \
- /* empty */) \
- NN_RET_CHECK_FAIL() \
- << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \
- << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \
- << ", " << #RHS << " = " \
- << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \
- << ") "
-
-// Logs an error and returns false if a condition between x and y does not hold. Extra logging can
-// be appended using << after. For example:
-//
-// NN_RET_CHECK_EQ(a, b) << "Something went wrong";
-//
-// The values must implement the appropriate comparison operator as well as
-// `operator<<(std::ostream&, ...)`.
-// The containing function must return a bool.
-#define NN_RET_CHECK_EQ(x, y) NN_RET_CHECK_OP(x, y, ==)
-#define NN_RET_CHECK_NE(x, y) NN_RET_CHECK_OP(x, y, !=)
-#define NN_RET_CHECK_LE(x, y) NN_RET_CHECK_OP(x, y, <=)
-#define NN_RET_CHECK_LT(x, y) NN_RET_CHECK_OP(x, y, <)
-#define NN_RET_CHECK_GE(x, y) NN_RET_CHECK_OP(x, y, >=)
-#define NN_RET_CHECK_GT(x, y) NN_RET_CHECK_OP(x, y, >)
-
// Make an TimeoutDuration from a duration in nanoseconds. If the value exceeds
// the max duration, return the maximum expressible duration.
TimeoutDuration makeTimeoutDuration(uint64_t nanoseconds);
@@ -180,28 +131,6 @@ OptionalTimePoint makeTimePoint(const std::optional<Deadline>& deadline);
// correct instance, using the correct LOG_TAG
namespace {
-// A wrapper around LOG(ERROR) that can be implicitly converted to bool (always evaluates to false).
-// Used to implement stream logging in NN_RET_CHECK.
-class FalseyErrorStream {
- DISALLOW_COPY_AND_ASSIGN(FalseyErrorStream);
-
- public:
- FalseyErrorStream() {}
-
- template <typename T>
- FalseyErrorStream& operator<<(const T& value) {
- mBuffer << value;
- return *this;
- }
-
- ~FalseyErrorStream() { LOG(ERROR) << mBuffer.str(); }
-
- operator bool() const { return false; }
-
- private:
- std::ostringstream mBuffer;
-};
-
template <HalVersion version>
struct VersionedType {};
@@ -373,21 +302,6 @@ std::string toString(const std::pair<A, B>& pair) {
return oss.str();
}
-inline std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion) {
- switch (halVersion) {
- case HalVersion::UNKNOWN:
- return os << "UNKNOWN HAL version";
- case HalVersion::V1_0:
- return os << "HAL version 1.0";
- case HalVersion::V1_1:
- return os << "HAL version 1.1";
- case HalVersion::V1_2:
- return os << "HAL version 1.2";
- case HalVersion::V1_3:
- return os << "HAL version 1.3";
- }
-}
-
inline bool validCode(uint32_t codeCount, uint32_t codeCountOEM, uint32_t code) {
return (code < codeCount) || (code >= kOEMCodeBase && (code - kOEMCodeBase) < codeCountOEM);
}
diff --git a/nn/common/include/ValidateHal.h b/nn/common/include/ValidateHal.h
index c501fc011..57ba0792d 100644
--- a/nn/common/include/ValidateHal.h
+++ b/nn/common/include/ValidateHal.h
@@ -21,19 +21,11 @@
#include <tuple>
#include "HalInterfaces.h"
+#include "nnapi/TypeUtils.h"
namespace android {
namespace nn {
-enum class HalVersion : int32_t {
- UNKNOWN,
- V1_0,
- V1_1,
- V1_2,
- V1_3,
- LATEST = V1_3,
-};
-
enum class IOType { INPUT, OUTPUT };
using PreparedModelRole = std::tuple<const V1_3::IPreparedModel*, IOType, uint32_t>;
diff --git a/nn/common/include/nnapi/TypeUtils.h b/nn/common/include/nnapi/TypeUtils.h
index 77761c40b..9dc67cf46 100644
--- a/nn/common/include/nnapi/TypeUtils.h
+++ b/nn/common/include/nnapi/TypeUtils.h
@@ -17,6 +17,9 @@
#ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
#define ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
+#include <android-base/logging.h>
+#include <android-base/macros.h>
+
#include <ostream>
#include <utility>
#include <vector>
@@ -28,6 +31,15 @@
namespace android::nn {
+enum class HalVersion : int32_t {
+ UNKNOWN,
+ V1_0,
+ V1_1,
+ V1_2,
+ V1_3,
+ LATEST = V1_3,
+};
+
bool isExtension(OperandType type);
bool isExtension(OperationType type);
@@ -93,6 +105,7 @@ std::ostream& operator<<(std::ostream& os, const OptionalTimePoint& optionalTime
std::ostream& operator<<(std::ostream& os, const TimeoutDuration& timeoutDuration);
std::ostream& operator<<(std::ostream& os, const OptionalTimeoutDuration& optionalTimeoutDuration);
std::ostream& operator<<(std::ostream& os, const Version& version);
+std::ostream& operator<<(std::ostream& os, const HalVersion& halVersion);
bool operator==(const Timing& a, const Timing& b);
bool operator!=(const Timing& a, const Timing& b);
@@ -112,6 +125,85 @@ bool operator!=(const Operand& a, const Operand& b);
bool operator==(const Operation& a, const Operation& b);
bool operator!=(const Operation& a, const Operation& b);
+// The NN_RET_CHECK family of macros defined below is similar to the CHECK family defined in
+// system/libbase/include/android-base/logging.h
+//
+// The difference is that NN_RET_CHECK macros use LOG(ERROR) instead of LOG(FATAL)
+// and return false instead of aborting.
+
+// Logs an error and returns false. Append context using << after. For example:
+//
+// NN_RET_CHECK_FAIL() << "Something went wrong";
+//
+// The containing function must return a bool.
+#define NN_RET_CHECK_FAIL() \
+ return ::android::nn::FalseyErrorStream() \
+ << "NN_RET_CHECK failed (" << __FILE__ << ":" << __LINE__ << "): "
+
+// Logs an error and returns false if condition is false. Extra logging can be appended using <<
+// after. For example:
+//
+// NN_RET_CHECK(false) << "Something went wrong";
+//
+// The containing function must return a bool.
+#define NN_RET_CHECK(condition) \
+ while (UNLIKELY(!(condition))) NN_RET_CHECK_FAIL() << #condition << " "
+
+// Helper for NN_CHECK_xx(x, y) macros.
+#define NN_RET_CHECK_OP(LHS, RHS, OP) \
+ for (auto _values = ::android::base::MakeEagerEvaluator(LHS, RHS); \
+ UNLIKELY(!(_values.lhs.v OP _values.rhs.v)); \
+ /* empty */) \
+ NN_RET_CHECK_FAIL() \
+ << #LHS << " " << #OP << " " << #RHS << " (" << #LHS << " = " \
+ << ::android::base::LogNullGuard<decltype(_values.lhs.v)>::Guard(_values.lhs.v) \
+ << ", " << #RHS << " = " \
+ << ::android::base::LogNullGuard<decltype(_values.rhs.v)>::Guard(_values.rhs.v) \
+ << ") "
+
+// Logs an error and returns false if a condition between x and y does not hold. Extra logging can
+// be appended using << after. For example:
+//
+// NN_RET_CHECK_EQ(a, b) << "Something went wrong";
+//
+// The values must implement the appropriate comparison operator as well as
+// `operator<<(std::ostream&, ...)`.
+// The containing function must return a bool.
+#define NN_RET_CHECK_EQ(x, y) NN_RET_CHECK_OP(x, y, ==)
+#define NN_RET_CHECK_NE(x, y) NN_RET_CHECK_OP(x, y, !=)
+#define NN_RET_CHECK_LE(x, y) NN_RET_CHECK_OP(x, y, <=)
+#define NN_RET_CHECK_LT(x, y) NN_RET_CHECK_OP(x, y, <)
+#define NN_RET_CHECK_GE(x, y) NN_RET_CHECK_OP(x, y, >=)
+#define NN_RET_CHECK_GT(x, y) NN_RET_CHECK_OP(x, y, >)
+
+// Ensure that every user of FalseyErrorStream is linked to the
+// correct instance, using the correct LOG_TAG
+namespace {
+
+// A wrapper around LOG(ERROR) that can be implicitly converted to bool (always evaluates to false).
+// Used to implement stream logging in NN_RET_CHECK.
+class FalseyErrorStream {
+ DISALLOW_COPY_AND_ASSIGN(FalseyErrorStream);
+
+ public:
+ FalseyErrorStream() {}
+
+ template <typename T>
+ FalseyErrorStream& operator<<(const T& value) {
+ mBuffer << value;
+ return *this;
+ }
+
+ ~FalseyErrorStream() { LOG(ERROR) << mBuffer.str(); }
+
+ operator bool() const { return false; }
+
+ private:
+ std::ostringstream mBuffer;
+};
+
+} // namespace
+
} // namespace android::nn
#endif // ANDROID_FRAMEWORKS_ML_NN_COMMON_NNAPI_TYPE_UTILS_H
diff --git a/nn/common/operations/Broadcast.cpp b/nn/common/operations/Broadcast.cpp
index 67bb914bd..cf75f3a87 100644
--- a/nn/common/operations/Broadcast.cpp
+++ b/nn/common/operations/Broadcast.cpp
@@ -206,7 +206,7 @@ bool addQuant8(const T* in1, const Shape& shape1, const T* in2, const Shape& sha
bool executeInt32(const int32_t* aData, const Shape& aShape, const int32_t* bData,
const Shape& bShape, int32_t activation, int32_t* outputData,
const Shape& outputShape, int32_t func(int32_t, int32_t)) {
- NN_RET_CHECK_EQ(activation, ANEURALNETWORKS_FUSED_NONE);
+ NN_RET_CHECK_EQ(static_cast<FusedActivationFunc>(activation), FusedActivationFunc::NONE);
IndexedShapeWrapper aShapeIndexed(aShape);
IndexedShapeWrapper bShapeIndexed(bShape);
IndexedShapeWrapper outputShapeIndexed(outputShape);
diff --git a/nn/common/operations/Reshape.cpp b/nn/common/operations/Reshape.cpp
index 48c293e7a..76effb8be 100644
--- a/nn/common/operations/Reshape.cpp
+++ b/nn/common/operations/Reshape.cpp
@@ -18,15 +18,15 @@
#define LOG_TAG "Operations"
+#include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
+#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
+
#include <vector>
#include "CpuOperationUtils.h"
#include "Operations.h"
-
-#include <tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h>
-#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
-
#include "Tracing.h"
+#include "Utils.h"
namespace android {
namespace nn {