aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/vision/segmenter/OutputType.java
blob: 03d82c6df4dbfaec5a1370a99a35121da578c446 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow.lite.task.vision.segmenter;

import static org.tensorflow.lite.DataType.FLOAT32;
import static org.tensorflow.lite.DataType.UINT8;
import static org.tensorflow.lite.support.common.SupportPreconditions.checkArgument;
import static org.tensorflow.lite.support.image.ColorSpaceType.GRAYSCALE;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.lite.support.image.TensorImage;
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;

/**
 * Output mask type. This allows specifying the type of post-processing to perform on the raw model
 * results.
 */
public enum OutputType {

  /**
   * Gives a single output mask where each pixel represents the class which the pixel in the
   * original image was predicted to belong to.
   */
  CATEGORY_MASK(0) {
    /**
     * {@inheritDoc}
     *
     * @throws IllegalArgumentException if more than one {@link TensorImage} are provided, or if the
     *     color space of the {@link TensorImage} is not {@link ColorSpaceType#GRAYSCALE}
     */
    @Override
    void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
      checkArgument(
          masks.size() == 1,
          "CATRGORY_MASK only allows one TensorImage in the list, providing " + masks.size());

      TensorImage mask = masks.get(0);
      checkArgument(
          mask.getColorSpaceType() == GRAYSCALE,
          "CATRGORY_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
              + mask.getColorSpaceType());
    }

    /**
     * {@inheritDoc}
     *
     * @throws IllegalArgumentException if more than one {@link ByteBuffer} are provided in the list
     */
    @Override
    List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
      checkArgument(
          buffers.size() == 1,
          "CATRGORY_MASK only allows one mask in the buffer list, providing " + buffers.size());

      List<TensorImage> masks = new ArrayList<>();
      TensorBuffer tensorBuffer = TensorBuffer.createDynamic(UINT8);
      tensorBuffer.loadBuffer(buffers.get(0), maskShape);
      TensorImage tensorImage = new TensorImage(UINT8);
      tensorImage.load(tensorBuffer, GRAYSCALE);
      masks.add(tensorImage);

      return masks;
    }
  },

  /**
   * Gives a list of output masks where, for each mask, each pixel represents the prediction
   * confidence, usually in the [0, 1] range.
   */
  CONFIDENCE_MASK(1) {
    /**
     * {@inheritDoc}
     *
     * @throws IllegalArgumentException if more the size of the masks list does not match the size
     *     of the coloredlabels list, or if the color space type of the any mask is not {@link
     *     ColorSpaceType#GRAYSCALE}
     */
    @Override
    void assertMasksMatchColoredLabels(List<TensorImage> masks, List<ColoredLabel> coloredLabels) {
      checkArgument(
          masks.size() == coloredLabels.size(),
          String.format(
              "When using CONFIDENCE_MASK, the number of masks (%d) should match the number of"
                  + " coloredLabels (%d).",
              masks.size(), coloredLabels.size()));

      for (TensorImage mask : masks) {
        checkArgument(
            mask.getColorSpaceType() == GRAYSCALE,
            "CONFIDENCE_MASK only supports masks of ColorSpaceType, GRAYSCALE, providing "
                + mask.getColorSpaceType());
      }
    }

    @Override
    List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape) {
      List<TensorImage> masks = new ArrayList<>();
      for (ByteBuffer buffer : buffers) {
        TensorBuffer tensorBuffer = TensorBuffer.createDynamic(FLOAT32);
        tensorBuffer.loadBuffer(buffer, maskShape);
        TensorImage tensorImage = new TensorImage(FLOAT32);
        tensorImage.load(tensorBuffer, GRAYSCALE);
        masks.add(tensorImage);
      }
      return masks;
    }
  };

  public int getValue() {
    return value;
  }

  /**
   * Verifies that the given list of masks matches the list of colored labels.
   *
   * @throws IllegalArgumentException if {@code masks} and {@code coloredLabels} do not match the
   *     output type
   */
  abstract void assertMasksMatchColoredLabels(
      List<TensorImage> masks, List<ColoredLabel> coloredLabels);

  /** Creates the masks in {@link TensorImage} based on the data in {@link ByteBuffer}. */
  abstract List<TensorImage> createMasksFromBuffer(List<ByteBuffer> buffers, int[] maskShape);

  private final int value;

  private OutputType(int value) {
    this.value = value;
  }
}