aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java')
-rw-r--r--tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java74
1 files changed, 74 insertions, 0 deletions
diff --git a/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
new file mode 100644
index 00000000..c2de8c0b
--- /dev/null
+++ b/tensorflow_lite_support/java/src/java/org/tensorflow/lite/support/label/ops/LabelAxisOp.java
@@ -0,0 +1,74 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite.support.label.ops;
+
+import android.content.Context;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.tensorflow.lite.support.common.FileUtil;
+import org.tensorflow.lite.support.common.SupportPreconditions;
+import org.tensorflow.lite.support.label.TensorLabel;
+import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
+
+/**
+ * Labels TensorBuffer with axisLabels for outputs.
+ *
+ * <p>Apply on a {@code TensorBuffer} to get a {@code TensorLabel} that could output a Map, which is
+ * a pair of the label name and the corresponding TensorBuffer value.
+ */
+public class LabelAxisOp {
+ // Axis and its corresponding label names.
+ private final Map<Integer, List<String>> axisLabels;
+
+ protected LabelAxisOp(Builder builder) {
+ axisLabels = builder.axisLabels;
+ }
+
+ public TensorLabel apply(@NonNull TensorBuffer buffer) {
+ SupportPreconditions.checkNotNull(buffer, "Tensor buffer cannot be null.");
+ return new TensorLabel(axisLabels, buffer);
+ }
+
+ /** The inner builder class to build a LabelTensor Operator. */
+ public static class Builder {
+ private final Map<Integer, List<String>> axisLabels;
+
+ protected Builder() {
+ axisLabels = new HashMap<>();
+ }
+
+ public Builder addAxisLabel(@NonNull Context context, int axis, @NonNull String filePath)
+ throws IOException {
+ SupportPreconditions.checkNotNull(context, "Context cannot be null.");
+ SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
+ List<String> labels = FileUtil.loadLabels(context, filePath);
+ axisLabels.put(axis, labels);
+ return this;
+ }
+
+ public Builder addAxisLabel(int axis, @NonNull List<String> labels) {
+ axisLabels.put(axis, labels);
+ return this;
+ }
+
+ public LabelAxisOp build() {
+ return new LabelAxisOp(this);
+ }
+ }
+}