aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc')
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
index a246066b..4c8b3fd7 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
@@ -245,13 +245,17 @@ absl::Status BertNLClassifier::InitializeFromMetadata() {
segment_ids_tensor.dims->data[1]),
TfLiteSupportStatus::kInvalidInputTensorSizeError);
}
- if (ids_tensor.dims_signature->data[1] == -1 &&
- mask_tensor.dims_signature->data[1] == -1 &&
- segment_ids_tensor.dims_signature->data[1] == -1) {
+
+ bool has_valid_dims_signature = ids_tensor.dims_signature->size == 2 &&
+ mask_tensor.dims_signature->size == 2 &&
+ segment_ids_tensor.dims_signature->size == 2;
+ int num_dynamic_tensors = ids_tensor.dims_signature->data[1] == -1 +
+ mask_tensor.dims_signature->data[1] == -1 +
+ segment_ids_tensor.dims_signature->data[1] == -1;
+
+ if (has_valid_dims_signature && num_dynamic_tensors == 3) {
input_tensors_are_dynamic_ = true;
- } else if (ids_tensor.dims_signature->data[1] == -1 ||
- mask_tensor.dims_signature->data[1] == -1 ||
- segment_ids_tensor.dims_signature->data[1] == -1) {
+ } else if (has_valid_dims_signature && num_dynamic_tensors > 0 && num_dynamic_tensors < 3) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
"Input tensors contain a mix of static and dynamic tensors",