diff options
author | Lev Proleev <levp@google.com> | 2020-05-20 21:33:12 +0100 |
---|---|---|
committer | Lev Proleev <levp@google.com> | 2020-05-20 22:10:58 +0100 |
commit | 0824d7c6d6821941bde2d1b82efb7982ff7cc8a4 (patch) | |
tree | 1cc80d6e5e71d1414d921d6ccba03ba010852e14 /nn | |
parent | 3594be18553db032e3af6dae5c6b9e310eb2f3d8 (diff) | |
download | ml-0824d7c6d6821941bde2d1b82efb7982ff7cc8a4.tar.gz |
Fix sample driver segfault in BIDIRECTIONAL_SEQUENCE_LSTM and LSTM
The segfault could happen if a model provided to the sample driver
contained BIDIRECTIONAL_SEQUENCE_LSTM or LSTM with no inputs.
The CL moves input and output count checks to the beginning of
validation logic of these operations.
Fix: 156306557
Test: NNTest_static + NNAPI_BSLstmFailure from ag/11514800
Change-Id: Ic7b0d8bd4ca954f03cfe2a4d2deca9b1e0022cee
Diffstat (limited to 'nn')
-rw-r--r-- | nn/common/Utils.cpp | 45 |
1 files changed, 23 insertions, 22 deletions
diff --git a/nn/common/Utils.cpp b/nn/common/Utils.cpp index 81e5cf1e1..fedc8cb30 100644 --- a/nn/common/Utils.cpp +++ b/nn/common/Utils.cpp @@ -1083,6 +1083,20 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, outExpectedTypes); } case ANEURALNETWORKS_BIDIRECTIONAL_SEQUENCE_LSTM: { + const uint32_t kNumOutputs = 2; + const uint32_t kNumOutputsMerged = 1; + const uint32_t kNumOutputsWithState = 6; + const uint32_t kNumOutputsMergedWithState = 5; + if (inputCount != 61 || + (outputCount != kNumOutputs && outputCount != kNumOutputsMerged && + outputCount != kNumOutputsWithState && + outputCount != kNumOutputsMergedWithState)) { + LOG(ERROR) << "Invalid number of input operands (" << inputCount + << ", expected 61) or output operands (" << outputCount + << ", expected 1, 2, 5 or 6) for operation " << getOperationName(opType); + return ANEURALNETWORKS_BAD_DATA; + } + std::vector<OperandType> inExpectedTypes; auto inputType = operands[inputIndexes[0]].type; if (inputType != OperandType::TENSOR_FLOAT32 && @@ -1109,20 +1123,6 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, inExpectedTypes.push_back(inputType); } - const uint32_t kNumOutputs = 2; - const uint32_t kNumOutputsMerged = 1; - const uint32_t kNumOutputsWithState = 6; - const uint32_t kNumOutputsMergedWithState = 5; - - if (inputCount != 61 || - (outputCount != kNumOutputs && outputCount != kNumOutputsMerged && - outputCount != kNumOutputsWithState && - outputCount != kNumOutputsMergedWithState)) { - LOG(ERROR) << "Invalid number of input operands (" << inputCount - << ", expected 61) or output operands (" << outputCount - << ", expected 1, 2, 5 or 6) for operation " << getOperationName(opType); - return ANEURALNETWORKS_BAD_DATA; - } HalVersion minSupportedHalVersion = HalVersion::V1_2; if (outputCount == kNumOutputsWithState || outputCount == kNumOutputsMergedWithState) { minSupportedHalVersion = HalVersion::V1_3; @@ -1135,6 +1135,12 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, return status; } case ANEURALNETWORKS_LSTM: { + if ((inputCount != 23 && inputCount != 27) || outputCount != 4) { + LOG(ERROR) << "Invalid number of input operands (" << inputCount + << ", expected 23 or 27) or output operands (" << outputCount + << ", expected 4) for operation " << getOperationName(opType); + return ANEURALNETWORKS_BAD_DATA; + } std::vector<OperandType> inExpectedTypes; std::vector<OperandType> outExpectedTypes; auto inputType = operands[inputIndexes[0]].type; @@ -1160,18 +1166,13 @@ int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount, } outExpectedTypes = {inputType, inputType, inputType, inputType}; - if (inputCount == 23 && outputCount == 4) { + if (inputCount == 23) { NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0)); - } else if (inputCount == 27 && outputCount == 4) { + } else { + NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2)); for (int i = 0; i < 4; ++i) { inExpectedTypes.push_back(inputType); } - NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2)); - } else { - LOG(ERROR) << "Invalid number of input operands (" << inputCount - << ", expected 23 or 27) or output operands (" << outputCount - << ", expected 4) for operation " << getOperationName(opType); - return ANEURALNETWORKS_BAD_DATA; } return validateOperationOperandTypes(operands, inputCount, inputIndexes, inExpectedTypes, outputCount, outputIndexes, |