diff options
Diffstat (limited to 'android/view/textclassifier/EntityConfidence.java')
-rw-r--r-- | android/view/textclassifier/EntityConfidence.java | 76 |
1 files changed, 58 insertions, 18 deletions
diff --git a/android/view/textclassifier/EntityConfidence.java b/android/view/textclassifier/EntityConfidence.java index 19660d95..69a59a5b 100644 --- a/android/view/textclassifier/EntityConfidence.java +++ b/android/view/textclassifier/EntityConfidence.java @@ -18,6 +18,8 @@ package android.view.textclassifier; import android.annotation.FloatRange; import android.annotation.NonNull; +import android.os.Parcel; +import android.os.Parcelable; import android.util.ArrayMap; import com.android.internal.util.Preconditions; @@ -30,17 +32,16 @@ import java.util.Map; /** * Helper object for setting and getting entity scores for classified text. * - * @param <T> the entity type. * @hide */ -final class EntityConfidence<T> { +final class EntityConfidence implements Parcelable { - private final ArrayMap<T, Float> mEntityConfidence = new ArrayMap<>(); - private final ArrayList<T> mSortedEntities = new ArrayList<>(); + private final ArrayMap<String, Float> mEntityConfidence = new ArrayMap<>(); + private final ArrayList<String> mSortedEntities = new ArrayList<>(); EntityConfidence() {} - EntityConfidence(@NonNull EntityConfidence<T> source) { + EntityConfidence(@NonNull EntityConfidence source) { Preconditions.checkNotNull(source); mEntityConfidence.putAll(source.mEntityConfidence); mSortedEntities.addAll(source.mSortedEntities); @@ -54,24 +55,16 @@ final class EntityConfidence<T> { * @param source a map from entity to a confidence value in the range 0 (low confidence) to * 1 (high confidence). */ - EntityConfidence(@NonNull Map<T, Float> source) { + EntityConfidence(@NonNull Map<String, Float> source) { Preconditions.checkNotNull(source); // Prune non-existent entities and clamp to 1. mEntityConfidence.ensureCapacity(source.size()); - for (Map.Entry<T, Float> it : source.entrySet()) { + for (Map.Entry<String, Float> it : source.entrySet()) { if (it.getValue() <= 0) continue; mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue())); } - - // Create a list of entities sorted by decreasing confidence for getEntities(). - mSortedEntities.ensureCapacity(mEntityConfidence.size()); - mSortedEntities.addAll(mEntityConfidence.keySet()); - mSortedEntities.sort((e1, e2) -> { - float score1 = mEntityConfidence.get(e1); - float score2 = mEntityConfidence.get(e2); - return Float.compare(score2, score1); - }); + resetSortedEntitiesFromMap(); } /** @@ -79,7 +72,7 @@ final class EntityConfidence<T> { * high confidence to low confidence. */ @NonNull - public List<T> getEntities() { + public List<String> getEntities() { return Collections.unmodifiableList(mSortedEntities); } @@ -89,7 +82,7 @@ final class EntityConfidence<T> { * classified text. */ @FloatRange(from = 0.0, to = 1.0) - public float getConfidenceScore(T entity) { + public float getConfidenceScore(String entity) { if (mEntityConfidence.containsKey(entity)) { return mEntityConfidence.get(entity); } @@ -100,4 +93,51 @@ final class EntityConfidence<T> { public String toString() { return mEntityConfidence.toString(); } + + @Override + public int describeContents() { + return 0; + } + + @Override + public void writeToParcel(Parcel dest, int flags) { + dest.writeInt(mEntityConfidence.size()); + for (Map.Entry<String, Float> entry : mEntityConfidence.entrySet()) { + dest.writeString(entry.getKey()); + dest.writeFloat(entry.getValue()); + } + } + + public static final Parcelable.Creator<EntityConfidence> CREATOR = + new Parcelable.Creator<EntityConfidence>() { + @Override + public EntityConfidence createFromParcel(Parcel in) { + return new EntityConfidence(in); + } + + @Override + public EntityConfidence[] newArray(int size) { + return new EntityConfidence[size]; + } + }; + + private EntityConfidence(Parcel in) { + final int numEntities = in.readInt(); + mEntityConfidence.ensureCapacity(numEntities); + for (int i = 0; i < numEntities; ++i) { + mEntityConfidence.put(in.readString(), in.readFloat()); + } + resetSortedEntitiesFromMap(); + } + + private void resetSortedEntitiesFromMap() { + mSortedEntities.clear(); + mSortedEntities.ensureCapacity(mEntityConfidence.size()); + mSortedEntities.addAll(mEntityConfidence.keySet()); + mSortedEntities.sort((e1, e2) -> { + float score1 = mEntityConfidence.get(e1); + float score2 = mEntityConfidence.get(e2); + return Float.compare(score2, score1); + }); + } } |