diff options
Diffstat (limited to 'nn/common/operations/LSTM.cpp')
-rw-r--r-- | nn/common/operations/LSTM.cpp | 54 |
1 files changed, 46 insertions, 8 deletions
diff --git a/nn/common/operations/LSTM.cpp b/nn/common/operations/LSTM.cpp index 6020353a3..ba5d46a71 100644 --- a/nn/common/operations/LSTM.cpp +++ b/nn/common/operations/LSTM.cpp @@ -83,16 +83,20 @@ LSTMCell::LSTMCell(const Operation& operation, RunTimeOperandInfo* operands) { output_state_in_ = GetInput(operation, operands, kOutputStateInTensor); cell_state_in_ = GetInput(operation, operands, kCellStateInTensor); - params_.activation = static_cast<TfLiteFusedActivation>( - getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam))); + const auto& activationOperand = *GetInput(operation, operands, kActivationParam); + params_.activation = static_cast<TfLiteFusedActivation>(getScalarDataWithDefault<int32_t>( + activationOperand, TfLiteFusedActivation::kTfLiteActNone)); + + const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam); + const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam); if (input_->type == OperandType::TENSOR_FLOAT32) { - params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam)); - params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam)); + params_.cell_clip = getScalarDataWithDefault<float>(cellClipOperand, 0.0f); + params_.proj_clip = getScalarDataWithDefault<float>(projClipOperand, 0.0f); } else { - params_.cell_clip = static_cast<float>( - getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam))); - params_.proj_clip = static_cast<float>( - getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam))); + params_.cell_clip = + static_cast<float>(getScalarDataWithDefault<_Float16>(cellClipOperand, 0.0f)); + params_.proj_clip = + static_cast<float>(getScalarDataWithDefault<_Float16>(projClipOperand, 0.0f)); } // We check the version of LSTM by checking the number of the inputs to the @@ -302,8 +306,42 @@ bool LSTMCell::Prepare(const Operation& operation, RunTimeOperandInfo* operands, // Check we have all the inputs and outputs we need. NN_CHECK(NumInputsWithValues(operation, operands) >= 15 && NumInputsWithValues(operation, operands) <= 27); + constexpr int requiredInputs[] = { + kInputTensor, + kInputToForgetWeightsTensor, + kInputToCellWeightsTensor, + kInputToOutputWeightsTensor, + kRecurrentToForgetWeightsTensor, + kRecurrentToCellWeightsTensor, + kRecurrentToOutputWeightsTensor, + kForgetGateBiasTensor, + kCellGateBiasTensor, + kOutputGateBiasTensor, + kOutputStateInTensor, + kCellStateInTensor, + kActivationParam, + kCellClipParam, + kProjClipParam, + }; + for (const int requiredInput : requiredInputs) { + NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput))) + << "required input " << requiredInput << " is omitted"; + } NN_CHECK_EQ(NumOutputs(operation), 4); + // Check that the scalar operands' buffers are large enough. + const auto& activationOperand = *GetInput(operation, operands, kActivationParam); + NN_RET_CHECK(activationOperand.length >= sizeof(int32_t)); + const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam); + const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam); + if (input_->type == OperandType::TENSOR_FLOAT32) { + NN_RET_CHECK(cellClipOperand.length >= sizeof(float)); + NN_RET_CHECK(projClipOperand.length >= sizeof(float)); + } else { + NN_RET_CHECK(cellClipOperand.length >= sizeof(_Float16)); + NN_RET_CHECK(projClipOperand.length >= sizeof(_Float16)); + } + // Inferring batch size, number of outputs and number of cells from the // input tensors. NN_CHECK(NumDimensions(input_) > 1); |