From 0824d7c6d6821941bde2d1b82efb7982ff7cc8a4 Mon Sep 17 00:00:00 2001 From: Lev Proleev Date: Wed, 20 May 2020 21:33:12 +0100 Subject: Fix sample driver segfault in BIDIRECTIONAL_SEQUENCE_LSTM and LSTM The segfault could happen if a model provided to the sample driver contained BIDIRECTIONAL_SEQUENCE_LSTM or LSTM with no inputs. The CL moves input and output count checks to the beginning of validation logic of these operations. Fix: 156306557 Test: NNTest_static + NNAPI_BSLstmFailure from ag/11514800 Change-Id: Ic7b0d8bd4ca954f03cfe2a4d2deca9b1e0022cee --- nn/common/Utils.cpp | 45 +++++++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 22 deletions(-) (limited to 'nn') diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index 81e5cf1e1..fedc8cb30 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -1083,6 +1083,20 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, outExpectedTypes); } case ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM: { + const uint32_t kNumOutputs = 2; + const uint32_t kNumOutputsMerged = 1; + const uint32_t kNumOutputsWithState = 6; + const uint32_t kNumOutputsMergedWithState = 5; + if (inputCount != 61 || + (outputCount != kNumOutputs && outputCount != kNumOutputsMerged && + outputCount != kNumOutputsWithState && + outputCount != kNumOutputsMergedWithState)) { + LOG(ERROR) << "Invalid number of input operands (" << inputCount + << ", expected 61) or output operands (" << outputCount + << ", expected 1, 2, 5 or 6) for operation " << getOperationName(opType); + return ANEURALNETWORKS_BAD_DATA; + } + std::vector inExpectedTypes; auto inputType = operands[inputIndexes[0]].type; if (inputType != OperandType::TENSOR_FLOAT32 && @@ -1109,20 +1123,6 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, inExpectedTypes.push_back(inputType); } - const uint32_t kNumOutputs = 2; - const uint32_t kNumOutputsMerged = 1; - const uint32_t kNumOutputsWithState = 6; - const uint32_t kNumOutputsMergedWithState = 5; - - if (inputCount != 61 || - (outputCount != kNumOutputs && outputCount != kNumOutputsMerged && - outputCount != kNumOutputsWithState && - outputCount != kNumOutputsMergedWithState)) { - LOG(ERROR) << "Invalid number of input operands (" << inputCount - << ", expected 61) or output operands (" << outputCount - << ", expected 1, 2, 5 or 6) for operation " << getOperationName(opType); - return ANEURALNETWORKS_BAD_DATA; - } HalVersion minSupportedHalVersion = HalVersion::V1_2; if (outputCount == kNumOutputsWithState || outputCount == kNumOutputsMergedWithState) { minSupportedHalVersion = HalVersion::V1_3; @@ -1135,6 +1135,12 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, return status; } case ANEURALNETWORKS_LSTM: { + if ((inputCount != 23 && inputCount != 27) || outputCount != 4) { + LOG(ERROR) << "Invalid number of input operands (" << inputCount + << ", expected 23 or 27) or output operands (" << outputCount + << ", expected 4) for operation " << getOperationName(opType); + return ANEURALNETWORKS_BAD_DATA; + } std::vector inExpectedTypes; std::vector outExpectedTypes; auto inputType = operands[inputIndexes[0]].type; @@ -1160,18 +1166,13 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } outExpectedTypes = {inputType, inputType, inputType, inputType}; - if (inputCount == 23 && outputCount == 4) { + if (inputCount == 23) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - } else if (inputCount == 27 && outputCount == 4) { + } else { + NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2)); for (int i = 0; i < 4; ++i) { inExpectedTypes.push_back(inputType); } - NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2)); - } else { - LOG(ERROR) << "Invalid number of input operands (" << inputCount - << ", expected 23 or 27) or output operands (" << outputCount - << ", expected 4) for operation " << getOperationName(opType); - return ANEURALNETWORKS_BAD_DATA; } return validateOperationOperandTypes(operands, inputCount, inputIndexes, inExpectedTypes, outputCount, outputIndexes, -- cgit v1.2.3