diff options
Diffstat (limited to 'nn/common/operations/Activation.cpp')
-rw-r--r-- | nn/common/operations/Activation.cpp | 43 |
1 files changed, 32 insertions, 11 deletions
diff --git a/nn/common/operations/Activation.cpp b/nn/common/operations/Activation.cpp index 491226e60..f85f6b4bf 100644 --- a/nn/common/operations/Activation.cpp +++ b/nn/common/operations/Activation.cpp @@ -226,11 +226,28 @@ bool validate(OperationType opType, const IOperationValidationContext* context) return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType}); } -bool prepare(IOperationExecutionContext* context) { +bool prepare(OperationType opType, IOperationExecutionContext* context) { Shape input = context->getInputShape(kInputTensor); NN_RET_CHECK_LE(getNumberOfDimensions(input), 4); - Shape output = context->getOutputShape(kOutputTensor); - output.dimensions = input.dimensions; + Shape output = input; + if (input.type == OperandType::TENSOR_QUANT8_ASYMM) { + switch (opType) { + case OperationType::RELU: + case OperationType::RELU1: + case OperationType::RELU6: + break; + case OperationType::LOGISTIC: + output.scale = 1.f / 256; + output.offset = 0; + break; + case OperationType::TANH: + output.scale = 1.f / 128; + output.offset = 128; + break; + default: + NN_RET_CHECK_FAIL() << "Unsupported operation type"; + } + } return context->setOutputShape(kOutputTensor, output); } @@ -326,7 +343,7 @@ bool executeLogistic(IOperationExecutionContext* context) { context->getOutputBuffer<uint8_t>(kOutputTensor), context->getOutputShape(kOutputTensor)); default: - NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH"; + NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC"; } } @@ -350,7 +367,7 @@ bool executeTanh(IOperationExecutionContext* context) { context->getOutputBuffer<uint8_t>(kOutputTensor), context->getOutputShape(kOutputTensor)); default: - NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation LOGISTIC"; + NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation TANH"; } } @@ -358,17 +375,21 @@ bool executeTanh(IOperationExecutionContext* context) { using std::placeholders::_1; NN_REGISTER_OPERATION(RELU, "RELU", std::bind(activation::validate, OperationType::RELU, _1), - activation::prepare, activation::executeRelu, .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::RELU, _1), + activation::executeRelu, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(RELU1, "RELU1", std::bind(activation::validate, OperationType::RELU1, _1), - activation::prepare, activation::executeRelu1, .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::RELU1, _1), + activation::executeRelu1, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(RELU6, "RELU6", std::bind(activation::validate, OperationType::RELU6, _1), - activation::prepare, activation::executeRelu6, .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::RELU6, _1), + activation::executeRelu6, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(LOGISTIC, "LOGISTIC", std::bind(activation::validate, OperationType::LOGISTIC, _1), - activation::prepare, activation::executeLogistic, - .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::LOGISTIC, _1), + activation::executeLogistic, .allowZeroSizedInput = true); NN_REGISTER_OPERATION(TANH, "TANH", std::bind(activation::validate, OperationType::TANH, _1), - activation::prepare, activation::executeTanh, .allowZeroSizedInput = true); + std::bind(activation::prepare, OperationType::TANH, _1), + activation::executeTanh, .allowZeroSizedInput = true); } // namespace nn } // namespace android |