summaryrefslogtreecommitdiff
path: root/nn/common/operations/Activation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/operations/Activation.cpp')
-rw-r--r--nn/common/operations/Activation.cpp43
1 files changed, 32 insertions, 11 deletions
diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp
index 491226e60..f85f6b4bf 100644
--- a/nn/common/operations/Activation.cpp
+++ b/nn/common/operations/Activation.cpp
@@ -226,11 +226,28 @@ bool validate(OperationType opType, const IOperationValidationContext* context)
return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
}
-bool prepare(IOperationExecutionContext* context) {
+bool prepare(OperationType opType, IOperationExecutionContext* context) {
Shape input = context->getInputShape(kInputTensor);
NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
- Shape output = context->getOutputShape(kOutputTensor);
- output.dimensions = input.dimensions;
+ Shape output = input;
+ if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
+ switch (opType) {
+ case OperationType::RELU:
+ case OperationType::RELU1:
+ case OperationType::RELU6:
+ break;
+ case OperationType::LOGISTIC:
+ output.scale = 1.f / 256;
+ output.offset = 0;
+ break;
+ case OperationType::TANH:
+ output.scale = 1.f / 128;
+ output.offset = 128;
+ break;
+ default:
+ NN_RET_CHECK_FAIL() << "Unsupported operation type";
+ }
+ }
return context->setOutputShape(kOutputTensor, output);
}
@@ -326,7 +343,7 @@ bool executeLogistic(IOperationExecutionContext* context) {
context->getOutputBuffer<uint8_t>(kOutputTensor),
context->getOutputShape(kOutputTensor));
default:
- NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
+ NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
}
}
@@ -350,7 +367,7 @@ bool executeTanh(IOperationExecutionContext* context) {
context->getOutputBuffer<uint8_t>(kOutputTensor),
context->getOutputShape(kOutputTensor));
default:
- NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC";
+ NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH";
}
}
@@ -358,17 +375,21 @@ bool executeTanh(IOperationExecutionContext* context) {
using std::placeholders::_1;
NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1),
- activation::prepare, activation::executeRelu, .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::RELU, _1),
+ activation::executeRelu, .allowZeroSizedInput = true);
NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1),
- activation::prepare, activation::executeRelu1, .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::RELU1, _1),
+ activation::executeRelu1, .allowZeroSizedInput = true);
NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1),
- activation::prepare, activation::executeRelu6, .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::RELU6, _1),
+ activation::executeRelu6, .allowZeroSizedInput = true);
NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC",
std::bind(activation::validate, OperationType::LOGISTIC, _1),
- activation::prepare, activation::executeLogistic,
- .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::LOGISTIC, _1),
+ activation::executeLogistic, .allowZeroSizedInput = true);
NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1),
- activation::prepare, activation::executeTanh, .allowZeroSizedInput = true);
+ std::bind(activation::prepare, OperationType::TANH, _1),
+ activation::executeTanh, .allowZeroSizedInput = true);
} // namespace nn
} // namespace android