aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc')
-rw-r--r--tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc113
1 files changed, 91 insertions, 22 deletions
diff --git a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
index d689c9e8..a246066b 100644
--- a/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
+++ b/tensorflow_lite_support/cc/task/text/nlclassifier/bert_nl_classifier.cc
@@ -55,12 +55,16 @@ namespace {
constexpr char kIdsTensorName[] = "ids";
constexpr char kMaskTensorName[] = "mask";
constexpr char kSegmentIdsTensorName[] = "segment_ids";
+constexpr int kIdsTensorIndex = 0;
+constexpr int kMaskTensorIndex = 1;
+constexpr int kSegmentIdsTensorIndex = 2;
constexpr char kScoreTensorName[] = "probability";
constexpr char kClassificationToken[] = "[CLS]";
constexpr char kSeparator[] = "[SEP]";
constexpr int kTokenizerProcessUnitIndex = 0;
} // namespace
+// TODO(b/241507692) Add a unit test for a model with dynamic tensors.
absl::Status BertNLClassifier::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors, const std::string& input) {
auto* input_tensor_metadatas =
@@ -78,39 +82,46 @@ absl::Status BertNLClassifier::Preprocess(
TokenizerResult input_tokenize_results;
input_tokenize_results = tokenizer_->Tokenize(processed_input);
- // 2 accounts for [CLS], [SEP]
- absl::Span<const std::string> query_tokens =
- absl::MakeSpan(input_tokenize_results.subwords.data(),
- input_tokenize_results.subwords.data() +
- std::min(static_cast<size_t>(kMaxSeqLen - 2),
- input_tokenize_results.subwords.size()));
-
- std::vector<std::string> tokens;
- tokens.reserve(2 + query_tokens.size());
- // Start of generating the features.
- tokens.push_back(kClassificationToken);
- // For query input.
- for (const auto& query_token : query_tokens) {
- tokens.push_back(query_token);
+ // Offset by 2 to account for [CLS] and [SEP]
+ int input_tokens_size =
+ static_cast<int>(input_tokenize_results.subwords.size()) + 2;
+ int input_tensor_length = input_tokens_size;
+ if (!input_tensors_are_dynamic_) {
+ input_tokens_size = std::min(kMaxSeqLen, input_tokens_size);
+ input_tensor_length = kMaxSeqLen;
+ } else {
+ GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kIdsTensorIndex,
+ {1, input_tensor_length});
+ GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kMaskTensorIndex,
+ {1, input_tensor_length});
+ GetTfLiteEngine()->interpreter()->ResizeInputTensorStrict(kSegmentIdsTensorIndex,
+ {1, input_tensor_length});
+ GetTfLiteEngine()->interpreter()->AllocateTensors();
}
- // For Separation.
- tokens.push_back(kSeparator);
- std::vector<int> input_ids(kMaxSeqLen, 0);
- std::vector<int> input_mask(kMaxSeqLen, 0);
+ std::vector<std::string> input_tokens;
+ input_tokens.reserve(input_tokens_size);
+ input_tokens.push_back(std::string(kClassificationToken));
+ for (int i = 0; i < input_tokens_size - 2; ++i) {
+ input_tokens.push_back(std::move(input_tokenize_results.subwords[i]));
+ }
+ input_tokens.push_back(std::string(kSeparator));
+
+ std::vector<int> input_ids(input_tensor_length, 0);
+ std::vector<int> input_mask(input_tensor_length, 0);
// Convert tokens back into ids and set mask
- for (int i = 0; i < tokens.size(); ++i) {
- tokenizer_->LookupId(tokens[i], &input_ids[i]);
+ for (int i = 0; i < input_tokens.size(); ++i) {
+ tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
input_mask[i] = 1;
}
- // |<-----------kMaxSeqLen---------->|
+ // |<--------input_tensor_length------->|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0
PopulateTensor(input_ids, ids_tensor);
PopulateTensor(input_mask, mask_tensor);
- PopulateTensor(std::vector<int>(kMaxSeqLen, 0), segment_ids_tensor);
+ PopulateTensor(std::vector<int>(input_tensor_length, 0), segment_ids_tensor);
return absl::OkStatus();
}
@@ -189,6 +200,64 @@ absl::Status BertNLClassifier::InitializeFromMetadata() {
TrySetLabelFromMetadata(
GetMetadataExtractor()->GetOutputTensorMetadata(kOutputTensorIndex))
.IgnoreError();
+
+ auto* input_tensor_metadatas =
+ GetMetadataExtractor()->GetInputTensorMetadata();
+ const auto& input_tensors = GetInputTensors();
+ const auto& ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
+ kIdsTensorName);
+ const auto& mask_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
+ kMaskTensorName);
+ const auto& segment_ids_tensor = *FindTensorByName(input_tensors, input_tensor_metadatas,
+ kSegmentIdsTensorName);
+ if (ids_tensor.dims->size != 2 || mask_tensor.dims->size != 2 ||
+ segment_ids_tensor.dims->size != 2) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ absl::StrFormat(
+ "The three input tensors in Bert models are expected to have dim "
+ "2, but got ids_tensor (%d), mask_tensor (%d), segment_ids_tensor "
+ "(%d).",
+ ids_tensor.dims->size, mask_tensor.dims->size,
+ segment_ids_tensor.dims->size),
+ TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
+ }
+ if (ids_tensor.dims->data[0] != 1 || mask_tensor.dims->data[0] != 1 ||
+ segment_ids_tensor.dims->data[0] != 1) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ absl::StrFormat(
+ "The three input tensors in Bert models are expected to have same "
+ "batch size 1, but got ids_tensor (%d), mask_tensor (%d), "
+ "segment_ids_tensor (%d).",
+ ids_tensor.dims->data[0], mask_tensor.dims->data[0],
+ segment_ids_tensor.dims->data[0]),
+ TfLiteSupportStatus::kInvalidInputTensorSizeError);
+ }
+ if (ids_tensor.dims->data[1] != mask_tensor.dims->data[1] ||
+ ids_tensor.dims->data[1] != segment_ids_tensor.dims->data[1]) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ absl::StrFormat("The three input tensors in Bert models are "
+ "expected to have same length, but got ids_tensor "
+ "(%d), mask_tensor (%d), segment_ids_tensor (%d).",
+ ids_tensor.dims->data[1], mask_tensor.dims->data[1],
+ segment_ids_tensor.dims->data[1]),
+ TfLiteSupportStatus::kInvalidInputTensorSizeError);
+ }
+ if (ids_tensor.dims_signature->data[1] == -1 &&
+ mask_tensor.dims_signature->data[1] == -1 &&
+ segment_ids_tensor.dims_signature->data[1] == -1) {
+ input_tensors_are_dynamic_ = true;
+ } else if (ids_tensor.dims_signature->data[1] == -1 ||
+ mask_tensor.dims_signature->data[1] == -1 ||
+ segment_ids_tensor.dims_signature->data[1] == -1) {
+ return CreateStatusWithPayload(
+ absl::StatusCode::kInternal,
+ "Input tensors contain a mix of static and dynamic tensors",
+ TfLiteSupportStatus::kInvalidInputTensorSizeError);
+ }
+
return absl::OkStatus();
}