summaryrefslogtreecommitdiff
path: root/nn/common/operations/LSTM.cpp
diff options
context:
space:
mode:
authorMichael Butler <butlermichael@google.com>2020-05-27 23:06:31 -0700
committerMichael Butler <butlermichael@google.com>2020-06-09 02:46:13 +0000
commitdde5b4dd1fbb4c75ee6a962dffd3fa996d95958b (patch)
tree7d407d20825e1c9600424d21544b47b62470edf5 /nn/common/operations/LSTM.cpp
parent66e5923200afc965bf19b880737e9180e9f5c909 (diff)
downloadml-dde5b4dd1fbb4c75ee6a962dffd3fa996d95958b.tar.gz
Verify non-optional tensors have values in CpuExecutor
This change adds additional validation for non-optional tensors for the following operations: * EMBEDDING_LOOKUP * HASHTABLE_LOOKUP * LSH_PROJECTION * BIDIRECTIONAL_SEQUENCE_LSTM * LSTM * RANDOM_MULTINOMIAL * RNN * SVDF * SPLIT Some operations such as SVDF unpack the scalar values without checking if the value is present, leading to a failed CHECK. This CL adds protections to use default values in these cases, and relies on a corresponding Prepare method to cause these cases to fail validation. Bug: 157516274 Test: mma Test: CtsNNAPITestCases Test: NeuralNetworksTest_static Test: libneuralnetworks_fuzzer Change-Id: I6bb804ec40205c9741b04231022894c714ad28ec
Diffstat (limited to 'nn/common/operations/LSTM.cpp')
-rw-r--r--nn/common/operations/LSTM.cpp54
1 files changed, 46 insertions, 8 deletions
diff --git a/nn/common/operations/LSTM.cpp b/nn/common/operations/LSTM.cpp
index 6020353a3..ba5d46a71 100644
--- a/nn/common/operations/LSTM.cpp
+++ b/nn/common/operations/LSTM.cpp
@@ -83,16 +83,20 @@ LSTMCell::LSTMCell(const Operation& operation, RunTimeOperandInfo* operands) {
output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
- params_.activation = static_cast<TfLiteFusedActivation>(
- getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
+ const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
+ params_.activation = static_cast<TfLiteFusedActivation>(getScalarDataWithDefault<int32_t>(
+ activationOperand, TfLiteFusedActivation::kTfLiteActNone));
+
+ const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
+ const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
if (input_->type == OperandType::TENSOR_FLOAT32) {
- params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
- params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
+ params_.cell_clip = getScalarDataWithDefault<float>(cellClipOperand, 0.0f);
+ params_.proj_clip = getScalarDataWithDefault<float>(projClipOperand, 0.0f);
} else {
- params_.cell_clip = static_cast<float>(
- getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
- params_.proj_clip = static_cast<float>(
- getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
+ params_.cell_clip =
+ static_cast<float>(getScalarDataWithDefault<_Float16>(cellClipOperand, 0.0f));
+ params_.proj_clip =
+ static_cast<float>(getScalarDataWithDefault<_Float16>(projClipOperand, 0.0f));
}
// We check the version of LSTM by checking the number of the inputs to the
@@ -302,8 +306,42 @@ bool LSTMCell::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
// Check we have all the inputs and outputs we need.
NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
NumInputsWithValues(operation, operands) <= 27);
+ constexpr int requiredInputs[] = {
+ kInputTensor,
+ kInputToForgetWeightsTensor,
+ kInputToCellWeightsTensor,
+ kInputToOutputWeightsTensor,
+ kRecurrentToForgetWeightsTensor,
+ kRecurrentToCellWeightsTensor,
+ kRecurrentToOutputWeightsTensor,
+ kForgetGateBiasTensor,
+ kCellGateBiasTensor,
+ kOutputGateBiasTensor,
+ kOutputStateInTensor,
+ kCellStateInTensor,
+ kActivationParam,
+ kCellClipParam,
+ kProjClipParam,
+ };
+ for (const int requiredInput : requiredInputs) {
+ NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
+ << "required input " << requiredInput << " is omitted";
+ }
NN_CHECK_EQ(NumOutputs(operation), 4);
+ // Check that the scalar operands' buffers are large enough.
+ const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
+ NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
+ const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
+ const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
+ if (input_->type == OperandType::TENSOR_FLOAT32) {
+ NN_RET_CHECK(cellClipOperand.length >= sizeof(float));
+ NN_RET_CHECK(projClipOperand.length >= sizeof(float));
+ } else {
+ NN_RET_CHECK(cellClipOperand.length >= sizeof(_Float16));
+ NN_RET_CHECK(projClipOperand.length >= sizeof(_Float16));
+ }
+
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
NN_CHECK(NumDimensions(input_) > 1);