summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2020-10-13 17:21:48 +0000
committerGerrit Code Review <noreply-gerritcodereview@google.com>2020-10-13 17:21:48 +0000
commit7dfee7e7b712f0e002d1dac7781123df846e057a (patch)
treee45762b95a88fdd6255a0cdc133f6fe868f92d38
parentae59a9ab3650e942a0bed7e79cbb8d9562591a3e (diff)
parentfb920b988804df61db94446c43c470c27d8c768b (diff)
downloadml-7dfee7e7b712f0e002d1dac7781123df846e057a.tar.gz
Merge "Add FusedActivationFunc to canonical types"
-rw-r--r--nn/common/TypeUtils.cpp14
-rw-r--r--nn/common/Validation.cpp19
-rw-r--r--nn/common/include/nnapi/TypeUtils.h1
-rw-r--r--nn/common/include/nnapi/Types.h7
-rw-r--r--nn/common/include/nnapi/Validation.h1
5 files changed, 40 insertions, 2 deletions
diff --git a/nn/common/TypeUtils.cpp b/nn/common/TypeUtils.cpp
index ffe2912c3..15b377b87 100644
--- a/nn/common/TypeUtils.cpp
+++ b/nn/common/TypeUtils.cpp
@@ -584,6 +584,20 @@ std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus) {
return os << "ErrorStatus{" << underlyingType(errorStatus) << "}";
}
+std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation) {
+ switch (activation) {
+ case FusedActivationFunc::NONE:
+ return os << "NONE";
+ case FusedActivationFunc::RELU:
+ return os << "RELU";
+ case FusedActivationFunc::RELU1:
+ return os << "RELU1";
+ case FusedActivationFunc::RELU6:
+ return os << "RELU6";
+ }
+ return os << "FusedActivationFunc{" << underlyingType(activation) << "}";
+}
+
std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape) {
return os << "OutputShape{.dimensions=" << outputShape.dimensions
<< ", .isSufficient=" << (outputShape.isSufficient ? "true" : "false") << "}";
diff --git a/nn/common/Validation.cpp b/nn/common/Validation.cpp
index d18ba3443..17ba63be9 100644
--- a/nn/common/Validation.cpp
+++ b/nn/common/Validation.cpp
@@ -243,6 +243,17 @@ Result<Version> validateErrorStatus(const ErrorStatus& errorStatus) {
NN_VALIDATE_FAIL() << "Invalid ErrorStatus " << errorStatus;
}
+Result<Version> validateFusedActivationFunc(const FusedActivationFunc& activation) {
+ switch (activation) {
+ case FusedActivationFunc::NONE:
+ case FusedActivationFunc::RELU:
+ case FusedActivationFunc::RELU1:
+ case FusedActivationFunc::RELU6:
+ return Version::ANDROID_OC_MR1;
+ }
+ NN_VALIDATE_FAIL() << "Invalid FusedActivationFunc " << activation;
+}
+
Result<Version> validateOutputShape(const OutputShape& /*outputShape*/) {
return Version::ANDROID_Q;
}
@@ -431,7 +442,7 @@ Result<Version> validateOperandDimensions(const Operand& operand) {
<< " but dimensions of rank 0";
const auto size = getNonExtensionSize(operand);
NN_VALIDATE(size.has_value()) << "Tensor dimensions overflow";
- NN_VALIDATE_EQ(size.value(), 0u) << "Tensor has at least one unknown dimension";
+ NN_VALIDATE_NE(size.value(), 0u) << "Tensor has at least one unknown dimension";
}
// TODO(b/165152547): aren't NO_VALUE arguments allowed to be .empty() even before
// Android Q?
@@ -1085,7 +1096,7 @@ Result<Version> validateRequestForModelImpl(const Request& request, const Model&
request.inputs, model.main.inputIndexes,
model.main.operands, /*isOutput=*/false)));
version = combineVersions(version, NN_TRY(validateRequestArgumentsForModel(
- request.inputs, model.main.inputIndexes,
+ request.outputs, model.main.outputIndexes,
model.main.operands, /*isOutput=*/true)));
return version;
}
@@ -2561,6 +2572,10 @@ Result<Version> validate(const ErrorStatus& errorStatus) {
return validateErrorStatus(errorStatus);
}
+Result<Version> validate(const FusedActivationFunc& activation) {
+ return validateFusedActivationFunc(activation);
+}
+
Result<Version> validate(const OutputShape& outputShape) {
return validateOutputShape(outputShape);
}
diff --git a/nn/common/include/nnapi/TypeUtils.h b/nn/common/include/nnapi/TypeUtils.h
index 30f1346f4..77761c40b 100644
--- a/nn/common/include/nnapi/TypeUtils.h
+++ b/nn/common/include/nnapi/TypeUtils.h
@@ -58,6 +58,7 @@ std::ostream& operator<<(std::ostream& os, const OperationType& operationType);
std::ostream& operator<<(std::ostream& os, const Request::Argument::LifeTime& lifetime);
std::ostream& operator<<(std::ostream& os, const Priority& priority);
std::ostream& operator<<(std::ostream& os, const ErrorStatus& errorStatus);
+std::ostream& operator<<(std::ostream& os, const FusedActivationFunc& activation);
std::ostream& operator<<(std::ostream& os, const OutputShape& outputShape);
std::ostream& operator<<(std::ostream& os, const Timing& timing);
std::ostream& operator<<(std::ostream& os, const Capabilities::PerformanceInfo& performanceInfo);
diff --git a/nn/common/include/nnapi/Types.h b/nn/common/include/nnapi/Types.h
index c0718dae1..371b4a5c2 100644
--- a/nn/common/include/nnapi/Types.h
+++ b/nn/common/include/nnapi/Types.h
@@ -116,6 +116,13 @@ enum class ErrorStatus {
DEAD_OBJECT = 10000,
};
+enum class FusedActivationFunc : int32_t {
+ NONE = 0,
+ RELU = 1,
+ RELU1 = 2,
+ RELU6 = 3,
+};
+
using Dimension = uint32_t;
using Dimensions = std::vector<Dimension>;
diff --git a/nn/common/include/nnapi/Validation.h b/nn/common/include/nnapi/Validation.h
index 61a2a273c..09133957b 100644
--- a/nn/common/include/nnapi/Validation.h
+++ b/nn/common/include/nnapi/Validation.h
@@ -37,6 +37,7 @@ Result<Version> validate(const DeviceType& deviceType);
Result<Version> validate(const MeasureTiming& measureTiming);
Result<Version> validate(const Priority& priority);
Result<Version> validate(const ErrorStatus& errorStatus);
+Result<Version> validate(const FusedActivationFunc& activation);
Result<Version> validate(const OutputShape& outputShape);
Result<Version> validate(const Timing& timing);
Result<Version> validate(const Capabilities& capabilities);