aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java')
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java195
1 files changed, 195 insertions, 0 deletions
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
new file mode 100644
index 00000000..76f562ef
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/qa/BertQuestionAnswerer.java
@@ -0,0 +1,195 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.task.text.qa;
+
+import android.content.Context;
+import android.os.ParcelFileDescriptor;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.List;
+import org.tensorflow.lite.task.core.BaseTaskApi;
+import org.tensorflow.lite.task.core.TaskJniUtils;
+import org.tensorflow.lite.task.core.TaskJniUtils.EmptyHandleProvider;
+import org.tensorflow.lite.task.core.TaskJniUtils.MultipleBuffersHandleProvider;
+
+/** Task API for BertQA models. */
+public class BertQuestionAnswerer extends BaseTaskApi implements QuestionAnswerer {
+ private static final String BERT_QUESTION_ANSWERER_NATIVE_LIBNAME = "task_text_jni";
+
+ private BertQuestionAnswerer(long nativeHandle) {
+ super(nativeHandle);
+ }
+
+ /**
+ * Generic API to create the QuestionAnswerer for bert models with metadata populated. The API
+ * expects a Bert based TFLite model with metadata containing the following information:
+ *
+ * <ul>
+ * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be
+ * used for a <a
+ * href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a>
+ * model, Sentencepiece Tokenizer Tokenizer can be used for an <a
+ * href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a>
+ * model.
+ * <li>3 input tensors with names "ids", "mask" and "segment_ids".
+ * <li>2 output tensors with names "end_logits" and "start_logits".
+ * </ul>
+ *
+ * @param context android context
+ * @param pathToModel file path to the model with metadata. Note: The model should not be
+ * compressed
+ * @return {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load.
+ */
+ public static BertQuestionAnswerer createFromFile(Context context, String pathToModel)
+ throws IOException {
+ return new BertQuestionAnswerer(
+ TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
+ context,
+ new MultipleBuffersHandleProvider() {
+ @Override
+ public long createHandle(ByteBuffer... buffers) {
+ return BertQuestionAnswerer.initJniWithModelWithMetadataByteBuffers(buffers);
+ }
+ },
+ BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
+ pathToModel));
+ }
+
+ /**
+ * Generic API to create the QuestionAnswerer for bert models with metadata populated. The API
+ * expects a Bert based TFLite model with metadata containing the following information:
+ *
+ * <ul>
+ * <li>input_process_units for Wordpiece/Sentencepiece Tokenizer - Wordpiece Tokenizer can be
+ * used for a <a
+ * href="https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1">MobileBert</a>
+ * model, Sentencepiece Tokenizer Tokenizer can be used for an <a
+ * href="https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1">Albert</a>
+ * model.
+ * <li>3 input tensors with names "ids", "mask" and "segment_ids".
+ * <li>2 output tensors with names "end_logits" and "start_logits".
+ * </ul>
+ *
+ * @param modelFile {@link File} object of the model
+ * @return {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load.
+ */
+ public static BertQuestionAnswerer createFromFile(File modelFile)
+ throws IOException {
+ try (ParcelFileDescriptor descriptor =
+ ParcelFileDescriptor.open(modelFile, ParcelFileDescriptor.MODE_READ_ONLY)) {
+ return new BertQuestionAnswerer(
+ TaskJniUtils.createHandleFromLibrary(
+ new EmptyHandleProvider() {
+ @Override
+ public long createHandle() {
+ return initJniWithFileDescriptor(descriptor.getFd());
+ }
+ },
+ BERT_QUESTION_ANSWERER_NATIVE_LIBNAME));
+ }
+ }
+
+ /**
+ * Creates the API instance with a bert model and vocabulary file.
+ *
+ * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
+ *
+ * @param context android context
+ * @param pathToModel file path to the bert model. Note: The model should not be compressed
+ * @param pathToVocab file path to the vocabulary file. Note: The file should not be compressed
+ * @return {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load.
+ */
+ public static BertQuestionAnswerer createBertQuestionAnswererFromFile(
+ Context context, String pathToModel, String pathToVocab) throws IOException {
+ return new BertQuestionAnswerer(
+ TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
+ context,
+ new MultipleBuffersHandleProvider() {
+ @Override
+ public long createHandle(ByteBuffer... buffers) {
+ return BertQuestionAnswerer.initJniWithBertByteBuffers(buffers);
+ }
+ },
+ BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
+ pathToModel,
+ pathToVocab));
+ }
+
+ /**
+ * Creates the API instance with an albert model and sentence piece model file.
+ *
+ * <p>One suitable model is: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
+ *
+ * @param context android context
+ * @param pathToModel file path to the albert model. Note: The model should not be compressed
+ * @param pathToSentencePieceModel file path to the sentence piece model file. Note: The model
+ * should not be compressed
+ * @return {@link BertQuestionAnswerer} instance
+ * @throws IOException If model file fails to load.
+ */
+ public static BertQuestionAnswerer createAlbertQuestionAnswererFromFile(
+ Context context, String pathToModel, String pathToSentencePieceModel) throws IOException {
+ return new BertQuestionAnswerer(
+ TaskJniUtils.createHandleWithMultipleAssetFilesFromLibrary(
+ context,
+ new MultipleBuffersHandleProvider() {
+ @Override
+ public long createHandle(ByteBuffer... buffers) {
+ return BertQuestionAnswerer.initJniWithAlbertByteBuffers(buffers);
+ }
+ },
+ BERT_QUESTION_ANSWERER_NATIVE_LIBNAME,
+ pathToModel,
+ pathToSentencePieceModel));
+ }
+
+ @Override
+ public List<QaAnswer> answer(String context, String question) {
+ checkNotClosed();
+ return answerNative(getNativeHandle(), context, question);
+ }
+
+ // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is vocab file buffer.
+ private static native long initJniWithBertByteBuffers(ByteBuffer... modelBuffers);
+
+ // modelBuffers[0] is tflite model file buffer, and modelBuffers[1] is sentencepiece model file
+ // buffer.
+ private static native long initJniWithAlbertByteBuffers(ByteBuffer... modelBuffers);
+
+ // modelBuffers[0] is tflite model file buffer with metadata to specify which tokenizer to use.
+ private static native long initJniWithModelWithMetadataByteBuffers(ByteBuffer... modelBuffers);
+
+ private static native long initJniWithFileDescriptor(int fd);
+
+ private static native List<QaAnswer> answerNative(
+ long nativeHandle, String context, String question);
+
+ @Override
+ protected void deinit(long nativeHandle) {
+ deinitJni(nativeHandle);
+ }
+
+ /**
+ * Native implementation to release memory pointed by the pointer.
+ *
+ * @param nativeHandle pointer to memory allocated
+ */
+ private native void deinitJni(long nativeHandle);
+}