diff options
author | Michael Butler <butlermichael@google.com> | 2020-10-11 16:39:57 -0700 |
---|---|---|
committer | Michael Butler <butlermichael@google.com> | 2020-10-11 16:39:57 -0700 |
commit | fb920b988804df61db94446c43c470c27d8c768b (patch) | |
tree | e45762b95a88fdd6255a0cdc133f6fe868f92d38 | |
parent | ae59a9ab3650e942a0bed7e79cbb8d9562591a3e (diff) | |
download | ml-fb920b988804df61db94446c43c470c27d8c768b.tar.gz |
Add FusedActivationFunc to canonical types
This CL also makes two minor fixes to the validation code.
Bug: N/A
Test: mma
Change-Id: I81b73f1236f6a43e7c249f8c2ca9144b5291b993
-rw-r--r-- | nn/common/TypeUtils.cpp | 14 | ||||
-rw-r--r-- | nn/common/Validation.cpp | 19 | ||||
-rw-r--r-- | nn/common/include/nnapi/TypeUtils.h | 1 | ||||
-rw-r--r-- | nn/common/include/nnapi/Types.h | 7 | ||||
-rw-r--r-- | nn/common/include/nnapi/Validation.h | 1 |
5 files changed, 40 insertions, 2 deletions
diff --git a/nn/common/TypeUtils.cpp b/nn/common/TypeUtils.cpp index ffe2912c3..15b377b87 100644 --- a/nn/common/TypeUtils.cpp +++ b/nn/common/TypeUtils.cpp @@ -584,6 +584,20 @@ std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus) { return os << "ErrorStatus{" << underlyingType(errorStatus) << "}"; } +std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation) { + switch (activation) { + case FusedActivationFunc::NONE: + return os << "NONE"; + case FusedActivationFunc::RELU: + return os << "RELU"; + case FusedActivationFunc::RELU1: + return os << "RELU1"; + case FusedActivationFunc::RELU6: + return os << "RELU6"; + } + return os << "FusedActivationFunc{" << underlyingType(activation) << "}"; +} + std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape) { return os << "OutputShape{.dimensions=" << outputShape.dimensions << ", .isSufficient=" << (outputShape.isSufficient ? "true" : "false") << "}"; diff --git a/nn/common/Validation.cpp b/nn/common/Validation.cpp index d18ba3443..17ba63be9 100644 --- a/nn/common/Validation.cpp +++ b/nn/common/Validation.cpp @@ -243,6 +243,17 @@ Result<Version> validateErrorStatus(const ErrorStatus& errorStatus) { NN_VALIDATE_FAIL() << "Invalid ErrorStatus " << errorStatus; } +Result<Version> validateFusedActivationFunc(const FusedActivationFunc& activation) { + switch (activation) { + case FusedActivationFunc::NONE: + case FusedActivationFunc::RELU: + case FusedActivationFunc::RELU1: + case FusedActivationFunc::RELU6: + return Version::ANDROID_OC_MR1; + } + NN_VALIDATE_FAIL() << "Invalid FusedActivationFunc " << activation; +} + Result<Version> validateOutputShape(const OutputShape& /*outputShape*/) { return Version::ANDROID_Q; } @@ -431,7 +442,7 @@ Result<Version> validateOperandDimensions(const Operand& operand) { << " but dimensions of rank 0"; const auto size = getNonExtensionSize(operand); NN_VALIDATE(size.has_value()) << "Tensor dimensions overflow"; - NN_VALIDATE_EQ(size.value(), 0u) << "Tensor has at least one unknown dimension"; + NN_VALIDATE_NE(size.value(), 0u) << "Tensor has at least one unknown dimension"; } // TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before // Android Q? @@ -1085,7 +1096,7 @@ Result<Version> validateRequestForModelImpl(const Request& request, const Model& request.inputs, model.main.inputIndexes, model.main.operands, /*isOutput=*/false))); version = combineVersions(version, NN_TRY(validateRequestArgumentsForModel( - request.inputs, model.main.inputIndexes, + request.outputs, model.main.outputIndexes, model.main.operands, /*isOutput=*/true))); return version; } @@ -2561,6 +2572,10 @@ Result<Version> validate(const ErrorStatus& errorStatus) { return validateErrorStatus(errorStatus); } +Result<Version> validate(const FusedActivationFunc& activation) { + return validateFusedActivationFunc(activation); +} + Result<Version> validate(const OutputShape& outputShape) { return validateOutputShape(outputShape); } diff --git a/nn/common/include/nnapi/TypeUtils.h b/nn/common/include/nnapi/TypeUtils.h index 30f1346f4..77761c40b 100644 --- a/nn/common/include/nnapi/TypeUtils.h +++ b/nn/common/include/nnapi/TypeUtils.h @@ -58,6 +58,7 @@ std::ostream& operator<<(std::ostream& os, const OperationType& operationType); std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime); std::ostream& operator<<(std::ostream& os, const Priority& priority); std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus); +std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation); std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape); std::ostream& operator<<(std::ostream& os, const Timing& timing); std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo); diff --git a/nn/common/include/nnapi/Types.h b/nn/common/include/nnapi/Types.h index c0718dae1..371b4a5c2 100644 --- a/nn/common/include/nnapi/Types.h +++ b/nn/common/include/nnapi/Types.h @@ -116,6 +116,13 @@ enum class ErrorStatus { DEAD_OBJECT = 10000, }; +enum class FusedActivationFunc : int32_t { + NONE = 0, + RELU = 1, + RELU1 = 2, + RELU6 = 3, +}; + using Dimension = uint32_t; using Dimensions = std::vector<Dimension>; diff --git a/nn/common/include/nnapi/Validation.h b/nn/common/include/nnapi/Validation.h index 61a2a273c..09133957b 100644 --- a/nn/common/include/nnapi/Validation.h +++ b/nn/common/include/nnapi/Validation.h @@ -37,6 +37,7 @@ Result<Version> validate(const DeviceType& deviceType); Result<Version> validate(const MeasureTiming& measureTiming); Result<Version> validate(const Priority& priority); Result<Version> validate(const ErrorStatus& errorStatus); +Result<Version> validate(const FusedActivationFunc& activation); Result<Version> validate(const OutputShape& outputShape); Result<Version> validate(const Timing& timing); Result<Version> validate(const Capabilities& capabilities); |