diff options
Diffstat (limited to 'nn/common/operations/BidirectionalSequenceLSTM.cpp')
-rw-r--r-- | nn/common/operations/BidirectionalSequenceLSTM.cpp | 78 |
1 files changed, 68 insertions, 10 deletions
diff --git a/nn/common/operations/BidirectionalSequenceLSTM.cpp b/nn/common/operations/BidirectionalSequenceLSTM.cpp index d4d32b964..12ac43f20 100644 --- a/nn/common/operations/BidirectionalSequenceLSTM.cpp +++ b/nn/common/operations/BidirectionalSequenceLSTM.cpp @@ -169,19 +169,24 @@ BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation, bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor); bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor); - params_.activation = static_cast<TfLiteFusedActivation>( - getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam))); + const auto& activationOperand = *GetInput(operation, operands, kActivationParam); + params_.activation = static_cast<TfLiteFusedActivation>(getScalarDataWithDefault<int32_t>( + activationOperand, TfLiteFusedActivation::kTfLiteActNone)); + const auto& clipOperand = *GetInput(operation, operands, kCellClipParam); + const auto& projOperand = *GetInput(operation, operands, kProjClipParam); if (input_->type == OperandType::TENSOR_FLOAT32) { - params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam)); - params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam)); + params_.cell_clip = getScalarDataWithDefault<float>(clipOperand, 0.0f); + params_.proj_clip = getScalarDataWithDefault<float>(projOperand, 0.0f); } else { - params_.cell_clip = static_cast<float>( - getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam))); - params_.proj_clip = static_cast<float>( - getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam))); + params_.cell_clip = + static_cast<float>(getScalarDataWithDefault<_Float16>(clipOperand, 0.0f)); + params_.proj_clip = + static_cast<float>(getScalarDataWithDefault<_Float16>(projOperand, 0.0f)); } - params_.merge_outputs = getScalarData<bool>(*GetInput(operation, operands, kMergeOutputsParam)); - params_.time_major = getScalarData<bool>(*GetInput(operation, operands, kTimeMajorParam)); + const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam); + params_.merge_outputs = getScalarDataWithDefault<bool>(mergeOutputsOperand, false); + const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam); + params_.time_major = getScalarDataWithDefault<bool>(timeMajorOperand, false); params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_); fw_output_ = GetOutput(operation, operands, kFwOutputTensor); @@ -205,6 +210,59 @@ bool BidirectionalSequenceLSTM::Prepare(const Operation& operation, RunTimeOpera Shape* fwOutputShape, Shape* bwOutputShape, Shape* fwOutputActivationState, Shape* fwOutputCellState, Shape* bwOutputActivationState, Shape* bwOutputCellState) { + // Check we have all the inputs and outputs we need. + constexpr int requiredInputs[] = { + kInputTensor, + kFwInputToForgetWeightsTensor, + kFwInputToCellWeightsTensor, + kFwInputToOutputWeightsTensor, + kFwRecurrentToForgetWeightsTensor, + kFwRecurrentToCellWeightsTensor, + kFwRecurrentToOutputWeightsTensor, + kFwForgetGateBiasTensor, + kFwCellGateBiasTensor, + kFwOutputGateBiasTensor, + kBwInputToForgetWeightsTensor, + kBwInputToCellWeightsTensor, + kBwInputToOutputWeightsTensor, + kBwRecurrentToForgetWeightsTensor, + kBwRecurrentToCellWeightsTensor, + kBwRecurrentToOutputWeightsTensor, + kBwForgetGateBiasTensor, + kBwCellGateBiasTensor, + kBwOutputGateBiasTensor, + kFwInputActivationStateTensor, + kFwInputCellStateTensor, + kBwInputActivationStateTensor, + kBwInputCellStateTensor, + kActivationParam, + kCellClipParam, + kProjClipParam, + kMergeOutputsParam, + kTimeMajorParam, + }; + for (const int requiredInput : requiredInputs) { + NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput))) + << "required input " << requiredInput << " is omitted"; + } + + // Check that the scalar operands' buffers are large enough. + const auto& activationOperand = *GetInput(operation, operands, kActivationParam); + NN_RET_CHECK(activationOperand.length >= sizeof(int32_t)); + const auto& cellOperand = *GetInput(operation, operands, kCellClipParam); + const auto& projOperand = *GetInput(operation, operands, kProjClipParam); + if (input_->type == OperandType::TENSOR_FLOAT32) { + NN_RET_CHECK(cellOperand.length >= sizeof(float)); + NN_RET_CHECK(projOperand.length >= sizeof(float)); + } else { + NN_RET_CHECK(cellOperand.length >= sizeof(_Float16)); + NN_RET_CHECK(projOperand.length >= sizeof(_Float16)); + } + const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam); + NN_RET_CHECK(mergeOutputsOperand.length >= sizeof(bool)); + const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam); + NN_RET_CHECK(timeMajorOperand.length >= sizeof(bool)); + // Inferring batch size, number of outputs and number of cells from the // input tensors. NN_CHECK(NumDimensions(input_) == 3); |