summaryrefslogtreecommitdiff
path: root/nn/common/operations/BidirectionalSequenceLSTM.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'nn/common/operations/BidirectionalSequenceLSTM.cpp')
-rw-r--r--nn/common/operations/BidirectionalSequenceLSTM.cpp78
1 files changed, 68 insertions, 10 deletions
diff --git a/nn/common/operations/BidirectionalSequenceLSTM.cpp b/nn/common/operations/BidirectionalSequenceLSTM.cpp
index d4d32b964..12ac43f20 100644
--- a/nn/common/operations/BidirectionalSequenceLSTM.cpp
+++ b/nn/common/operations/BidirectionalSequenceLSTM.cpp
@@ -169,19 +169,24 @@ BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
- params_.activation = static_cast<TfLiteFusedActivation>(
- getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
+ const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
+ params_.activation = static_cast<TfLiteFusedActivation>(getScalarDataWithDefault<int32_t>(
+ activationOperand, TfLiteFusedActivation::kTfLiteActNone));
+ const auto& clipOperand = *GetInput(operation, operands, kCellClipParam);
+ const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
if (input_->type == OperandType::TENSOR_FLOAT32) {
- params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
- params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
+ params_.cell_clip = getScalarDataWithDefault<float>(clipOperand, 0.0f);
+ params_.proj_clip = getScalarDataWithDefault<float>(projOperand, 0.0f);
} else {
- params_.cell_clip = static_cast<float>(
- getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
- params_.proj_clip = static_cast<float>(
- getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
+ params_.cell_clip =
+ static_cast<float>(getScalarDataWithDefault<_Float16>(clipOperand, 0.0f));
+ params_.proj_clip =
+ static_cast<float>(getScalarDataWithDefault<_Float16>(projOperand, 0.0f));
}
- params_.merge_outputs = getScalarData<bool>(*GetInput(operation, operands, kMergeOutputsParam));
- params_.time_major = getScalarData<bool>(*GetInput(operation, operands, kTimeMajorParam));
+ const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
+ params_.merge_outputs = getScalarDataWithDefault<bool>(mergeOutputsOperand, false);
+ const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
+ params_.time_major = getScalarDataWithDefault<bool>(timeMajorOperand, false);
params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
@@ -205,6 +210,59 @@ bool BidirectionalSequenceLSTM::Prepare(const Operation& operation, RunTimeOpera
Shape* fwOutputShape, Shape* bwOutputShape,
Shape* fwOutputActivationState, Shape* fwOutputCellState,
Shape* bwOutputActivationState, Shape* bwOutputCellState) {
+ // Check we have all the inputs and outputs we need.
+ constexpr int requiredInputs[] = {
+ kInputTensor,
+ kFwInputToForgetWeightsTensor,
+ kFwInputToCellWeightsTensor,
+ kFwInputToOutputWeightsTensor,
+ kFwRecurrentToForgetWeightsTensor,
+ kFwRecurrentToCellWeightsTensor,
+ kFwRecurrentToOutputWeightsTensor,
+ kFwForgetGateBiasTensor,
+ kFwCellGateBiasTensor,
+ kFwOutputGateBiasTensor,
+ kBwInputToForgetWeightsTensor,
+ kBwInputToCellWeightsTensor,
+ kBwInputToOutputWeightsTensor,
+ kBwRecurrentToForgetWeightsTensor,
+ kBwRecurrentToCellWeightsTensor,
+ kBwRecurrentToOutputWeightsTensor,
+ kBwForgetGateBiasTensor,
+ kBwCellGateBiasTensor,
+ kBwOutputGateBiasTensor,
+ kFwInputActivationStateTensor,
+ kFwInputCellStateTensor,
+ kBwInputActivationStateTensor,
+ kBwInputCellStateTensor,
+ kActivationParam,
+ kCellClipParam,
+ kProjClipParam,
+ kMergeOutputsParam,
+ kTimeMajorParam,
+ };
+ for (const int requiredInput : requiredInputs) {
+ NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
+ << "required input " << requiredInput << " is omitted";
+ }
+
+ // Check that the scalar operands' buffers are large enough.
+ const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
+ NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
+ const auto& cellOperand = *GetInput(operation, operands, kCellClipParam);
+ const auto& projOperand = *GetInput(operation, operands, kProjClipParam);
+ if (input_->type == OperandType::TENSOR_FLOAT32) {
+ NN_RET_CHECK(cellOperand.length >= sizeof(float));
+ NN_RET_CHECK(projOperand.length >= sizeof(float));
+ } else {
+ NN_RET_CHECK(cellOperand.length >= sizeof(_Float16));
+ NN_RET_CHECK(projOperand.length >= sizeof(_Float16));
+ }
+ const auto& mergeOutputsOperand = *GetInput(operation, operands, kMergeOutputsParam);
+ NN_RET_CHECK(mergeOutputsOperand.length >= sizeof(bool));
+ const auto& timeMajorOperand = *GetInput(operation, operands, kTimeMajorParam);
+ NN_RET_CHECK(timeMajorOperand.length >= sizeof(bool));
+
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
NN_CHECK(NumDimensions(input_) == 3);