summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-02-05 13:02:37 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2022-02-05 13:02:37 +0000
commitecbfc3f0e29e0d4b4a6e5c7679df5a3bc05c6a5d (patch)
tree7ff8d841461eef3d6a4fc61221dab93ab7eb2fe7
parent6b4ecf498c9d82ef8ba29729670c10c5fb4a710b (diff)
parentfaf03992dc4e169d214b17726a82f664efd6b57a (diff)
downloadlibtextclassifier-android12-mainline-media-swcodec-release.tar.gz
Snap for 8152310 from faf03992dc4e169d214b17726a82f664efd6b57a to mainline-media-swcodec-releaseandroid-mainline-12.0.0_r91android12-mainline-media-swcodec-release
Change-Id: I0ba3ed5e29e26af2034e05fc9e46dac96bfa4225
-rw-r--r--java/src/com/android/textclassifier/common/logging/ResultIdUtils.java4
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java8
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java5
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java7
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java6
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java6
-rw-r--r--java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java8
-rw-r--r--native/actions/actions-entity-data.bfbsbin880 -> 888 bytes
-rw-r--r--native/actions/actions-entity-data.fbs2
-rw-r--r--native/actions/actions-suggestions.cc29
-rw-r--r--native/actions/actions-suggestions.h4
-rw-r--r--native/actions/actions-suggestions_test.cc29
-rw-r--r--native/actions/actions_model.fbs16
-rw-r--r--native/actions/ranker.cc74
-rw-r--r--native/actions/ranker_test.cc96
-rw-r--r--native/actions/test_data/actions_suggestions_grammar_test.modelbin145160 -> 145176 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.modelbin3387328 -> 3387360 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_9heads.modelbin3874528 -> 3874704 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.modelbin3808528 -> 3812304 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.modelbin3853520 -> 3853520 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.modelbin4671808 -> 4671840 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.modelbin5045280 -> 5045408 bytes
-rw-r--r--native/actions/test_data/actions_suggestions_test.sensitive_tflite.modelbin7111552 -> 7111552 bytes
-rw-r--r--native/annotator/annotator.cc79
-rw-r--r--native/annotator/datetime/datetime-grounder.cc9
-rw-r--r--native/annotator/datetime/extractor.cc11
-rw-r--r--native/annotator/datetime/regex-parser.cc31
-rw-r--r--native/annotator/translate/translate.cc11
-rw-r--r--native/lang_id/common/embedding-network.cc2
-rw-r--r--native/lang_id/common/fel/feature-extractor.cc1
-rw-r--r--native/lang_id/common/fel/workspace.cc1
-rw-r--r--native/lang_id/common/fel/workspace.h1
-rw-r--r--native/lang_id/common/file/mmap.cc2
-rw-r--r--native/lang_id/common/lite_strings/str-split.cc2
-rw-r--r--native/lang_id/common/math/softmax.cc1
-rw-r--r--native/lang_id/fb_model/lang-id-from-fb.cc2
-rw-r--r--native/lang_id/fb_model/model-provider-from-fb.cc2
-rw-r--r--native/lang_id/lang-id.cc1
-rw-r--r--native/utils/codepoint-range.cc9
-rw-r--r--native/utils/grammar/parsing/parser.cc17
-rw-r--r--native/utils/grammar/utils/ir.cc35
-rw-r--r--native/utils/grammar/utils/locale-shard-map.cc4
-rw-r--r--native/utils/testing/test_data_generator.h13
-rw-r--r--native/utils/tflite-model-executor.cc14
-rw-r--r--native/utils/tflite/encoder_common.cc11
-rw-r--r--native/utils/tflite/text_encoder3s.cc243
-rw-r--r--native/utils/tflite/text_encoder3s.h35
-rw-r--r--native/utils/tokenfree/byte_encoder.cc42
-rw-r--r--native/utils/tokenfree/byte_encoder.h37
-rw-r--r--native/utils/tokenfree/byte_encoder_test.cc51
-rw-r--r--native/utils/tokenizer.cc11
-rw-r--r--notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java8
52 files changed, 825 insertions, 155 deletions
diff --git a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
index dae0442..67e300d 100644
--- a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
+++ b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java
@@ -66,8 +66,8 @@ public final class ResultIdUtils {
}
/** Returns if the result id was generated from the default text classifier. */
- public static boolean isFromDefaultTextClassifier(String resultId) {
- return resultId.startsWith(CLASSIFIER_ID + '|');
+ public static boolean isFromDefaultTextClassifier(@Nullable String resultId) {
+ return resultId != null && resultId.startsWith(CLASSIFIER_ID + '|');
}
/** Returns all the model names encoded in the signature. */
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 0e3842c..71f9a4f 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -52,16 +52,21 @@ import java.util.concurrent.Executor;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.Mockito;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class DefaultTextClassifierServiceTest {
+
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
/** A statsd config ID, which is arbitrary. */
private static final long CONFIG_ID = 689777;
@@ -79,7 +84,6 @@ public class DefaultTextClassifierServiceTest {
@Before
public void setup() {
- MockitoAnnotations.initMocks(this);
testInjector = new TestInjector(ApplicationProvider.getApplicationContext());
defaultTextClassifierService = new DefaultTextClassifierService(testInjector);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
index 5297640..20ae592 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java
@@ -53,7 +53,8 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@SmallTest
@RunWith(AndroidJUnit4.class)
@@ -67,6 +68,7 @@ public final class ModelFileManagerImplTest {
@Mock private DownloadedModelManager downloadedModelManager;
@Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
private File rootTestDir;
private ModelFileManagerImpl modelFileManager;
@@ -75,7 +77,6 @@ public final class ModelFileManagerImplTest {
@Before
public void setup() {
- MockitoAnnotations.initMocks(this);
deviceConfig = new TestingDeviceConfig();
rootTestDir =
new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
index 216cd5d..3aab211 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java
@@ -27,14 +27,18 @@ import com.google.android.textclassifier.NamedVariant;
import com.google.android.textclassifier.RemoteActionTemplate;
import java.util.List;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@SmallTest
@RunWith(AndroidJUnit4.class)
public class TemplateIntentFactoryTest {
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
private static final String TITLE_WITHOUT_ENTITY = "Map";
private static final String TITLE_WITH_ENTITY = "Map NW14D1";
private static final String DESCRIPTION = "Check the map";
@@ -71,7 +75,6 @@ public class TemplateIntentFactoryTest {
@Before
public void setup() {
- MockitoAnnotations.initMocks(this);
templateIntentFactory = new TemplateIntentFactory();
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
index ffd2ee4..3a8fefc 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java
@@ -86,8 +86,8 @@ public class StatsdTestUtils {
return ImmutableList.copyOf(
metricsList.stream()
.flatMap(statsLogReport -> statsLogReport.getEventMetrics().getDataList().stream())
- .flatMap(eventMetricData -> backfillAggregatedAtomsinEventMetric(
- eventMetricData).stream())
+ .flatMap(
+ eventMetricData -> backfillAggregatedAtomsinEventMetric(eventMetricData).stream())
.sorted(Comparator.comparing(EventMetricData::getElapsedTimestampNanos))
.map(EventMetricData::getAtom)
.collect(Collectors.toList()));
@@ -136,7 +136,7 @@ public class StatsdTestUtils {
}
private static ImmutableList<EventMetricData> backfillAggregatedAtomsinEventMetric(
- EventMetricData metricData) {
+ EventMetricData metricData) {
if (metricData.hasAtom()) {
return ImmutableList.of(metricData);
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
index c626ed7..394b7ad 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
@@ -46,7 +46,8 @@ import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@RunWith(AndroidJUnit4.class)
public final class ModelDownloadManagerTest {
@@ -61,6 +62,8 @@ public final class ModelDownloadManagerTest {
public final TextClassifierDownloadLoggerTestRule loggerTestRule =
new TextClassifierDownloadLoggerTestRule();
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
private TestingDeviceConfig deviceConfig;
private WorkManager workManager;
private ModelDownloadManager downloadManager;
@@ -68,7 +71,6 @@ public final class ModelDownloadManagerTest {
@Before
public void setUp() {
- MockitoAnnotations.initMocks(this);
Context context = ApplicationProvider.getApplicationContext();
WorkManagerTestInitHelper.initializeTestWorkManager(context);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
index eac2af3..76d04e0 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java
@@ -37,14 +37,19 @@ import com.google.common.util.concurrent.SettableFuture;
import java.io.File;
import java.net.URI;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@RunWith(JUnit4.class)
public final class ModelDownloaderServiceImplTest {
+
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
private static final long BYTES_WRITTEN = 1L;
private static final String DOWNLOAD_URI =
"https://www.gstatic.com/android/text_classifier/r/v999/en.fb";
@@ -66,7 +71,6 @@ public final class ModelDownloaderServiceImplTest {
@Before
public void setUp() {
- MockitoAnnotations.initMocks(this);
this.targetModelFile =
new File(ApplicationProvider.getApplicationContext().getCacheDir(), "model.fb");
diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs
index 7421579..6ebf1cf 100644
--- a/native/actions/actions-entity-data.bfbs
+++ b/native/actions/actions-entity-data.bfbs
Binary files differ
diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs
index 21584b6..e906f93 100644
--- a/native/actions/actions-entity-data.fbs
+++ b/native/actions/actions-entity-data.fbs
@@ -18,7 +18,7 @@
namespace libtextclassifier3;
table ActionsEntityData {
// Extracted text.
- text:string (shared);
+ text:string (key, shared);
}
root_type libtextclassifier3.ActionsEntityData;
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index b1a042c..9f9a8d4 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -17,6 +17,7 @@
#include "actions/actions-suggestions.h"
#include <memory>
+#include <string>
#include <vector>
#include "utils/base/statusor.h"
@@ -40,6 +41,7 @@
#include "utils/strings/stringpiece.h"
#include "utils/strings/utf8.h"
#include "utils/utf8/unicodetext.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/lite/string_util.h"
namespace libtextclassifier3 {
@@ -809,12 +811,14 @@ bool ActionsSuggestions::SetupModelInput(
void ActionsSuggestions::PopulateTextReplies(
const tflite::Interpreter* interpreter, int suggestion_index,
- int score_index, const std::string& type,
+ int score_index, const std::string& type, float priority_score,
+ const absl::flat_hash_set<std::string>& blocklist,
ActionsSuggestionsResponse* response) const {
const std::vector<tflite::StringRef> replies =
model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
const TensorView<float> scores =
model_executor_->OutputView<float>(score_index, interpreter);
+
for (int i = 0; i < replies.size(); i++) {
if (replies[i].len == 0) {
continue;
@@ -823,8 +827,12 @@ void ActionsSuggestions::PopulateTextReplies(
if (score < preconditions_.min_reply_score_threshold) {
continue;
}
- response->actions.push_back(
- {std::string(replies[i].str, replies[i].len), type, score});
+ std::string response_text(replies[i].str, replies[i].len);
+ if (blocklist.contains(response_text)) {
+ continue;
+ }
+
+ response->actions.push_back({response_text, type, score, priority_score});
}
}
@@ -909,10 +917,12 @@ bool ActionsSuggestions::ReadModelOutput(
// Read smart reply predictions.
if (!response->output_filtered_min_triggering_score &&
model_->tflite_model_spec()->output_replies() >= 0) {
+ absl::flat_hash_set<std::string> empty_blocklist;
PopulateTextReplies(interpreter,
model_->tflite_model_spec()->output_replies(),
model_->tflite_model_spec()->output_replies_scores(),
- model_->smart_reply_action_type()->str(), response);
+ model_->smart_reply_action_type()->str(),
+ /* priority_score */ 0.0, empty_blocklist, response);
}
// Read actions suggestions.
@@ -950,17 +960,26 @@ bool ActionsSuggestions::ReadModelOutput(
const int suggestions_index = metadata->output_suggestions();
const int suggestions_scores_index =
metadata->output_suggestions_scores();
+ absl::flat_hash_set<std::string> response_text_blocklist;
switch (metadata->prediction_type()) {
case PredictionType_NEXT_MESSAGE_PREDICTION:
if (!task_spec || task_spec->type()->size() == 0) {
TC3_LOG(WARNING) << "Task type not provided, use default "
"smart_reply_action_type!";
}
+ if (task_spec) {
+ if (task_spec->response_text_blocklist()) {
+ for (const auto& val : *task_spec->response_text_blocklist()) {
+ response_text_blocklist.insert(val->str());
+ }
+ }
+ }
PopulateTextReplies(
interpreter, suggestions_index, suggestions_scores_index,
task_spec ? task_spec->type()->str()
: model_->smart_reply_action_type()->str(),
- response);
+ task_spec ? task_spec->priority_score() : 0.0,
+ response_text_blocklist, response);
break;
case PredictionType_INTENT_TRIGGERING:
PopulateIntentTriggering(interpreter, suggestions_index,
diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h
index 32edc78..87f55fb 100644
--- a/native/actions/actions-suggestions.h
+++ b/native/actions/actions-suggestions.h
@@ -43,6 +43,7 @@
#include "utils/utf8/unilib.h"
#include "utils/variant.h"
#include "utils/zlib/zlib.h"
+#include "absl/container/flat_hash_set.h"
namespace libtextclassifier3 {
@@ -176,7 +177,8 @@ class ActionsSuggestions {
void PopulateTextReplies(const tflite::Interpreter* interpreter,
int suggestion_index, int score_index,
- const std::string& type,
+ const std::string& type, float priority_score,
+ const absl::flat_hash_set<std::string>& blocklist,
ActionsSuggestionsResponse* response) const;
void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 062d527..b51ebc7 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -1798,6 +1798,7 @@ TEST_F(ActionsSuggestionsTest,
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kMultiTaskSrEmojiModelFileName);
+
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "hello?",
@@ -1807,9 +1808,31 @@ TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
/*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 5);
EXPECT_EQ(response.actions[0].response_text, "😁");
- EXPECT_EQ(response.actions[0].type, "EMOJI_CONCEPT");
- EXPECT_EQ(response.actions[1].response_text, "Yes");
- EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION");
+ EXPECT_EQ(response.actions[0].type, "text_reply");
+ EXPECT_EQ(response.actions[1].response_text, "👋");
+ EXPECT_EQ(response.actions[1].type, "text_reply");
+ EXPECT_EQ(response.actions[2].response_text, "Yes");
+ EXPECT_EQ(response.actions[2].type, "text_reply");
+}
+
+TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kMultiTaskSrEmojiModelFileName);
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "a pleasure chatting",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{},
+ /*locales=*/"en"}}});
+ EXPECT_EQ(response.actions.size(), 3);
+ EXPECT_EQ(response.actions[0].response_text, "😁");
+ EXPECT_EQ(response.actions[0].type, "text_reply");
+ EXPECT_EQ(response.actions[1].response_text, "😘");
+ EXPECT_EQ(response.actions[1].type, "text_reply");
+ EXPECT_EQ(response.actions[2].response_text, "Okay");
+ EXPECT_EQ(response.actions[2].type, "text_reply");
}
TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 8c03eeb..0d8c7ad 100644
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -36,6 +36,17 @@ enum PredictionType : int {
ENTITY_ANNOTATION = 3,
}
+namespace libtextclassifier3;
+enum RankingOptionsSortType : int {
+ SORT_TYPE_UNSPECIFIED = 0,
+
+ // Rank results (or groups) by score, then type
+ SORT_TYPE_SCORE = 1,
+
+ // Rank results (or groups) by priority score, then score, then type
+ SORT_TYPE_PRIORITY_SCORE = 2,
+}
+
// Prediction metadata for an arbitrary task.
namespace libtextclassifier3;
table PredictionMetadata {
@@ -315,10 +326,11 @@ table ActionSuggestionSpec {
// Additional entity information.
serialized_entity_data:string (shared);
- // Priority score used for internal conflict resolution.
+ // For ranking and internal conflict resolution.
priority_score:float = 0;
entity_data:ActionsEntityData;
+ response_text_blocklist:[string];
}
// Options to specify triggering behaviour per action class.
@@ -416,6 +428,8 @@ table RankingOptions {
// If true, keep actions from the same entities together for ranking.
group_by_annotations:bool = true;
+
+ sort_type:RankingOptionsSortType = SORT_TYPE_SCORE;
}
// Entity data to set from capturing groups.
diff --git a/native/actions/ranker.cc b/native/actions/ranker.cc
index d52ecaa..46e392a 100644
--- a/native/actions/ranker.cc
+++ b/native/actions/ranker.cc
@@ -20,6 +20,8 @@
#include <set>
#include <vector>
+#include "actions/actions_model_generated.h"
+
#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-ranker.h"
#endif
@@ -34,11 +36,22 @@ namespace libtextclassifier3 {
namespace {
void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
- std::sort(actions->begin(), actions->end(),
- [](const ActionSuggestion& a, const ActionSuggestion& b) {
- return a.score > b.score ||
- (a.score >= b.score && a.type < b.type);
- });
+ std::stable_sort(actions->begin(), actions->end(),
+ [](const ActionSuggestion& a, const ActionSuggestion& b) {
+ return a.score > b.score ||
+ (a.score >= b.score && a.type < b.type);
+ });
+}
+
+void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) {
+ std::stable_sort(
+ actions->begin(), actions->end(),
+ [](const ActionSuggestion& a, const ActionSuggestion& b) {
+ return a.priority_score > b.priority_score ||
+ (a.priority_score >= b.priority_score && a.score > b.score) ||
+ (a.priority_score >= b.priority_score && a.score >= b.score &&
+ a.type < b.type);
+ });
}
template <typename T>
@@ -241,13 +254,8 @@ bool ActionsSuggestionsRanker::RankActions(
const reflection::Schema* annotations_entity_data_schema) const {
if (options_->deduplicate_suggestions() ||
options_->deduplicate_suggestions_by_span()) {
- // First order suggestions by priority score for deduplication.
- std::sort(
- response->actions.begin(), response->actions.end(),
- [](const ActionSuggestion& a, const ActionSuggestion& b) {
- return a.priority_score > b.priority_score ||
- (a.priority_score >= b.priority_score && a.score > b.score);
- });
+ // Order suggestions by [priority score -> score] for deduplication
+ SortByPriorityAndScoreAndType(&response->actions);
// Deduplicate, keeping the higher score actions.
if (options_->deduplicate_suggestions()) {
@@ -275,6 +283,8 @@ bool ActionsSuggestionsRanker::RankActions(
}
}
+ bool sort_by_priority =
+ options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
// Suppress smart replies if actions are present.
if (options_->suppress_smart_replies_with_actions()) {
std::vector<ActionSuggestion> non_smart_reply_actions;
@@ -316,17 +326,35 @@ bool ActionsSuggestionsRanker::RankActions(
// Sort within each group by score.
for (std::vector<ActionSuggestion>& group : groups) {
- SortByScoreAndType(&group);
+ if (sort_by_priority) {
+ SortByPriorityAndScoreAndType(&group);
+ } else {
+ SortByScoreAndType(&group);
+ }
}
- // Sort groups by maximum score.
- std::sort(groups.begin(), groups.end(),
- [](const std::vector<ActionSuggestion>& a,
- const std::vector<ActionSuggestion>& b) {
- return a.begin()->score > b.begin()->score ||
- (a.begin()->score >= b.begin()->score &&
- a.begin()->type < b.begin()->type);
- });
+ // Sort groups by maximum score or priority score.
+ if (sort_by_priority) {
+ std::stable_sort(
+ groups.begin(), groups.end(),
+ [](const std::vector<ActionSuggestion>& a,
+ const std::vector<ActionSuggestion>& b) {
+ return (a.begin()->priority_score > b.begin()->priority_score) ||
+ (a.begin()->priority_score >= b.begin()->priority_score &&
+ a.begin()->score > b.begin()->score) ||
+ (a.begin()->priority_score >= b.begin()->priority_score &&
+ a.begin()->score >= b.begin()->score &&
+ a.begin()->type < b.begin()->type);
+ });
+ } else {
+ std::stable_sort(groups.begin(), groups.end(),
+ [](const std::vector<ActionSuggestion>& a,
+ const std::vector<ActionSuggestion>& b) {
+ return a.begin()->score > b.begin()->score ||
+ (a.begin()->score >= b.begin()->score &&
+ a.begin()->type < b.begin()->type);
+ });
+ }
// Flatten result.
const size_t num_actions = response->actions.size();
@@ -336,9 +364,9 @@ bool ActionsSuggestionsRanker::RankActions(
response->actions.insert(response->actions.end(), actions.begin(),
actions.end());
}
-
+ } else if (sort_by_priority) {
+ SortByPriorityAndScoreAndType(&response->actions);
} else {
- // Order suggestions independently by score.
SortByScoreAndType(&response->actions);
}
diff --git a/native/actions/ranker_test.cc b/native/actions/ranker_test.cc
index b52cf45..5eba45f 100644
--- a/native/actions/ranker_test.cc
+++ b/native/actions/ranker_test.cc
@@ -18,6 +18,7 @@
#include <string>
+#include "actions/actions_model_generated.h"
#include "actions/types.h"
#include "utils/zlib/zlib.h"
#include "gmock/gmock.h"
@@ -308,12 +309,12 @@ TEST(RankingTest, GroupsActionsByAnnotations) {
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
- /*priority_score=*/1.0,
+ /*priority_score=*/0.0,
/*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact",
/*score=*/0.0,
- /*priority_score=*/0.0,
+ /*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
response.actions.push_back({/*response_text=*/"How are you?",
@@ -338,23 +339,75 @@ TEST(RankingTest, GroupsActionsByAnnotations) {
IsAction("text_reply", "How are you?", 0.5)}));
}
-TEST(RankingTest, SortsActionsByScore) {
+TEST(RankingTest, GroupsByAnnotationsSortedByPriority) {
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
ActionsSuggestionsResponse response;
+ response.actions.push_back({/*response_text=*/"How are you?",
+ /*type=*/"text_reply",
+ /*score=*/2.0,
+ /*priority_score=*/0.0});
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
/*text=*/"911"};
annotation.entity = ClassificationResult("phone", 1.0);
response.actions.push_back({/*response_text=*/"",
+ /*type=*/"add_contact",
+ /*score=*/0.0,
+ /*priority_score=*/1.0,
+ /*annotations=*/{annotation}});
+ response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/{annotation}});
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"add_contact2",
+ /*score=*/0.5,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
+ }
+ RankingOptionsT options;
+ options.group_by_annotations = true;
+ options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+
+ // The text reply should be last, even though it's score is higher than
+ // any other scores -- because it's priority_score is lower than the max
+ // of those with the 'phone' annotation
+ EXPECT_THAT(response.actions,
+ testing::ElementsAreArray({
+ // Group 1 (Phone annotation)
+ IsAction("add_contact2", "", 0.5), // priority_score=1.0
+ IsAction("add_contact", "", 0.0), // priority_score=1.0
+ IsAction("call_phone", "", 1.0), // priority_score=0.0
+ IsAction("text_reply", "How are you?", 2.0), // Group 2
+ }));
+}
+
+TEST(RankingTest, SortsActionsByScore) {
+ const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
+ ActionsSuggestionsResponse response;
+ {
+ ActionSuggestionAnnotation annotation;
+ annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
+ /*text=*/"911"};
+ annotation.entity = ClassificationResult("phone", 1.0);
+ response.actions.push_back({/*response_text=*/"",
+ /*type=*/"call_phone",
+ /*score=*/1.0,
+ /*priority_score=*/0.0,
+ /*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact",
/*score=*/0.0,
- /*priority_score=*/0.0,
+ /*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
response.actions.push_back({/*response_text=*/"How are you?",
@@ -378,5 +431,40 @@ TEST(RankingTest, SortsActionsByScore) {
IsAction("add_contact", "", 0.0)}));
}
+TEST(RankingTest, SortsActionsByPriority) {
+ const Conversation conversation = {{{/*user_id=*/1, "hello?"}}};
+ ActionsSuggestionsResponse response;
+ // emoji replies given higher priority_score
+ response.actions.push_back({/*response_text=*/"😁",
+ /*type=*/"text_reply",
+ /*score=*/0.5,
+ /*priority_score=*/1.0});
+ response.actions.push_back({/*response_text=*/"👋",
+ /*type=*/"text_reply",
+ /*score=*/0.4,
+ /*priority_score=*/1.0});
+ response.actions.push_back({/*response_text=*/"Yes",
+ /*type=*/"text_reply",
+ /*score=*/1.0,
+ /*priority_score=*/0.0});
+ RankingOptionsT options;
+ // Don't group by annotation.
+ options.group_by_annotations = false;
+ options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(RankingOptions::Pack(builder, &options));
+ auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
+ flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
+ /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
+
+ ranker->RankActions(conversation, &response);
+
+ EXPECT_THAT(response.actions, testing::ElementsAreArray(
+ {IsAction("text_reply", "😁", 0.5),
+ IsAction("text_reply", "👋", 0.4),
+ // Ranked last because of priority score
+ IsAction("text_reply", "Yes", 1.0)}));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index 77e556c..0fa7f7e 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index c468bd5..6107e98 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index ec421a1..436ed93 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
index 24be6c6..935691d 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index fd7ddf2..2c9f74b 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
index c969c56..cdb7523 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
index d171898..ac28fa2 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
index 937552b..d864b79 100644
--- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 32bd29c..e0d4241 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -973,11 +973,11 @@ CodepointSpan Annotator::SuggestSelection(
// Sort candidates according to their position in the input, so that the next
// code can assume that any connected component of overlapping spans forms a
// contiguous block.
- std::sort(candidates.annotated_spans[0].begin(),
- candidates.annotated_spans[0].end(),
- [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
- return a.span.first < b.span.first;
- });
+ std::stable_sort(candidates.annotated_spans[0].begin(),
+ candidates.annotated_spans[0].end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ return a.span.first < b.span.first;
+ });
std::vector<int> candidate_indices;
if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens,
@@ -987,13 +987,14 @@ CodepointSpan Annotator::SuggestSelection(
return original_click_indices;
}
- std::sort(candidate_indices.begin(), candidate_indices.end(),
- [this, &candidates](int a, int b) {
- return GetPriorityScore(
- candidates.annotated_spans[0][a].classification) >
- GetPriorityScore(
- candidates.annotated_spans[0][b].classification);
- });
+ std::stable_sort(
+ candidate_indices.begin(), candidate_indices.end(),
+ [this, &candidates](int a, int b) {
+ return GetPriorityScore(
+ candidates.annotated_spans[0][a].classification) >
+ GetPriorityScore(
+ candidates.annotated_spans[0][b].classification);
+ });
for (const int i : candidate_indices) {
if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) &&
@@ -1173,7 +1174,7 @@ bool Annotator::ResolveConflict(
}
}
- std::sort(
+ std::stable_sort(
conflicting_indices.begin(), conflicting_indices.end(),
[this, &scores_lengths, candidates, conflicting_indices](int i, int j) {
if (scores_lengths[i].first == scores_lengths[j].first &&
@@ -1241,7 +1242,7 @@ bool Annotator::ResolveConflict(
chosen_indices_for_source_ptr->insert(considered_candidate);
}
- std::sort(chosen_indices->begin(), chosen_indices->end());
+ std::stable_sort(chosen_indices->begin(), chosen_indices->end());
return true;
}
@@ -1414,10 +1415,11 @@ namespace {
// Sorts the classification results from high score to low score.
void SortClassificationResults(
std::vector<ClassificationResult>* classification_results) {
- std::sort(classification_results->begin(), classification_results->end(),
- [](const ClassificationResult& a, const ClassificationResult& b) {
- return a.score > b.score;
- });
+ std::stable_sort(
+ classification_results->begin(), classification_results->end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
}
} // namespace
@@ -1936,10 +1938,11 @@ std::vector<ClassificationResult> Annotator::ClassifyText(
}
// Sort results according to score.
- std::sort(results.begin(), results.end(),
- [](const ClassificationResult& a, const ClassificationResult& b) {
- return a.score > b.score;
- });
+ std::stable_sort(
+ results.begin(), results.end(),
+ [](const ClassificationResult& a, const ClassificationResult& b) {
+ return a.score > b.score;
+ });
if (results.empty()) {
results = {{Collections::Other(), 1.0}};
@@ -2297,19 +2300,19 @@ Status Annotator::AnnotateSingleInput(
// Also sort them according to the end position and collection, so that the
// deduplication code below can assume that same spans and classifications
// form contiguous blocks.
- std::sort(candidates->begin(), candidates->end(),
- [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
- if (a.span.first != b.span.first) {
- return a.span.first < b.span.first;
- }
+ std::stable_sort(candidates->begin(), candidates->end(),
+ [](const AnnotatedSpan& a, const AnnotatedSpan& b) {
+ if (a.span.first != b.span.first) {
+ return a.span.first < b.span.first;
+ }
- if (a.span.second != b.span.second) {
- return a.span.second < b.span.second;
- }
+ if (a.span.second != b.span.second) {
+ return a.span.second < b.span.second;
+ }
- return a.classification[0].collection <
- b.classification[0].collection;
- });
+ return a.classification[0].collection <
+ b.classification[0].collection;
+ });
std::vector<int> candidate_indices;
if (!ResolveConflicts(*candidates, context, tokens,
@@ -2904,10 +2907,10 @@ bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
return false;
}
}
- std::sort(scored_chunks.rbegin(), scored_chunks.rend(),
- [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
- return lhs.score < rhs.score;
- });
+ std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(),
+ [](const ScoredChunk& lhs, const ScoredChunk& rhs) {
+ return lhs.score < rhs.score;
+ });
// Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick
// them greedily as long as they do not overlap with any previously picked
@@ -2936,7 +2939,7 @@ bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest,
chunks->push_back(scored_chunk.token_span);
}
- std::sort(chunks->begin(), chunks->end());
+ std::stable_sort(chunks->begin(), chunks->end());
return true;
}
diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc
index 7d5f440..ff0c775 100644
--- a/native/annotator/datetime/datetime-grounder.cc
+++ b/native/annotator/datetime/datetime-grounder.cc
@@ -16,6 +16,7 @@
#include "annotator/datetime/datetime-grounder.h"
+#include <algorithm>
#include <limits>
#include <unordered_map>
#include <vector>
@@ -250,10 +251,10 @@ StatusOr<std::vector<DatetimeParseResult>> DatetimeGrounder::Ground(
}
// Sort the date time units by component type.
- std::sort(date_components.begin(), date_components.end(),
- [](DatetimeComponent a, DatetimeComponent b) {
- return a.component_type > b.component_type;
- });
+ std::stable_sort(date_components.begin(), date_components.end(),
+ [](DatetimeComponent a, DatetimeComponent b) {
+ return a.component_type > b.component_type;
+ });
result.datetime_components.swap(date_components);
datetime_parse_result.push_back(result);
}
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
index 867c886..94a0961 100644
--- a/native/annotator/datetime/extractor.cc
+++ b/native/annotator/datetime/extractor.cc
@@ -16,6 +16,8 @@
#include "annotator/datetime/extractor.h"
+#include <algorithm>
+
#include "annotator/datetime/utils.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
@@ -347,10 +349,11 @@ bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input,
}
}
- std::sort(found_numbers.begin(), found_numbers.end(),
- [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
- return a.first < b.first;
- });
+ std::stable_sort(
+ found_numbers.begin(), found_numbers.end(),
+ [](const std::pair<int, int>& a, const std::pair<int, int>& b) {
+ return a.first < b.first;
+ });
int sum = 0;
int running_value = -1;
diff --git a/native/annotator/datetime/regex-parser.cc b/native/annotator/datetime/regex-parser.cc
index 4dc9c56..5daabd5 100644
--- a/native/annotator/datetime/regex-parser.cc
+++ b/native/annotator/datetime/regex-parser.cc
@@ -16,6 +16,7 @@
#include "annotator/datetime/regex-parser.h"
+#include <algorithm>
#include <iterator>
#include <set>
#include <unordered_set>
@@ -191,17 +192,17 @@ StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
// Resolve conflicts by always picking the longer span and breaking ties by
// selecting the earlier entry in the list for a given locale.
- std::sort(indexed_found_spans.begin(), indexed_found_spans.end(),
- [](const std::pair<DatetimeParseResultSpan, int>& a,
- const std::pair<DatetimeParseResultSpan, int>& b) {
- if ((a.first.span.second - a.first.span.first) !=
- (b.first.span.second - b.first.span.first)) {
- return (a.first.span.second - a.first.span.first) >
- (b.first.span.second - b.first.span.first);
- } else {
- return a.second < b.second;
- }
- });
+ std::stable_sort(indexed_found_spans.begin(), indexed_found_spans.end(),
+ [](const std::pair<DatetimeParseResultSpan, int>& a,
+ const std::pair<DatetimeParseResultSpan, int>& b) {
+ if ((a.first.span.second - a.first.span.first) !=
+ (b.first.span.second - b.first.span.first)) {
+ return (a.first.span.second - a.first.span.first) >
+ (b.first.span.second - b.first.span.first);
+ } else {
+ return a.second < b.second;
+ }
+ });
std::vector<DatetimeParseResultSpan> results;
std::vector<DatetimeParseResultSpan> resolved_found_spans;
@@ -394,10 +395,10 @@ bool RegexDatetimeParser::ExtractDatetime(
}
// Sort the date time units by component type.
- std::sort(date_components.begin(), date_components.end(),
- [](DatetimeComponent a, DatetimeComponent b) {
- return a.component_type > b.component_type;
- });
+ std::stable_sort(date_components.begin(), date_components.end(),
+ [](DatetimeComponent a, DatetimeComponent b) {
+ return a.component_type > b.component_type;
+ });
result.datetime_components.swap(date_components);
results->push_back(result);
}
diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc
index 640ceec..2c5a43c 100644
--- a/native/annotator/translate/translate.cc
+++ b/native/annotator/translate/translate.cc
@@ -16,6 +16,7 @@
#include "annotator/translate/translate.h"
+#include <algorithm>
#include <memory>
#include "annotator/collections.h"
@@ -142,11 +143,11 @@ TranslateAnnotator::BackoffDetectLanguages(
result.push_back({key, value});
}
- std::sort(result.begin(), result.end(),
- [](TranslateAnnotator::LanguageConfidence& a,
- TranslateAnnotator::LanguageConfidence& b) {
- return a.confidence > b.confidence;
- });
+ std::stable_sort(result.begin(), result.end(),
+ [](const TranslateAnnotator::LanguageConfidence& a,
+ const TranslateAnnotator::LanguageConfidence& b) {
+ return a.confidence > b.confidence;
+ });
return result;
}
diff --git a/native/lang_id/common/embedding-network.cc b/native/lang_id/common/embedding-network.cc
index 469cb1f..49c9ca0 100644
--- a/native/lang_id/common/embedding-network.cc
+++ b/native/lang_id/common/embedding-network.cc
@@ -16,6 +16,8 @@
#include "lang_id/common/embedding-network.h"
+#include <vector>
+
#include "lang_id/common/lite_base/integral-types.h"
#include "lang_id/common/lite_base/logging.h"
diff --git a/native/lang_id/common/fel/feature-extractor.cc b/native/lang_id/common/fel/feature-extractor.cc
index ab8a1a6..4e304fe 100644
--- a/native/lang_id/common/fel/feature-extractor.cc
+++ b/native/lang_id/common/fel/feature-extractor.cc
@@ -17,6 +17,7 @@
#include "lang_id/common/fel/feature-extractor.h"
#include <string>
+#include <vector>
#include "lang_id/common/fel/feature-types.h"
#include "lang_id/common/fel/fel-parser.h"
diff --git a/native/lang_id/common/fel/workspace.cc b/native/lang_id/common/fel/workspace.cc
index af41e29..60dcc46 100644
--- a/native/lang_id/common/fel/workspace.cc
+++ b/native/lang_id/common/fel/workspace.cc
@@ -18,6 +18,7 @@
#include <atomic>
#include <string>
+#include <vector>
namespace libtextclassifier3 {
namespace mobile {
diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h
index f13d802..2ac5b26 100644
--- a/native/lang_id/common/fel/workspace.h
+++ b/native/lang_id/common/fel/workspace.h
@@ -23,6 +23,7 @@
#include <stddef.h>
+#include <algorithm>
#include <string>
#include <unordered_map>
#include <utility>
diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc
index 19afcc4..fc925ea 100644
--- a/native/lang_id/common/file/mmap.cc
+++ b/native/lang_id/common/file/mmap.cc
@@ -29,6 +29,8 @@
#endif
#include <sys/stat.h>
+#include <string>
+
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_base/macros.h"
diff --git a/native/lang_id/common/lite_strings/str-split.cc b/native/lang_id/common/lite_strings/str-split.cc
index 199bb69..d227eec 100644
--- a/native/lang_id/common/lite_strings/str-split.cc
+++ b/native/lang_id/common/lite_strings/str-split.cc
@@ -16,6 +16,8 @@
#include "lang_id/common/lite_strings/str-split.h"
+#include <vector>
+
namespace libtextclassifier3 {
namespace mobile {
diff --git a/native/lang_id/common/math/softmax.cc b/native/lang_id/common/math/softmax.cc
index 750341d..249ed57 100644
--- a/native/lang_id/common/math/softmax.cc
+++ b/native/lang_id/common/math/softmax.cc
@@ -17,6 +17,7 @@
#include "lang_id/common/math/softmax.h"
#include <algorithm>
+#include <vector>
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/math/fastexp.h"
diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc
index dc36fb7..51c8c47 100644
--- a/native/lang_id/fb_model/lang-id-from-fb.cc
+++ b/native/lang_id/fb_model/lang-id-from-fb.cc
@@ -16,7 +16,9 @@
#include "lang_id/fb_model/lang-id-from-fb.h"
+#include <memory>
#include <string>
+#include <utility>
#include "lang_id/fb_model/model-provider-from-fb.h"
diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc
index 43bf860..d14d403 100644
--- a/native/lang_id/fb_model/model-provider-from-fb.cc
+++ b/native/lang_id/fb_model/model-provider-from-fb.cc
@@ -16,7 +16,9 @@
#include "lang_id/fb_model/model-provider-from-fb.h"
+#include <memory>
#include <string>
+#include <utility>
#include "lang_id/common/file/file-utils.h"
#include "lang_id/common/file/mmap.h"
diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc
index 92359a9..f7c66f7 100644
--- a/native/lang_id/lang-id.cc
+++ b/native/lang_id/lang-id.cc
@@ -21,6 +21,7 @@
#include <memory>
#include <string>
#include <unordered_map>
+#include <utility>
#include <vector>
#include "lang_id/common/embedding-feature-interface.h"
diff --git a/native/utils/codepoint-range.cc b/native/utils/codepoint-range.cc
index e26b160..a4cd485 100644
--- a/native/utils/codepoint-range.cc
+++ b/native/utils/codepoint-range.cc
@@ -31,10 +31,11 @@ void SortCodepointRanges(
CodepointRangeStruct(range->start(), range->end()));
}
- std::sort(sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(),
- [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) {
- return a.start < b.start;
- });
+ std::stable_sort(
+ sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(),
+ [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) {
+ return a.start < b.start;
+ });
}
// Returns true if given codepoint is covered by the given sorted vector of
diff --git a/native/utils/grammar/parsing/parser.cc b/native/utils/grammar/parsing/parser.cc
index 4e39a98..a9e99ba 100644
--- a/native/utils/grammar/parsing/parser.cc
+++ b/native/utils/grammar/parsing/parser.cc
@@ -16,6 +16,7 @@
#include "utils/grammar/parsing/parser.h"
+#include <algorithm>
#include <unordered_map>
#include "utils/grammar/parsing/parse-tree.h"
@@ -177,14 +178,14 @@ std::vector<Symbol> Parser::SortedSymbolsForInput(const TextContext& input,
}
}
- std::sort(symbols.begin(), symbols.end(),
- [](const Symbol& a, const Symbol& b) {
- // Sort by increasing (end, start) position to guarantee the
- // matcher requirement that the tokens are fed in non-decreasing
- // end position order.
- return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
- std::tie(b.codepoint_span.second, b.codepoint_span.first);
- });
+ std::stable_sort(
+ symbols.begin(), symbols.end(), [](const Symbol& a, const Symbol& b) {
+ // Sort by increasing (end, start) position to guarantee the
+ // matcher requirement that the tokens are fed in non-decreasing
+ // end position order.
+ return std::tie(a.codepoint_span.second, a.codepoint_span.first) <
+ std::tie(b.codepoint_span.second, b.codepoint_span.first);
+ });
return symbols;
}
diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc
index dd29e3c..c134550 100644
--- a/native/utils/grammar/utils/ir.cc
+++ b/native/utils/grammar/utils/ir.cc
@@ -16,6 +16,8 @@
#include "utils/grammar/utils/ir.h"
+#include <algorithm>
+
#include "utils/i18n/locale.h"
#include "utils/strings/append.h"
#include "utils/strings/stringpiece.h"
@@ -28,14 +30,16 @@ constexpr size_t kMaxHashTableSize = 100;
template <typename T>
void SortForBinarySearchLookup(T* entries) {
- std::sort(entries->begin(), entries->end(),
- [](const auto& a, const auto& b) { return a->key < b->key; });
+ std::stable_sort(
+ entries->begin(), entries->end(),
+ [](const auto& a, const auto& b) { return a->key < b->key; });
}
template <typename T>
void SortStructsForBinarySearchLookup(T* entries) {
- std::sort(entries->begin(), entries->end(),
- [](const auto& a, const auto& b) { return a.key() < b.key(); });
+ std::stable_sort(
+ entries->begin(), entries->end(),
+ [](const auto& a, const auto& b) { return a.key() < b.key(); });
}
bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) {
@@ -76,13 +80,14 @@ bool IsSameLhsSet(const Ir::LhsSet& lhs_set,
Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) {
Ir::LhsSet sorted_lhs = lhs_set;
- std::sort(sorted_lhs.begin(), sorted_lhs.end(),
- [](const Ir::Lhs& a, const Ir::Lhs& b) {
- return std::tie(a.nonterminal, a.callback.id, a.callback.param,
- a.preconditions.max_whitespace_gap) <
- std::tie(b.nonterminal, b.callback.id, b.callback.param,
- b.preconditions.max_whitespace_gap);
- });
+ std::stable_sort(
+ sorted_lhs.begin(), sorted_lhs.end(),
+ [](const Ir::Lhs& a, const Ir::Lhs& b) {
+ return std::tie(a.nonterminal, a.callback.id, a.callback.param,
+ a.preconditions.max_whitespace_gap) <
+ std::tie(b.nonterminal, b.callback.id, b.callback.param,
+ b.preconditions.max_whitespace_gap);
+ });
return lhs_set;
}
@@ -300,10 +305,10 @@ void Ir::SerializeTerminalRules(
TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second});
}
}
- std::sort(terminal_rules.begin(), terminal_rules.end(),
- [](const TerminalEntry& a, const TerminalEntry& b) {
- return a.terminal < b.terminal;
- });
+ std::stable_sort(terminal_rules.begin(), terminal_rules.end(),
+ [](const TerminalEntry& a, const TerminalEntry& b) {
+ return a.terminal < b.terminal;
+ });
// Index the entries in sorted order.
std::vector<int> index(terminal_rules_sets.size(), 0);
diff --git a/native/utils/grammar/utils/locale-shard-map.cc b/native/utils/grammar/utils/locale-shard-map.cc
index e6db06d..141ce5d 100644
--- a/native/utils/grammar/utils/locale-shard-map.cc
+++ b/native/utils/grammar/utils/locale-shard-map.cc
@@ -40,8 +40,8 @@ std::vector<Locale> LocaleTagsToLocaleList(const std::string& locale_tags) {
locale_list.emplace_back(locale);
}
}
- std::sort(locale_list.begin(), locale_list.end(),
- [](const Locale& a, const Locale& b) { return a < b; });
+ std::stable_sort(locale_list.begin(), locale_list.end(),
+ [](const Locale& a, const Locale& b) { return a < b; });
return locale_list;
}
diff --git a/native/utils/testing/test_data_generator.h b/native/utils/testing/test_data_generator.h
index 30c7aed..c23b5dc 100644
--- a/native/utils/testing/test_data_generator.h
+++ b/native/utils/testing/test_data_generator.h
@@ -20,6 +20,7 @@
#include <algorithm>
#include <iostream>
#include <random>
+#include <string>
#include "utils/strings/stringpiece.h"
@@ -35,6 +36,18 @@ class TestDataGenerator {
return dist(random_engine_);
}
+ template <>
+ bool generate() {
+ std::bernoulli_distribution dist(0.5);
+ return dist(random_engine_);
+ }
+
+ template <>
+ char generate() {
+ std::uniform_int_distribution<int> dist(0, 25);
+ return dist(random_engine_) + 'a';
+ }
+
template <typename T, typename std::enable_if_t<
std::is_floating_point<T>::value>* = nullptr>
T generate() {
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 463d910..644dde8 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -27,6 +27,8 @@ namespace builtin {
TfLiteRegistration* Register_ADD();
TfLiteRegistration* Register_CONCATENATION();
TfLiteRegistration* Register_CONV_2D();
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
+TfLiteRegistration* Register_AVERAGE_POOL_2D();
TfLiteRegistration* Register_EQUAL();
TfLiteRegistration* Register_FULLY_CONNECTED();
TfLiteRegistration* Register_GREATER_EQUAL();
@@ -89,7 +91,9 @@ TfLiteRegistration* Register_GREATER();
#include "utils/tflite/dist_diversification.h"
#include "utils/tflite/string_projection.h"
#include "utils/tflite/text_encoder.h"
+#include "utils/tflite/text_encoder3s.h"
#include "utils/tflite/token_encoder.h"
+
namespace tflite {
namespace ops {
namespace custom {
@@ -114,6 +118,14 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) {
tflite::ops::builtin::Register_CONV_2D(),
/*min_version=*/1,
/*max_version=*/5);
+ resolver->AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
+ tflite::ops::builtin::Register_DEPTHWISE_CONV_2D(),
+ /*min_version=*/1,
+ /*max_version=*/6);
+ resolver->AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D,
+ tflite::ops::builtin::Register_AVERAGE_POOL_2D(),
+ /*min_version=*/1,
+ /*max_version=*/1);
resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL,
::tflite::ops::builtin::Register_EQUAL());
@@ -289,6 +301,8 @@ std::unique_ptr<tflite::OpResolver> BuildOpResolver(
tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION());
resolver->AddCustom("TextEncoder",
tflite::ops::custom::Register_TEXT_ENCODER());
+ resolver->AddCustom("TextEncoder3S",
+ tflite::ops::custom::Register_TEXT_ENCODER3S());
resolver->AddCustom("TokenEncoder",
tflite::ops::custom::Register_TOKEN_ENCODER());
resolver->AddCustom(
diff --git a/native/utils/tflite/encoder_common.cc b/native/utils/tflite/encoder_common.cc
index 8f9f2a8..eb319f9 100644
--- a/native/utils/tflite/encoder_common.cc
+++ b/native/utils/tflite/encoder_common.cc
@@ -58,6 +58,11 @@ TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
out->data.i32 + output_offset + from_this_element,
in.data.i32[value_index]);
} break;
+ case kTfLiteInt64: {
+ std::fill(out->data.i64 + output_offset,
+ out->data.i64 + output_offset + from_this_element,
+ in.data.i64[value_index]);
+ } break;
case kTfLiteFloat32: {
std::fill(out->data.f + output_offset,
out->data.f + output_offset + from_this_element,
@@ -78,6 +83,12 @@ TfLiteStatus CopyValuesToTensorAndPadOrTruncate(
std::fill(out->data.i32 + output_offset, out->data.i32 + output_size,
value);
} break;
+ case kTfLiteInt64: {
+ const int64_t value =
+ (output_offset > 0) ? out->data.i64[output_offset - 1] : 0;
+ std::fill(out->data.i64 + output_offset, out->data.i64 + output_size,
+ value);
+ } break;
case kTfLiteFloat32: {
const float value =
(output_offset > 0) ? out->data.f[output_offset - 1] : 0;
diff --git a/native/utils/tflite/text_encoder3s.cc b/native/utils/tflite/text_encoder3s.cc
new file mode 100644
index 0000000..0b5e65b
--- /dev/null
+++ b/native/utils/tflite/text_encoder3s.cc
@@ -0,0 +1,243 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tflite/text_encoder3s.h"
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/tflite/encoder_common.h"
+#include "utils/tflite/text_encoder_config_generated.h"
+#include "utils/tokenfree/byte_encoder.h"
+#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/string_util.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+// Input parameters for the op.
+constexpr int kInputTextInd = 0;
+
+constexpr int kTextLengthInd = 1;
+constexpr int kMaxLengthInd = 2;
+constexpr int kInputAttrInd = 3;
+
+// Output parameters for the op.
+constexpr int kOutputEncodedInd = 0;
+constexpr int kOutputPositionInd = 1;
+constexpr int kOutputLengthsInd = 2;
+constexpr int kOutputAttrInd = 3;
+
+// Initializes text encoder object from serialized parameters.
+void* Initialize(TfLiteContext* context, const char* buffer, size_t length) {
+ std::unique_ptr<ByteEncoder> encoder(new ByteEncoder());
+ return encoder.release();
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<ByteEncoder*>(buffer);
+}
+
+namespace {
+TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
+ int max_output_length) {
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncodedInd]];
+
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(
+ context, &output_encoded,
+ CreateIntArray({kEncoderBatchSize, max_output_length})));
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPositionInd]];
+
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(
+ context, &output_positions,
+ CreateIntArray({kEncoderBatchSize, max_output_length})));
+
+ const int num_output_attrs = node->outputs->size - kOutputAttrInd;
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& output =
+ context->tensors[node->outputs->data[kOutputAttrInd + i]];
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(
+ context, &output,
+ CreateIntArray({kEncoderBatchSize, max_output_length})));
+ }
+ return kTfLiteOk;
+}
+} // namespace
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ // Check that the batch dimension is kEncoderBatchSize.
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[kInputTextInd]];
+ TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank);
+ TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize);
+
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengthsInd]];
+
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncodedInd]];
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPositionInd]];
+ output_encoded.type = kTfLiteInt32;
+ output_positions.type = kTfLiteInt32;
+ output_lengths.type = kTfLiteInt32;
+
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, &output_lengths,
+ CreateIntArray({kEncoderBatchSize})));
+
+ // Check that there are enough outputs for attributes.
+ const int num_output_attrs = node->outputs->size - kOutputAttrInd;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
+ num_output_attrs);
+
+ // Copy attribute types from input to output tensors.
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& input =
+ context->tensors[node->inputs->data[kInputAttrInd + i]];
+ TfLiteTensor& output =
+ context->tensors[node->outputs->data[kOutputAttrInd + i]];
+ output.type = input.type;
+ }
+
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kMaxLengthInd]];
+
+ if (tflite::IsConstantTensor(&output_length)) {
+ return ResizeOutputTensors(context, node, output_length.data.i64[0]);
+ } else {
+ tflite::SetTensorToDynamic(&output_encoded);
+ tflite::SetTensorToDynamic(&output_positions);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteTensor& output_attr =
+ context->tensors[node->outputs->data[kOutputAttrInd + i]];
+ tflite::SetTensorToDynamic(&output_attr);
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ if (node->user_data == nullptr) {
+ return kTfLiteError;
+ }
+ auto text_encoder = reinterpret_cast<ByteEncoder*>(node->user_data);
+ const TfLiteTensor& input_text =
+ context->tensors[node->inputs->data[kInputTextInd]];
+ const int num_strings_in_tensor = tflite::GetStringCount(&input_text);
+ const int num_strings =
+ context->tensors[node->inputs->data[kTextLengthInd]].data.i32[0];
+
+ // Check that the number of strings is not bigger than the input tensor size.
+ TF_LITE_ENSURE(context, num_strings_in_tensor >= num_strings);
+
+ TfLiteTensor& output_encoded =
+ context->tensors[node->outputs->data[kOutputEncodedInd]];
+ if (tflite::IsDynamicTensor(&output_encoded)) {
+ const TfLiteTensor& output_length =
+ context->tensors[node->inputs->data[kMaxLengthInd]];
+ TF_LITE_ENSURE_OK(
+ context, ResizeOutputTensors(context, node, output_length.data.i64[0]));
+ }
+ TfLiteTensor& output_positions =
+ context->tensors[node->outputs->data[kOutputPositionInd]];
+
+ std::vector<int> encoded_total;
+ std::vector<int> encoded_positions;
+ std::vector<int> encoded_offsets;
+ encoded_offsets.reserve(num_strings);
+ const int max_output_length = output_encoded.dims->data[1];
+ const int max_encoded_position = max_output_length;
+
+ for (int i = 0; i < num_strings; ++i) {
+ const auto& strref = tflite::GetString(&input_text, i);
+ std::vector<int64_t> encoded;
+ text_encoder->Encode(
+ libtextclassifier3::StringPiece(strref.str, strref.len), &encoded);
+ encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end());
+ encoded_offsets.push_back(encoded_total.size());
+ for (int i = 0; i < encoded.size(); ++i) {
+ encoded_positions.push_back(std::min(i, max_encoded_position - 1));
+ }
+ }
+
+ // Copy encoding to output tensor.
+ const int start_offset =
+ std::max(0, static_cast<int>(encoded_total.size()) - max_output_length);
+ int output_offset = 0;
+ int32_t* output_buffer = output_encoded.data.i32;
+ int32_t* output_positions_buffer = output_positions.data.i32;
+ for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) {
+ output_buffer[output_offset] = encoded_total[i];
+ output_positions_buffer[output_offset] = encoded_positions[i];
+ }
+
+ // Save output encoded length.
+ TfLiteTensor& output_lengths =
+ context->tensors[node->outputs->data[kOutputLengthsInd]];
+ output_lengths.data.i32[0] = output_offset;
+
+ // Do padding.
+ for (; output_offset < max_output_length; ++output_offset) {
+ output_buffer[output_offset] = 0;
+ output_positions_buffer[output_offset] = 0;
+ }
+
+ // Process attributes, all checks of sizes and types are done in Prepare.
+ const int num_output_attrs = node->outputs->size - kOutputAttrInd;
+ TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd,
+ num_output_attrs);
+ for (int i = 0; i < num_output_attrs; ++i) {
+ TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate(
+ context->tensors[node->inputs->data[kInputAttrInd + i]],
+ encoded_offsets, start_offset, context,
+ &context->tensors[node->outputs->data[kOutputAttrInd + i]]);
+ if (attr_status != kTfLiteOk) {
+ return attr_status;
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace libtextclassifier3
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER3S() {
+ static TfLiteRegistration registration = {
+ libtextclassifier3::Initialize, libtextclassifier3::Free,
+ libtextclassifier3::Prepare, libtextclassifier3::Eval};
+ return &registration;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/native/utils/tflite/text_encoder3s.h b/native/utils/tflite/text_encoder3s.h
new file mode 100644
index 0000000..50e1e64
--- /dev/null
+++ b/native/utils/tflite/text_encoder3s.h
@@ -0,0 +1,35 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// An encoder that produces positional and attributes encodings for a
+// transformer style model based on byte segmentation of text.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_
+#define LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_
+
+#include "tensorflow/lite/context.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_TEXT_ENCODER3S();
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_
diff --git a/native/utils/tokenfree/byte_encoder.cc b/native/utils/tokenfree/byte_encoder.cc
new file mode 100644
index 0000000..c79d3a2
--- /dev/null
+++ b/native/utils/tokenfree/byte_encoder.cc
@@ -0,0 +1,42 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenfree/byte_encoder.h"
+
+#include <vector>
+namespace libtextclassifier3 {
+
+bool ByteEncoder::Encode(StringPiece input_text,
+ std::vector<int64_t>* encoded_text) const {
+ const int len = input_text.size();
+ if (len <= 0) {
+ *encoded_text = {};
+ return true;
+ }
+
+ int size = input_text.size();
+ encoded_text->resize(size);
+
+ const auto& text = input_text.ToString();
+ for (int i = 0; i < size; i++) {
+ int64_t encoding = static_cast<int64_t>(text[i]);
+ (*encoded_text)[i] = encoding;
+ }
+
+ return true;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/tokenfree/byte_encoder.h b/native/utils/tokenfree/byte_encoder.h
new file mode 100644
index 0000000..1a495ec
--- /dev/null
+++ b/native/utils/tokenfree/byte_encoder.h
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_
+#define LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_
+
+#include <vector>
+
+#include "utils/base/logging.h"
+#include "utils/container/string-set.h"
+#include "utils/strings/stringpiece.h"
+
+namespace libtextclassifier3 {
+
+// Encoder to segment/tokenize strings into bytes
+class ByteEncoder {
+ public:
+ bool Encode(StringPiece input_text, std::vector<int64_t>* encoded_text) const;
+ ByteEncoder() {}
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_
diff --git a/native/utils/tokenfree/byte_encoder_test.cc b/native/utils/tokenfree/byte_encoder_test.cc
new file mode 100644
index 0000000..d4d119e
--- /dev/null
+++ b/native/utils/tokenfree/byte_encoder_test.cc
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/tokenfree/byte_encoder.h"
+
+#include <memory>
+#include <vector>
+
+#include "utils/base/integral_types.h"
+#include "utils/container/sorted-strings-table.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::ElementsAre;
+
+TEST(EncoderTest, SimpleTokenization) {
+ const ByteEncoder encoder;
+ {
+ std::vector<int64_t> encoded_text;
+ EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text));
+ EXPECT_THAT(encoded_text,
+ ElementsAre(104, 101, 108, 108, 111, 116, 104, 101, 114, 101));
+ }
+}
+
+TEST(EncoderTest, SimpleTokenization2) {
+ const ByteEncoder encoder;
+ {
+ std::vector<int64_t> encoded_text;
+ EXPECT_TRUE(encoder.Encode("Hello", &encoded_text));
+ EXPECT_THAT(encoded_text, ElementsAre(72, 101, 108, 108, 111));
+ }
+}
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc
index 071141c..7038517 100644
--- a/native/utils/tokenizer.cc
+++ b/native/utils/tokenizer.cc
@@ -43,11 +43,12 @@ Tokenizer::Tokenizer(
codepoint_ranges_.emplace_back(range->UnPack());
}
- std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(),
- [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
- const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
- return a->start < b->start;
- });
+ std::stable_sort(
+ codepoint_ranges_.begin(), codepoint_ranges_.end(),
+ [](const std::unique_ptr<const TokenizationCodepointRangeT>& a,
+ const std::unique_ptr<const TokenizationCodepointRangeT>& b) {
+ return a->start < b->start;
+ });
SortCodepointRanges(internal_tokenizer_codepoint_ranges,
&internal_tokenizer_codepoint_ranges_);
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
index bc30fcf..f539ba7 100644
--- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java
@@ -37,15 +37,20 @@ import androidx.test.filters.LargeTest;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
@LargeTest
@RunWith(AndroidJUnit4.class)
public class SmartSuggestionsLogSessionTest {
+
+ @Rule public final MockitoRule mocks = MockitoJUnit.rule();
+
private static final String RESULT_ID = "resultId";
private static final String REPLY = "reply";
private static final float SCORE = 0.5f;
@@ -55,7 +60,6 @@ public class SmartSuggestionsLogSessionTest {
@Before
public void setup() {
- MockitoAnnotations.initMocks(this);
session =
new SmartSuggestionsLogSession(