diff options
Diffstat (limited to 'tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc')
-rw-r--r-- | tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc index a246066b..4c8b3fd7 100644 --- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc +++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc @@ -245,13 +245,17 @@ absl::Status BertNLClassifier::InitializeFromMetadata() { segment_ids_tensor.dims->data[1]), TfLiteSupportStatus::kInvalidInputTensorSizeError); } - if (ids_tensor.dims_signature->data[1] == -1 && - mask_tensor.dims_signature->data[1] == -1 && - segment_ids_tensor.dims_signature->data[1] == -1) { + + bool has_valid_dims_signature = ids_tensor.dims_signature->size == 2 && + mask_tensor.dims_signature->size == 2 && + segment_ids_tensor.dims_signature->size == 2; + int num_dynamic_tensors = ids_tensor.dims_signature->data[1] == -1 + + mask_tensor.dims_signature->data[1] == -1 + + segment_ids_tensor.dims_signature->data[1] == -1; + + if (has_valid_dims_signature && num_dynamic_tensors == 3) { input_tensors_are_dynamic_ = true; - } else if (ids_tensor.dims_signature->data[1] == -1 || - mask_tensor.dims_signature->data[1] == -1 || - segment_ids_tensor.dims_signature->data[1] == -1) { + } else if (has_valid_dims_signature && num_dynamic_tensors > 0 && num_dynamic_tensors < 3) { return CreateStatusWithPayload( absl::StatusCode::kInternal, "Input tensors contain a mix of static and dynamic tensors", |