diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-02-05 13:02:37 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2022-02-05 13:02:37 +0000 |
commit | ecbfc3f0e29e0d4b4a6e5c7679df5a3bc05c6a5d (patch) | |
tree | 7ff8d841461eef3d6a4fc61221dab93ab7eb2fe7 | |
parent | 6b4ecf498c9d82ef8ba29729670c10c5fb4a710b (diff) | |
parent | faf03992dc4e169d214b17726a82f664efd6b57a (diff) | |
download | libtextclassifier-android12-mainline-media-swcodec-release.tar.gz |
Snap for 8152310 from faf03992dc4e169d214b17726a82f664efd6b57a to mainline-media-swcodec-releaseandroid-mainline-12.0.0_r91android12-mainline-media-swcodec-release
Change-Id: I0ba3ed5e29e26af2034e05fc9e46dac96bfa4225
52 files changed, 825 insertions, 155 deletions
diff --git a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java index dae0442..67e300d 100644 --- a/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java +++ b/java/src/com/android/textclassifier/common/logging/ResultIdUtils.java @@ -66,8 +66,8 @@ public final class ResultIdUtils { } /** Returns if the result id was generated from the default text classifier. */ - public static boolean isFromDefaultTextClassifier(String resultId) { - return resultId.startsWith(CLASSIFIER_ID + '|'); + public static boolean isFromDefaultTextClassifier(@Nullable String resultId) { + return resultId != null && resultId.startsWith(CLASSIFIER_ID + '|'); } /** Returns all the model names encoded in the signature. */ diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java index 0e3842c..71f9a4f 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java @@ -52,16 +52,21 @@ import java.util.concurrent.Executor; import java.util.stream.Collectors; import org.junit.After; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @SmallTest @RunWith(AndroidJUnit4.class) public class DefaultTextClassifierServiceTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + /** A statsd config ID, which is arbitrary. */ private static final long CONFIG_ID = 689777; @@ -79,7 +84,6 @@ public class DefaultTextClassifierServiceTest { @Before public void setup() { - MockitoAnnotations.initMocks(this); testInjector = new TestInjector(ApplicationProvider.getApplicationContext()); defaultTextClassifierService = new DefaultTextClassifierService(testInjector); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java index 5297640..20ae592 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerImplTest.java @@ -53,7 +53,8 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @SmallTest @RunWith(AndroidJUnit4.class) @@ -67,6 +68,7 @@ public final class ModelFileManagerImplTest { @Mock private DownloadedModelManager downloadedModelManager; @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); private File rootTestDir; private ModelFileManagerImpl modelFileManager; @@ -75,7 +77,6 @@ public final class ModelFileManagerImplTest { @Before public void setup() { - MockitoAnnotations.initMocks(this); deviceConfig = new TestingDeviceConfig(); rootTestDir = new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir"); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java index 216cd5d..3aab211 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/TemplateIntentFactoryTest.java @@ -27,14 +27,18 @@ import com.google.android.textclassifier.NamedVariant; import com.google.android.textclassifier.RemoteActionTemplate; import java.util.List; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @SmallTest @RunWith(AndroidJUnit4.class) public class TemplateIntentFactoryTest { + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private static final String TITLE_WITHOUT_ENTITY = "Map"; private static final String TITLE_WITH_ENTITY = "Map NW14D1"; private static final String DESCRIPTION = "Check the map"; @@ -71,7 +75,6 @@ public class TemplateIntentFactoryTest { @Before public void setup() { - MockitoAnnotations.initMocks(this); templateIntentFactory = new TemplateIntentFactory(); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java index ffd2ee4..3a8fefc 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/StatsdTestUtils.java @@ -86,8 +86,8 @@ public class StatsdTestUtils { return ImmutableList.copyOf( metricsList.stream() .flatMap(statsLogReport -> statsLogReport.getEventMetrics().getDataList().stream()) - .flatMap(eventMetricData -> backfillAggregatedAtomsinEventMetric( - eventMetricData).stream()) + .flatMap( + eventMetricData -> backfillAggregatedAtomsinEventMetric(eventMetricData).stream()) .sorted(Comparator.comparing(EventMetricData::getElapsedTimestampNanos)) .map(EventMetricData::getAtom) .collect(Collectors.toList())); @@ -136,7 +136,7 @@ public class StatsdTestUtils { } private static ImmutableList<EventMetricData> backfillAggregatedAtomsinEventMetric( - EventMetricData metricData) { + EventMetricData metricData) { if (metricData.hasAtom()) { return ImmutableList.of(metricData); } diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java index c626ed7..394b7ad 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java @@ -46,7 +46,8 @@ import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @RunWith(AndroidJUnit4.class) public final class ModelDownloadManagerTest { @@ -61,6 +62,8 @@ public final class ModelDownloadManagerTest { public final TextClassifierDownloadLoggerTestRule loggerTestRule = new TextClassifierDownloadLoggerTestRule(); + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private TestingDeviceConfig deviceConfig; private WorkManager workManager; private ModelDownloadManager downloadManager; @@ -68,7 +71,6 @@ public final class ModelDownloadManagerTest { @Before public void setUp() { - MockitoAnnotations.initMocks(this); Context context = ApplicationProvider.getApplicationContext(); WorkManagerTestInitHelper.initializeTestWorkManager(context); diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java index eac2af3..76d04e0 100644 --- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java +++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderServiceImplTest.java @@ -37,14 +37,19 @@ import com.google.common.util.concurrent.SettableFuture; import java.io.File; import java.net.URI; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @RunWith(JUnit4.class) public final class ModelDownloaderServiceImplTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private static final long BYTES_WRITTEN = 1L; private static final String DOWNLOAD_URI = "https://www.gstatic.com/android/text_classifier/r/v999/en.fb"; @@ -66,7 +71,6 @@ public final class ModelDownloaderServiceImplTest { @Before public void setUp() { - MockitoAnnotations.initMocks(this); this.targetModelFile = new File(ApplicationProvider.getApplicationContext().getCacheDir(), "model.fb"); diff --git a/native/actions/actions-entity-data.bfbs b/native/actions/actions-entity-data.bfbs Binary files differindex 7421579..6ebf1cf 100644 --- a/native/actions/actions-entity-data.bfbs +++ b/native/actions/actions-entity-data.bfbs diff --git a/native/actions/actions-entity-data.fbs b/native/actions/actions-entity-data.fbs index 21584b6..e906f93 100644 --- a/native/actions/actions-entity-data.fbs +++ b/native/actions/actions-entity-data.fbs @@ -18,7 +18,7 @@ namespace libtextclassifier3; table ActionsEntityData { // Extracted text. - text:string (shared); + text:string (key, shared); } root_type libtextclassifier3.ActionsEntityData; diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc index b1a042c..9f9a8d4 100644 --- a/native/actions/actions-suggestions.cc +++ b/native/actions/actions-suggestions.cc @@ -17,6 +17,7 @@ #include "actions/actions-suggestions.h" #include <memory> +#include <string> #include <vector> #include "utils/base/statusor.h" @@ -40,6 +41,7 @@ #include "utils/strings/stringpiece.h" #include "utils/strings/utf8.h" #include "utils/utf8/unicodetext.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/lite/string_util.h" namespace libtextclassifier3 { @@ -809,12 +811,14 @@ bool ActionsSuggestions::SetupModelInput( void ActionsSuggestions::PopulateTextReplies( const tflite::Interpreter* interpreter, int suggestion_index, - int score_index, const std::string& type, + int score_index, const std::string& type, float priority_score, + const absl::flat_hash_set<std::string>& blocklist, ActionsSuggestionsResponse* response) const { const std::vector<tflite::StringRef> replies = model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter); const TensorView<float> scores = model_executor_->OutputView<float>(score_index, interpreter); + for (int i = 0; i < replies.size(); i++) { if (replies[i].len == 0) { continue; @@ -823,8 +827,12 @@ void ActionsSuggestions::PopulateTextReplies( if (score < preconditions_.min_reply_score_threshold) { continue; } - response->actions.push_back( - {std::string(replies[i].str, replies[i].len), type, score}); + std::string response_text(replies[i].str, replies[i].len); + if (blocklist.contains(response_text)) { + continue; + } + + response->actions.push_back({response_text, type, score, priority_score}); } } @@ -909,10 +917,12 @@ bool ActionsSuggestions::ReadModelOutput( // Read smart reply predictions. if (!response->output_filtered_min_triggering_score && model_->tflite_model_spec()->output_replies() >= 0) { + absl::flat_hash_set<std::string> empty_blocklist; PopulateTextReplies(interpreter, model_->tflite_model_spec()->output_replies(), model_->tflite_model_spec()->output_replies_scores(), - model_->smart_reply_action_type()->str(), response); + model_->smart_reply_action_type()->str(), + /* priority_score */ 0.0, empty_blocklist, response); } // Read actions suggestions. @@ -950,17 +960,26 @@ bool ActionsSuggestions::ReadModelOutput( const int suggestions_index = metadata->output_suggestions(); const int suggestions_scores_index = metadata->output_suggestions_scores(); + absl::flat_hash_set<std::string> response_text_blocklist; switch (metadata->prediction_type()) { case PredictionType_NEXT_MESSAGE_PREDICTION: if (!task_spec || task_spec->type()->size() == 0) { TC3_LOG(WARNING) << "Task type not provided, use default " "smart_reply_action_type!"; } + if (task_spec) { + if (task_spec->response_text_blocklist()) { + for (const auto& val : *task_spec->response_text_blocklist()) { + response_text_blocklist.insert(val->str()); + } + } + } PopulateTextReplies( interpreter, suggestions_index, suggestions_scores_index, task_spec ? task_spec->type()->str() : model_->smart_reply_action_type()->str(), - response); + task_spec ? task_spec->priority_score() : 0.0, + response_text_blocklist, response); break; case PredictionType_INTENT_TRIGGERING: PopulateIntentTriggering(interpreter, suggestions_index, diff --git a/native/actions/actions-suggestions.h b/native/actions/actions-suggestions.h index 32edc78..87f55fb 100644 --- a/native/actions/actions-suggestions.h +++ b/native/actions/actions-suggestions.h @@ -43,6 +43,7 @@ #include "utils/utf8/unilib.h" #include "utils/variant.h" #include "utils/zlib/zlib.h" +#include "absl/container/flat_hash_set.h" namespace libtextclassifier3 { @@ -176,7 +177,8 @@ class ActionsSuggestions { void PopulateTextReplies(const tflite::Interpreter* interpreter, int suggestion_index, int score_index, - const std::string& type, + const std::string& type, float priority_score, + const absl::flat_hash_set<std::string>& blocklist, ActionsSuggestionsResponse* response) const; void PopulateIntentTriggering(const tflite::Interpreter* interpreter, diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc index 062d527..b51ebc7 100644 --- a/native/actions/actions-suggestions_test.cc +++ b/native/actions/actions-suggestions_test.cc @@ -1798,6 +1798,7 @@ TEST_F(ActionsSuggestionsTest, TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) { std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(kMultiTaskSrEmojiModelFileName); + const ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( {{{/*user_id=*/1, "hello?", @@ -1807,9 +1808,31 @@ TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) { /*locales=*/"en"}}}); EXPECT_EQ(response.actions.size(), 5); EXPECT_EQ(response.actions[0].response_text, "😁"); - EXPECT_EQ(response.actions[0].type, "EMOJI_CONCEPT"); - EXPECT_EQ(response.actions[1].response_text, "Yes"); - EXPECT_EQ(response.actions[1].type, "REPLY_SUGGESTION"); + EXPECT_EQ(response.actions[0].type, "text_reply"); + EXPECT_EQ(response.actions[1].response_text, "👋"); + EXPECT_EQ(response.actions[1].type, "text_reply"); + EXPECT_EQ(response.actions[2].response_text, "Yes"); + EXPECT_EQ(response.actions[2].type, "text_reply"); +} + +TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) { + std::unique_ptr<ActionsSuggestions> actions_suggestions = + LoadTestModel(kMultiTaskSrEmojiModelFileName); + + const ActionsSuggestionsResponse response = + actions_suggestions->SuggestActions( + {{{/*user_id=*/1, "a pleasure chatting", + /*reference_time_ms_utc=*/0, + /*reference_timezone=*/"Europe/Zurich", + /*annotations=*/{}, + /*locales=*/"en"}}}); + EXPECT_EQ(response.actions.size(), 3); + EXPECT_EQ(response.actions[0].response_text, "😁"); + EXPECT_EQ(response.actions[0].type, "text_reply"); + EXPECT_EQ(response.actions[1].response_text, "😘"); + EXPECT_EQ(response.actions[1].type, "text_reply"); + EXPECT_EQ(response.actions[2].response_text, "Okay"); + EXPECT_EQ(response.actions[2].type, "text_reply"); } TEST_F(ActionsSuggestionsTest, LiveRelayModel) { diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs index 8c03eeb..0d8c7ad 100644 --- a/native/actions/actions_model.fbs +++ b/native/actions/actions_model.fbs @@ -36,6 +36,17 @@ enum PredictionType : int { ENTITY_ANNOTATION = 3, } +namespace libtextclassifier3; +enum RankingOptionsSortType : int { + SORT_TYPE_UNSPECIFIED = 0, + + // Rank results (or groups) by score, then type + SORT_TYPE_SCORE = 1, + + // Rank results (or groups) by priority score, then score, then type + SORT_TYPE_PRIORITY_SCORE = 2, +} + // Prediction metadata for an arbitrary task. namespace libtextclassifier3; table PredictionMetadata { @@ -315,10 +326,11 @@ table ActionSuggestionSpec { // Additional entity information. serialized_entity_data:string (shared); - // Priority score used for internal conflict resolution. + // For ranking and internal conflict resolution. priority_score:float = 0; entity_data:ActionsEntityData; + response_text_blocklist:[string]; } // Options to specify triggering behaviour per action class. @@ -416,6 +428,8 @@ table RankingOptions { // If true, keep actions from the same entities together for ranking. group_by_annotations:bool = true; + + sort_type:RankingOptionsSortType = SORT_TYPE_SCORE; } // Entity data to set from capturing groups. diff --git a/native/actions/ranker.cc b/native/actions/ranker.cc index d52ecaa..46e392a 100644 --- a/native/actions/ranker.cc +++ b/native/actions/ranker.cc @@ -20,6 +20,8 @@ #include <set> #include <vector> +#include "actions/actions_model_generated.h" + #if !defined(TC3_DISABLE_LUA) #include "actions/lua-ranker.h" #endif @@ -34,11 +36,22 @@ namespace libtextclassifier3 { namespace { void SortByScoreAndType(std::vector<ActionSuggestion>* actions) { - std::sort(actions->begin(), actions->end(), - [](const ActionSuggestion& a, const ActionSuggestion& b) { - return a.score > b.score || - (a.score >= b.score && a.type < b.type); - }); + std::stable_sort(actions->begin(), actions->end(), + [](const ActionSuggestion& a, const ActionSuggestion& b) { + return a.score > b.score || + (a.score >= b.score && a.type < b.type); + }); +} + +void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) { + std::stable_sort( + actions->begin(), actions->end(), + [](const ActionSuggestion& a, const ActionSuggestion& b) { + return a.priority_score > b.priority_score || + (a.priority_score >= b.priority_score && a.score > b.score) || + (a.priority_score >= b.priority_score && a.score >= b.score && + a.type < b.type); + }); } template <typename T> @@ -241,13 +254,8 @@ bool ActionsSuggestionsRanker::RankActions( const reflection::Schema* annotations_entity_data_schema) const { if (options_->deduplicate_suggestions() || options_->deduplicate_suggestions_by_span()) { - // First order suggestions by priority score for deduplication. - std::sort( - response->actions.begin(), response->actions.end(), - [](const ActionSuggestion& a, const ActionSuggestion& b) { - return a.priority_score > b.priority_score || - (a.priority_score >= b.priority_score && a.score > b.score); - }); + // Order suggestions by [priority score -> score] for deduplication + SortByPriorityAndScoreAndType(&response->actions); // Deduplicate, keeping the higher score actions. if (options_->deduplicate_suggestions()) { @@ -275,6 +283,8 @@ bool ActionsSuggestionsRanker::RankActions( } } + bool sort_by_priority = + options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE; // Suppress smart replies if actions are present. if (options_->suppress_smart_replies_with_actions()) { std::vector<ActionSuggestion> non_smart_reply_actions; @@ -316,17 +326,35 @@ bool ActionsSuggestionsRanker::RankActions( // Sort within each group by score. for (std::vector<ActionSuggestion>& group : groups) { - SortByScoreAndType(&group); + if (sort_by_priority) { + SortByPriorityAndScoreAndType(&group); + } else { + SortByScoreAndType(&group); + } } - // Sort groups by maximum score. - std::sort(groups.begin(), groups.end(), - [](const std::vector<ActionSuggestion>& a, - const std::vector<ActionSuggestion>& b) { - return a.begin()->score > b.begin()->score || - (a.begin()->score >= b.begin()->score && - a.begin()->type < b.begin()->type); - }); + // Sort groups by maximum score or priority score. + if (sort_by_priority) { + std::stable_sort( + groups.begin(), groups.end(), + [](const std::vector<ActionSuggestion>& a, + const std::vector<ActionSuggestion>& b) { + return (a.begin()->priority_score > b.begin()->priority_score) || + (a.begin()->priority_score >= b.begin()->priority_score && + a.begin()->score > b.begin()->score) || + (a.begin()->priority_score >= b.begin()->priority_score && + a.begin()->score >= b.begin()->score && + a.begin()->type < b.begin()->type); + }); + } else { + std::stable_sort(groups.begin(), groups.end(), + [](const std::vector<ActionSuggestion>& a, + const std::vector<ActionSuggestion>& b) { + return a.begin()->score > b.begin()->score || + (a.begin()->score >= b.begin()->score && + a.begin()->type < b.begin()->type); + }); + } // Flatten result. const size_t num_actions = response->actions.size(); @@ -336,9 +364,9 @@ bool ActionsSuggestionsRanker::RankActions( response->actions.insert(response->actions.end(), actions.begin(), actions.end()); } - + } else if (sort_by_priority) { + SortByPriorityAndScoreAndType(&response->actions); } else { - // Order suggestions independently by score. SortByScoreAndType(&response->actions); } diff --git a/native/actions/ranker_test.cc b/native/actions/ranker_test.cc index b52cf45..5eba45f 100644 --- a/native/actions/ranker_test.cc +++ b/native/actions/ranker_test.cc @@ -18,6 +18,7 @@ #include <string> +#include "actions/actions_model_generated.h" #include "actions/types.h" #include "utils/zlib/zlib.h" #include "gmock/gmock.h" @@ -308,12 +309,12 @@ TEST(RankingTest, GroupsActionsByAnnotations) { response.actions.push_back({/*response_text=*/"", /*type=*/"call_phone", /*score=*/1.0, - /*priority_score=*/1.0, + /*priority_score=*/0.0, /*annotations=*/{annotation}}); response.actions.push_back({/*response_text=*/"", /*type=*/"add_contact", /*score=*/0.0, - /*priority_score=*/0.0, + /*priority_score=*/1.0, /*annotations=*/{annotation}}); } response.actions.push_back({/*response_text=*/"How are you?", @@ -338,23 +339,75 @@ TEST(RankingTest, GroupsActionsByAnnotations) { IsAction("text_reply", "How are you?", 0.5)})); } -TEST(RankingTest, SortsActionsByScore) { +TEST(RankingTest, GroupsByAnnotationsSortedByPriority) { const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}}; ActionsSuggestionsResponse response; + response.actions.push_back({/*response_text=*/"How are you?", + /*type=*/"text_reply", + /*score=*/2.0, + /*priority_score=*/0.0}); { ActionSuggestionAnnotation annotation; annotation.span = {/*message_index=*/0, /*span=*/{5, 8}, /*text=*/"911"}; annotation.entity = ClassificationResult("phone", 1.0); response.actions.push_back({/*response_text=*/"", + /*type=*/"add_contact", + /*score=*/0.0, + /*priority_score=*/1.0, + /*annotations=*/{annotation}}); + response.actions.push_back({/*response_text=*/"", /*type=*/"call_phone", /*score=*/1.0, + /*priority_score=*/0.0, + /*annotations=*/{annotation}}); + response.actions.push_back({/*response_text=*/"", + /*type=*/"add_contact2", + /*score=*/0.5, /*priority_score=*/1.0, /*annotations=*/{annotation}}); + } + RankingOptionsT options; + options.group_by_annotations = true; + options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(RankingOptions::Pack(builder, &options)); + auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker( + flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()), + /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply"); + + ranker->RankActions(conversation, &response); + + // The text reply should be last, even though it's score is higher than + // any other scores -- because it's priority_score is lower than the max + // of those with the 'phone' annotation + EXPECT_THAT(response.actions, + testing::ElementsAreArray({ + // Group 1 (Phone annotation) + IsAction("add_contact2", "", 0.5), // priority_score=1.0 + IsAction("add_contact", "", 0.0), // priority_score=1.0 + IsAction("call_phone", "", 1.0), // priority_score=0.0 + IsAction("text_reply", "How are you?", 2.0), // Group 2 + })); +} + +TEST(RankingTest, SortsActionsByScore) { + const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}}; + ActionsSuggestionsResponse response; + { + ActionSuggestionAnnotation annotation; + annotation.span = {/*message_index=*/0, /*span=*/{5, 8}, + /*text=*/"911"}; + annotation.entity = ClassificationResult("phone", 1.0); + response.actions.push_back({/*response_text=*/"", + /*type=*/"call_phone", + /*score=*/1.0, + /*priority_score=*/0.0, + /*annotations=*/{annotation}}); response.actions.push_back({/*response_text=*/"", /*type=*/"add_contact", /*score=*/0.0, - /*priority_score=*/0.0, + /*priority_score=*/1.0, /*annotations=*/{annotation}}); } response.actions.push_back({/*response_text=*/"How are you?", @@ -378,5 +431,40 @@ TEST(RankingTest, SortsActionsByScore) { IsAction("add_contact", "", 0.0)})); } +TEST(RankingTest, SortsActionsByPriority) { + const Conversation conversation = {{{/*user_id=*/1, "hello?"}}}; + ActionsSuggestionsResponse response; + // emoji replies given higher priority_score + response.actions.push_back({/*response_text=*/"😁", + /*type=*/"text_reply", + /*score=*/0.5, + /*priority_score=*/1.0}); + response.actions.push_back({/*response_text=*/"👋", + /*type=*/"text_reply", + /*score=*/0.4, + /*priority_score=*/1.0}); + response.actions.push_back({/*response_text=*/"Yes", + /*type=*/"text_reply", + /*score=*/1.0, + /*priority_score=*/0.0}); + RankingOptionsT options; + // Don't group by annotation. + options.group_by_annotations = false; + options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE; + flatbuffers::FlatBufferBuilder builder; + builder.Finish(RankingOptions::Pack(builder, &options)); + auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker( + flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()), + /*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply"); + + ranker->RankActions(conversation, &response); + + EXPECT_THAT(response.actions, testing::ElementsAreArray( + {IsAction("text_reply", "😁", 0.5), + IsAction("text_reply", "👋", 0.4), + // Ranked last because of priority score + IsAction("text_reply", "Yes", 1.0)})); +} + } // namespace } // namespace libtextclassifier3 diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model Binary files differindex 77e556c..0fa7f7e 100644 --- a/native/actions/test_data/actions_suggestions_grammar_test.model +++ b/native/actions/test_data/actions_suggestions_grammar_test.model diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model Binary files differindex c468bd5..6107e98 100644 --- a/native/actions/test_data/actions_suggestions_test.model +++ b/native/actions/test_data/actions_suggestions_test.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model Binary files differindex ec421a1..436ed93 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model Binary files differindex 24be6c6..935691d 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model Binary files differindex fd7ddf2..2c9f74b 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model Binary files differindex c969c56..cdb7523 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model Binary files differindex d171898..ac28fa2 100644 --- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model +++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model Binary files differindex 937552b..d864b79 100644 --- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model +++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc index 32bd29c..e0d4241 100644 --- a/native/annotator/annotator.cc +++ b/native/annotator/annotator.cc @@ -973,11 +973,11 @@ CodepointSpan Annotator::SuggestSelection( // Sort candidates according to their position in the input, so that the next // code can assume that any connected component of overlapping spans forms a // contiguous block. - std::sort(candidates.annotated_spans[0].begin(), - candidates.annotated_spans[0].end(), - [](const AnnotatedSpan& a, const AnnotatedSpan& b) { - return a.span.first < b.span.first; - }); + std::stable_sort(candidates.annotated_spans[0].begin(), + candidates.annotated_spans[0].end(), + [](const AnnotatedSpan& a, const AnnotatedSpan& b) { + return a.span.first < b.span.first; + }); std::vector<int> candidate_indices; if (!ResolveConflicts(candidates.annotated_spans[0], context, tokens, @@ -987,13 +987,14 @@ CodepointSpan Annotator::SuggestSelection( return original_click_indices; } - std::sort(candidate_indices.begin(), candidate_indices.end(), - [this, &candidates](int a, int b) { - return GetPriorityScore( - candidates.annotated_spans[0][a].classification) > - GetPriorityScore( - candidates.annotated_spans[0][b].classification); - }); + std::stable_sort( + candidate_indices.begin(), candidate_indices.end(), + [this, &candidates](int a, int b) { + return GetPriorityScore( + candidates.annotated_spans[0][a].classification) > + GetPriorityScore( + candidates.annotated_spans[0][b].classification); + }); for (const int i : candidate_indices) { if (SpansOverlap(candidates.annotated_spans[0][i].span, click_indices) && @@ -1173,7 +1174,7 @@ bool Annotator::ResolveConflict( } } - std::sort( + std::stable_sort( conflicting_indices.begin(), conflicting_indices.end(), [this, &scores_lengths, candidates, conflicting_indices](int i, int j) { if (scores_lengths[i].first == scores_lengths[j].first && @@ -1241,7 +1242,7 @@ bool Annotator::ResolveConflict( chosen_indices_for_source_ptr->insert(considered_candidate); } - std::sort(chosen_indices->begin(), chosen_indices->end()); + std::stable_sort(chosen_indices->begin(), chosen_indices->end()); return true; } @@ -1414,10 +1415,11 @@ namespace { // Sorts the classification results from high score to low score. void SortClassificationResults( std::vector<ClassificationResult>* classification_results) { - std::sort(classification_results->begin(), classification_results->end(), - [](const ClassificationResult& a, const ClassificationResult& b) { - return a.score > b.score; - }); + std::stable_sort( + classification_results->begin(), classification_results->end(), + [](const ClassificationResult& a, const ClassificationResult& b) { + return a.score > b.score; + }); } } // namespace @@ -1936,10 +1938,11 @@ std::vector<ClassificationResult> Annotator::ClassifyText( } // Sort results according to score. - std::sort(results.begin(), results.end(), - [](const ClassificationResult& a, const ClassificationResult& b) { - return a.score > b.score; - }); + std::stable_sort( + results.begin(), results.end(), + [](const ClassificationResult& a, const ClassificationResult& b) { + return a.score > b.score; + }); if (results.empty()) { results = {{Collections::Other(), 1.0}}; @@ -2297,19 +2300,19 @@ Status Annotator::AnnotateSingleInput( // Also sort them according to the end position and collection, so that the // deduplication code below can assume that same spans and classifications // form contiguous blocks. - std::sort(candidates->begin(), candidates->end(), - [](const AnnotatedSpan& a, const AnnotatedSpan& b) { - if (a.span.first != b.span.first) { - return a.span.first < b.span.first; - } + std::stable_sort(candidates->begin(), candidates->end(), + [](const AnnotatedSpan& a, const AnnotatedSpan& b) { + if (a.span.first != b.span.first) { + return a.span.first < b.span.first; + } - if (a.span.second != b.span.second) { - return a.span.second < b.span.second; - } + if (a.span.second != b.span.second) { + return a.span.second < b.span.second; + } - return a.classification[0].collection < - b.classification[0].collection; - }); + return a.classification[0].collection < + b.classification[0].collection; + }); std::vector<int> candidate_indices; if (!ResolveConflicts(*candidates, context, tokens, @@ -2904,10 +2907,10 @@ bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest, return false; } } - std::sort(scored_chunks.rbegin(), scored_chunks.rend(), - [](const ScoredChunk& lhs, const ScoredChunk& rhs) { - return lhs.score < rhs.score; - }); + std::stable_sort(scored_chunks.rbegin(), scored_chunks.rend(), + [](const ScoredChunk& lhs, const ScoredChunk& rhs) { + return lhs.score < rhs.score; + }); // Traverse the candidate chunks from highest-scoring to lowest-scoring. Pick // them greedily as long as they do not overlap with any previously picked @@ -2936,7 +2939,7 @@ bool Annotator::ModelChunk(int num_tokens, const TokenSpan& span_of_interest, chunks->push_back(scored_chunk.token_span); } - std::sort(chunks->begin(), chunks->end()); + std::stable_sort(chunks->begin(), chunks->end()); return true; } diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc index 7d5f440..ff0c775 100644 --- a/native/annotator/datetime/datetime-grounder.cc +++ b/native/annotator/datetime/datetime-grounder.cc @@ -16,6 +16,7 @@ #include "annotator/datetime/datetime-grounder.h" +#include <algorithm> #include <limits> #include <unordered_map> #include <vector> @@ -250,10 +251,10 @@ StatusOr<std::vector<DatetimeParseResult>> DatetimeGrounder::Ground( } // Sort the date time units by component type. - std::sort(date_components.begin(), date_components.end(), - [](DatetimeComponent a, DatetimeComponent b) { - return a.component_type > b.component_type; - }); + std::stable_sort(date_components.begin(), date_components.end(), + [](DatetimeComponent a, DatetimeComponent b) { + return a.component_type > b.component_type; + }); result.datetime_components.swap(date_components); datetime_parse_result.push_back(result); } diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc index 867c886..94a0961 100644 --- a/native/annotator/datetime/extractor.cc +++ b/native/annotator/datetime/extractor.cc @@ -16,6 +16,8 @@ #include "annotator/datetime/extractor.h" +#include <algorithm> + #include "annotator/datetime/utils.h" #include "annotator/model_generated.h" #include "annotator/types.h" @@ -347,10 +349,11 @@ bool DatetimeExtractor::ParseWrittenNumber(const UnicodeText& input, } } - std::sort(found_numbers.begin(), found_numbers.end(), - [](const std::pair<int, int>& a, const std::pair<int, int>& b) { - return a.first < b.first; - }); + std::stable_sort( + found_numbers.begin(), found_numbers.end(), + [](const std::pair<int, int>& a, const std::pair<int, int>& b) { + return a.first < b.first; + }); int sum = 0; int running_value = -1; diff --git a/native/annotator/datetime/regex-parser.cc b/native/annotator/datetime/regex-parser.cc index 4dc9c56..5daabd5 100644 --- a/native/annotator/datetime/regex-parser.cc +++ b/native/annotator/datetime/regex-parser.cc @@ -16,6 +16,7 @@ #include "annotator/datetime/regex-parser.h" +#include <algorithm> #include <iterator> #include <set> #include <unordered_set> @@ -191,17 +192,17 @@ StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse( // Resolve conflicts by always picking the longer span and breaking ties by // selecting the earlier entry in the list for a given locale. - std::sort(indexed_found_spans.begin(), indexed_found_spans.end(), - [](const std::pair<DatetimeParseResultSpan, int>& a, - const std::pair<DatetimeParseResultSpan, int>& b) { - if ((a.first.span.second - a.first.span.first) != - (b.first.span.second - b.first.span.first)) { - return (a.first.span.second - a.first.span.first) > - (b.first.span.second - b.first.span.first); - } else { - return a.second < b.second; - } - }); + std::stable_sort(indexed_found_spans.begin(), indexed_found_spans.end(), + [](const std::pair<DatetimeParseResultSpan, int>& a, + const std::pair<DatetimeParseResultSpan, int>& b) { + if ((a.first.span.second - a.first.span.first) != + (b.first.span.second - b.first.span.first)) { + return (a.first.span.second - a.first.span.first) > + (b.first.span.second - b.first.span.first); + } else { + return a.second < b.second; + } + }); std::vector<DatetimeParseResultSpan> results; std::vector<DatetimeParseResultSpan> resolved_found_spans; @@ -394,10 +395,10 @@ bool RegexDatetimeParser::ExtractDatetime( } // Sort the date time units by component type. - std::sort(date_components.begin(), date_components.end(), - [](DatetimeComponent a, DatetimeComponent b) { - return a.component_type > b.component_type; - }); + std::stable_sort(date_components.begin(), date_components.end(), + [](DatetimeComponent a, DatetimeComponent b) { + return a.component_type > b.component_type; + }); result.datetime_components.swap(date_components); results->push_back(result); } diff --git a/native/annotator/translate/translate.cc b/native/annotator/translate/translate.cc index 640ceec..2c5a43c 100644 --- a/native/annotator/translate/translate.cc +++ b/native/annotator/translate/translate.cc @@ -16,6 +16,7 @@ #include "annotator/translate/translate.h" +#include <algorithm> #include <memory> #include "annotator/collections.h" @@ -142,11 +143,11 @@ TranslateAnnotator::BackoffDetectLanguages( result.push_back({key, value}); } - std::sort(result.begin(), result.end(), - [](TranslateAnnotator::LanguageConfidence& a, - TranslateAnnotator::LanguageConfidence& b) { - return a.confidence > b.confidence; - }); + std::stable_sort(result.begin(), result.end(), + [](const TranslateAnnotator::LanguageConfidence& a, + const TranslateAnnotator::LanguageConfidence& b) { + return a.confidence > b.confidence; + }); return result; } diff --git a/native/lang_id/common/embedding-network.cc b/native/lang_id/common/embedding-network.cc index 469cb1f..49c9ca0 100644 --- a/native/lang_id/common/embedding-network.cc +++ b/native/lang_id/common/embedding-network.cc @@ -16,6 +16,8 @@ #include "lang_id/common/embedding-network.h" +#include <vector> + #include "lang_id/common/lite_base/integral-types.h" #include "lang_id/common/lite_base/logging.h" diff --git a/native/lang_id/common/fel/feature-extractor.cc b/native/lang_id/common/fel/feature-extractor.cc index ab8a1a6..4e304fe 100644 --- a/native/lang_id/common/fel/feature-extractor.cc +++ b/native/lang_id/common/fel/feature-extractor.cc @@ -17,6 +17,7 @@ #include "lang_id/common/fel/feature-extractor.h" #include <string> +#include <vector> #include "lang_id/common/fel/feature-types.h" #include "lang_id/common/fel/fel-parser.h" diff --git a/native/lang_id/common/fel/workspace.cc b/native/lang_id/common/fel/workspace.cc index af41e29..60dcc46 100644 --- a/native/lang_id/common/fel/workspace.cc +++ b/native/lang_id/common/fel/workspace.cc @@ -18,6 +18,7 @@ #include <atomic> #include <string> +#include <vector> namespace libtextclassifier3 { namespace mobile { diff --git a/native/lang_id/common/fel/workspace.h b/native/lang_id/common/fel/workspace.h index f13d802..2ac5b26 100644 --- a/native/lang_id/common/fel/workspace.h +++ b/native/lang_id/common/fel/workspace.h @@ -23,6 +23,7 @@ #include <stddef.h> +#include <algorithm> #include <string> #include <unordered_map> #include <utility> diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc index 19afcc4..fc925ea 100644 --- a/native/lang_id/common/file/mmap.cc +++ b/native/lang_id/common/file/mmap.cc @@ -29,6 +29,8 @@ #endif #include <sys/stat.h> +#include <string> + #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/lite_base/macros.h" diff --git a/native/lang_id/common/lite_strings/str-split.cc b/native/lang_id/common/lite_strings/str-split.cc index 199bb69..d227eec 100644 --- a/native/lang_id/common/lite_strings/str-split.cc +++ b/native/lang_id/common/lite_strings/str-split.cc @@ -16,6 +16,8 @@ #include "lang_id/common/lite_strings/str-split.h" +#include <vector> + namespace libtextclassifier3 { namespace mobile { diff --git a/native/lang_id/common/math/softmax.cc b/native/lang_id/common/math/softmax.cc index 750341d..249ed57 100644 --- a/native/lang_id/common/math/softmax.cc +++ b/native/lang_id/common/math/softmax.cc @@ -17,6 +17,7 @@ #include "lang_id/common/math/softmax.h" #include <algorithm> +#include <vector> #include "lang_id/common/lite_base/logging.h" #include "lang_id/common/math/fastexp.h" diff --git a/native/lang_id/fb_model/lang-id-from-fb.cc b/native/lang_id/fb_model/lang-id-from-fb.cc index dc36fb7..51c8c47 100644 --- a/native/lang_id/fb_model/lang-id-from-fb.cc +++ b/native/lang_id/fb_model/lang-id-from-fb.cc @@ -16,7 +16,9 @@ #include "lang_id/fb_model/lang-id-from-fb.h" +#include <memory> #include <string> +#include <utility> #include "lang_id/fb_model/model-provider-from-fb.h" diff --git a/native/lang_id/fb_model/model-provider-from-fb.cc b/native/lang_id/fb_model/model-provider-from-fb.cc index 43bf860..d14d403 100644 --- a/native/lang_id/fb_model/model-provider-from-fb.cc +++ b/native/lang_id/fb_model/model-provider-from-fb.cc @@ -16,7 +16,9 @@ #include "lang_id/fb_model/model-provider-from-fb.h" +#include <memory> #include <string> +#include <utility> #include "lang_id/common/file/file-utils.h" #include "lang_id/common/file/mmap.h" diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc index 92359a9..f7c66f7 100644 --- a/native/lang_id/lang-id.cc +++ b/native/lang_id/lang-id.cc @@ -21,6 +21,7 @@ #include <memory> #include <string> #include <unordered_map> +#include <utility> #include <vector> #include "lang_id/common/embedding-feature-interface.h" diff --git a/native/utils/codepoint-range.cc b/native/utils/codepoint-range.cc index e26b160..a4cd485 100644 --- a/native/utils/codepoint-range.cc +++ b/native/utils/codepoint-range.cc @@ -31,10 +31,11 @@ void SortCodepointRanges( CodepointRangeStruct(range->start(), range->end())); } - std::sort(sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(), - [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) { - return a.start < b.start; - }); + std::stable_sort( + sorted_codepoint_ranges->begin(), sorted_codepoint_ranges->end(), + [](const CodepointRangeStruct& a, const CodepointRangeStruct& b) { + return a.start < b.start; + }); } // Returns true if given codepoint is covered by the given sorted vector of diff --git a/native/utils/grammar/parsing/parser.cc b/native/utils/grammar/parsing/parser.cc index 4e39a98..a9e99ba 100644 --- a/native/utils/grammar/parsing/parser.cc +++ b/native/utils/grammar/parsing/parser.cc @@ -16,6 +16,7 @@ #include "utils/grammar/parsing/parser.h" +#include <algorithm> #include <unordered_map> #include "utils/grammar/parsing/parse-tree.h" @@ -177,14 +178,14 @@ std::vector<Symbol> Parser::SortedSymbolsForInput(const TextContext& input, } } - std::sort(symbols.begin(), symbols.end(), - [](const Symbol& a, const Symbol& b) { - // Sort by increasing (end, start) position to guarantee the - // matcher requirement that the tokens are fed in non-decreasing - // end position order. - return std::tie(a.codepoint_span.second, a.codepoint_span.first) < - std::tie(b.codepoint_span.second, b.codepoint_span.first); - }); + std::stable_sort( + symbols.begin(), symbols.end(), [](const Symbol& a, const Symbol& b) { + // Sort by increasing (end, start) position to guarantee the + // matcher requirement that the tokens are fed in non-decreasing + // end position order. + return std::tie(a.codepoint_span.second, a.codepoint_span.first) < + std::tie(b.codepoint_span.second, b.codepoint_span.first); + }); return symbols; } diff --git a/native/utils/grammar/utils/ir.cc b/native/utils/grammar/utils/ir.cc index dd29e3c..c134550 100644 --- a/native/utils/grammar/utils/ir.cc +++ b/native/utils/grammar/utils/ir.cc @@ -16,6 +16,8 @@ #include "utils/grammar/utils/ir.h" +#include <algorithm> + #include "utils/i18n/locale.h" #include "utils/strings/append.h" #include "utils/strings/stringpiece.h" @@ -28,14 +30,16 @@ constexpr size_t kMaxHashTableSize = 100; template <typename T> void SortForBinarySearchLookup(T* entries) { - std::sort(entries->begin(), entries->end(), - [](const auto& a, const auto& b) { return a->key < b->key; }); + std::stable_sort( + entries->begin(), entries->end(), + [](const auto& a, const auto& b) { return a->key < b->key; }); } template <typename T> void SortStructsForBinarySearchLookup(T* entries) { - std::sort(entries->begin(), entries->end(), - [](const auto& a, const auto& b) { return a.key() < b.key(); }); + std::stable_sort( + entries->begin(), entries->end(), + [](const auto& a, const auto& b) { return a.key() < b.key(); }); } bool IsSameLhs(const Ir::Lhs& lhs, const RulesSet_::Lhs& other) { @@ -76,13 +80,14 @@ bool IsSameLhsSet(const Ir::LhsSet& lhs_set, Ir::LhsSet SortedLhsSet(const Ir::LhsSet& lhs_set) { Ir::LhsSet sorted_lhs = lhs_set; - std::sort(sorted_lhs.begin(), sorted_lhs.end(), - [](const Ir::Lhs& a, const Ir::Lhs& b) { - return std::tie(a.nonterminal, a.callback.id, a.callback.param, - a.preconditions.max_whitespace_gap) < - std::tie(b.nonterminal, b.callback.id, b.callback.param, - b.preconditions.max_whitespace_gap); - }); + std::stable_sort( + sorted_lhs.begin(), sorted_lhs.end(), + [](const Ir::Lhs& a, const Ir::Lhs& b) { + return std::tie(a.nonterminal, a.callback.id, a.callback.param, + a.preconditions.max_whitespace_gap) < + std::tie(b.nonterminal, b.callback.id, b.callback.param, + b.preconditions.max_whitespace_gap); + }); return lhs_set; } @@ -300,10 +305,10 @@ void Ir::SerializeTerminalRules( TerminalEntry{it.first, /*set_index=*/i, /*index=*/0, it.second}); } } - std::sort(terminal_rules.begin(), terminal_rules.end(), - [](const TerminalEntry& a, const TerminalEntry& b) { - return a.terminal < b.terminal; - }); + std::stable_sort(terminal_rules.begin(), terminal_rules.end(), + [](const TerminalEntry& a, const TerminalEntry& b) { + return a.terminal < b.terminal; + }); // Index the entries in sorted order. std::vector<int> index(terminal_rules_sets.size(), 0); diff --git a/native/utils/grammar/utils/locale-shard-map.cc b/native/utils/grammar/utils/locale-shard-map.cc index e6db06d..141ce5d 100644 --- a/native/utils/grammar/utils/locale-shard-map.cc +++ b/native/utils/grammar/utils/locale-shard-map.cc @@ -40,8 +40,8 @@ std::vector<Locale> LocaleTagsToLocaleList(const std::string& locale_tags) { locale_list.emplace_back(locale); } } - std::sort(locale_list.begin(), locale_list.end(), - [](const Locale& a, const Locale& b) { return a < b; }); + std::stable_sort(locale_list.begin(), locale_list.end(), + [](const Locale& a, const Locale& b) { return a < b; }); return locale_list; } diff --git a/native/utils/testing/test_data_generator.h b/native/utils/testing/test_data_generator.h index 30c7aed..c23b5dc 100644 --- a/native/utils/testing/test_data_generator.h +++ b/native/utils/testing/test_data_generator.h @@ -20,6 +20,7 @@ #include <algorithm> #include <iostream> #include <random> +#include <string> #include "utils/strings/stringpiece.h" @@ -35,6 +36,18 @@ class TestDataGenerator { return dist(random_engine_); } + template <> + bool generate() { + std::bernoulli_distribution dist(0.5); + return dist(random_engine_); + } + + template <> + char generate() { + std::uniform_int_distribution<int> dist(0, 25); + return dist(random_engine_) + 'a'; + } + template <typename T, typename std::enable_if_t< std::is_floating_point<T>::value>* = nullptr> T generate() { diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc index 463d910..644dde8 100644 --- a/native/utils/tflite-model-executor.cc +++ b/native/utils/tflite-model-executor.cc @@ -27,6 +27,8 @@ namespace builtin { TfLiteRegistration* Register_ADD(); TfLiteRegistration* Register_CONCATENATION(); TfLiteRegistration* Register_CONV_2D(); +TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); +TfLiteRegistration* Register_AVERAGE_POOL_2D(); TfLiteRegistration* Register_EQUAL(); TfLiteRegistration* Register_FULLY_CONNECTED(); TfLiteRegistration* Register_GREATER_EQUAL(); @@ -89,7 +91,9 @@ TfLiteRegistration* Register_GREATER(); #include "utils/tflite/dist_diversification.h" #include "utils/tflite/string_projection.h" #include "utils/tflite/text_encoder.h" +#include "utils/tflite/text_encoder3s.h" #include "utils/tflite/token_encoder.h" + namespace tflite { namespace ops { namespace custom { @@ -114,6 +118,14 @@ void RegisterSelectedOps(tflite::MutableOpResolver* resolver) { tflite::ops::builtin::Register_CONV_2D(), /*min_version=*/1, /*max_version=*/5); + resolver->AddBuiltin(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, + tflite::ops::builtin::Register_DEPTHWISE_CONV_2D(), + /*min_version=*/1, + /*max_version=*/6); + resolver->AddBuiltin(tflite::BuiltinOperator_AVERAGE_POOL_2D, + tflite::ops::builtin::Register_AVERAGE_POOL_2D(), + /*min_version=*/1, + /*max_version=*/1); resolver->AddBuiltin(::tflite::BuiltinOperator_EQUAL, ::tflite::ops::builtin::Register_EQUAL()); @@ -289,6 +301,8 @@ std::unique_ptr<tflite::OpResolver> BuildOpResolver( tflite::ops::custom::Register_DISTANCE_DIVERSIFICATION()); resolver->AddCustom("TextEncoder", tflite::ops::custom::Register_TEXT_ENCODER()); + resolver->AddCustom("TextEncoder3S", + tflite::ops::custom::Register_TEXT_ENCODER3S()); resolver->AddCustom("TokenEncoder", tflite::ops::custom::Register_TOKEN_ENCODER()); resolver->AddCustom( diff --git a/native/utils/tflite/encoder_common.cc b/native/utils/tflite/encoder_common.cc index 8f9f2a8..eb319f9 100644 --- a/native/utils/tflite/encoder_common.cc +++ b/native/utils/tflite/encoder_common.cc @@ -58,6 +58,11 @@ TfLiteStatus CopyValuesToTensorAndPadOrTruncate( out->data.i32 + output_offset + from_this_element, in.data.i32[value_index]); } break; + case kTfLiteInt64: { + std::fill(out->data.i64 + output_offset, + out->data.i64 + output_offset + from_this_element, + in.data.i64[value_index]); + } break; case kTfLiteFloat32: { std::fill(out->data.f + output_offset, out->data.f + output_offset + from_this_element, @@ -78,6 +83,12 @@ TfLiteStatus CopyValuesToTensorAndPadOrTruncate( std::fill(out->data.i32 + output_offset, out->data.i32 + output_size, value); } break; + case kTfLiteInt64: { + const int64_t value = + (output_offset > 0) ? out->data.i64[output_offset - 1] : 0; + std::fill(out->data.i64 + output_offset, out->data.i64 + output_size, + value); + } break; case kTfLiteFloat32: { const float value = (output_offset > 0) ? out->data.f[output_offset - 1] : 0; diff --git a/native/utils/tflite/text_encoder3s.cc b/native/utils/tflite/text_encoder3s.cc new file mode 100644 index 0000000..0b5e65b --- /dev/null +++ b/native/utils/tflite/text_encoder3s.cc @@ -0,0 +1,243 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/tflite/text_encoder3s.h" + +#include <memory> +#include <vector> + +#include "utils/base/logging.h" +#include "utils/strings/stringpiece.h" +#include "utils/tflite/encoder_common.h" +#include "utils/tflite/text_encoder_config_generated.h" +#include "utils/tokenfree/byte_encoder.h" +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/flexbuffers.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/model.h" +#include "tensorflow/lite/string_util.h" + +namespace libtextclassifier3 { +namespace { + +// Input parameters for the op. +constexpr int kInputTextInd = 0; + +constexpr int kTextLengthInd = 1; +constexpr int kMaxLengthInd = 2; +constexpr int kInputAttrInd = 3; + +// Output parameters for the op. +constexpr int kOutputEncodedInd = 0; +constexpr int kOutputPositionInd = 1; +constexpr int kOutputLengthsInd = 2; +constexpr int kOutputAttrInd = 3; + +// Initializes text encoder object from serialized parameters. +void* Initialize(TfLiteContext* context, const char* buffer, size_t length) { + std::unique_ptr<ByteEncoder> encoder(new ByteEncoder()); + return encoder.release(); +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<ByteEncoder*>(buffer); +} + +namespace { +TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node, + int max_output_length) { + TfLiteTensor& output_encoded = + context->tensors[node->outputs->data[kOutputEncodedInd]]; + + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output_encoded, + CreateIntArray({kEncoderBatchSize, max_output_length}))); + TfLiteTensor& output_positions = + context->tensors[node->outputs->data[kOutputPositionInd]]; + + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output_positions, + CreateIntArray({kEncoderBatchSize, max_output_length}))); + + const int num_output_attrs = node->outputs->size - kOutputAttrInd; + for (int i = 0; i < num_output_attrs; ++i) { + TfLiteTensor& output = + context->tensors[node->outputs->data[kOutputAttrInd + i]]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor( + context, &output, + CreateIntArray({kEncoderBatchSize, max_output_length}))); + } + return kTfLiteOk; +} +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check that the batch dimension is kEncoderBatchSize. + const TfLiteTensor& input_text = + context->tensors[node->inputs->data[kInputTextInd]]; + TF_LITE_ENSURE_EQ(context, input_text.dims->size, kEncoderInputRank); + TF_LITE_ENSURE_EQ(context, input_text.dims->data[0], kEncoderBatchSize); + + TfLiteTensor& output_lengths = + context->tensors[node->outputs->data[kOutputLengthsInd]]; + + TfLiteTensor& output_encoded = + context->tensors[node->outputs->data[kOutputEncodedInd]]; + TfLiteTensor& output_positions = + context->tensors[node->outputs->data[kOutputPositionInd]]; + output_encoded.type = kTfLiteInt32; + output_positions.type = kTfLiteInt32; + output_lengths.type = kTfLiteInt32; + + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, &output_lengths, + CreateIntArray({kEncoderBatchSize}))); + + // Check that there are enough outputs for attributes. + const int num_output_attrs = node->outputs->size - kOutputAttrInd; + TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd, + num_output_attrs); + + // Copy attribute types from input to output tensors. + for (int i = 0; i < num_output_attrs; ++i) { + TfLiteTensor& input = + context->tensors[node->inputs->data[kInputAttrInd + i]]; + TfLiteTensor& output = + context->tensors[node->outputs->data[kOutputAttrInd + i]]; + output.type = input.type; + } + + const TfLiteTensor& output_length = + context->tensors[node->inputs->data[kMaxLengthInd]]; + + if (tflite::IsConstantTensor(&output_length)) { + return ResizeOutputTensors(context, node, output_length.data.i64[0]); + } else { + tflite::SetTensorToDynamic(&output_encoded); + tflite::SetTensorToDynamic(&output_positions); + for (int i = 0; i < num_output_attrs; ++i) { + TfLiteTensor& output_attr = + context->tensors[node->outputs->data[kOutputAttrInd + i]]; + tflite::SetTensorToDynamic(&output_attr); + } + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + if (node->user_data == nullptr) { + return kTfLiteError; + } + auto text_encoder = reinterpret_cast<ByteEncoder*>(node->user_data); + const TfLiteTensor& input_text = + context->tensors[node->inputs->data[kInputTextInd]]; + const int num_strings_in_tensor = tflite::GetStringCount(&input_text); + const int num_strings = + context->tensors[node->inputs->data[kTextLengthInd]].data.i32[0]; + + // Check that the number of strings is not bigger than the input tensor size. + TF_LITE_ENSURE(context, num_strings_in_tensor >= num_strings); + + TfLiteTensor& output_encoded = + context->tensors[node->outputs->data[kOutputEncodedInd]]; + if (tflite::IsDynamicTensor(&output_encoded)) { + const TfLiteTensor& output_length = + context->tensors[node->inputs->data[kMaxLengthInd]]; + TF_LITE_ENSURE_OK( + context, ResizeOutputTensors(context, node, output_length.data.i64[0])); + } + TfLiteTensor& output_positions = + context->tensors[node->outputs->data[kOutputPositionInd]]; + + std::vector<int> encoded_total; + std::vector<int> encoded_positions; + std::vector<int> encoded_offsets; + encoded_offsets.reserve(num_strings); + const int max_output_length = output_encoded.dims->data[1]; + const int max_encoded_position = max_output_length; + + for (int i = 0; i < num_strings; ++i) { + const auto& strref = tflite::GetString(&input_text, i); + std::vector<int64_t> encoded; + text_encoder->Encode( + libtextclassifier3::StringPiece(strref.str, strref.len), &encoded); + encoded_total.insert(encoded_total.end(), encoded.begin(), encoded.end()); + encoded_offsets.push_back(encoded_total.size()); + for (int i = 0; i < encoded.size(); ++i) { + encoded_positions.push_back(std::min(i, max_encoded_position - 1)); + } + } + + // Copy encoding to output tensor. + const int start_offset = + std::max(0, static_cast<int>(encoded_total.size()) - max_output_length); + int output_offset = 0; + int32_t* output_buffer = output_encoded.data.i32; + int32_t* output_positions_buffer = output_positions.data.i32; + for (int i = start_offset; i < encoded_total.size(); ++i, ++output_offset) { + output_buffer[output_offset] = encoded_total[i]; + output_positions_buffer[output_offset] = encoded_positions[i]; + } + + // Save output encoded length. + TfLiteTensor& output_lengths = + context->tensors[node->outputs->data[kOutputLengthsInd]]; + output_lengths.data.i32[0] = output_offset; + + // Do padding. + for (; output_offset < max_output_length; ++output_offset) { + output_buffer[output_offset] = 0; + output_positions_buffer[output_offset] = 0; + } + + // Process attributes, all checks of sizes and types are done in Prepare. + const int num_output_attrs = node->outputs->size - kOutputAttrInd; + TF_LITE_ENSURE_EQ(context, node->inputs->size - kInputAttrInd, + num_output_attrs); + for (int i = 0; i < num_output_attrs; ++i) { + TfLiteStatus attr_status = CopyValuesToTensorAndPadOrTruncate( + context->tensors[node->inputs->data[kInputAttrInd + i]], + encoded_offsets, start_offset, context, + &context->tensors[node->outputs->data[kOutputAttrInd + i]]); + if (attr_status != kTfLiteOk) { + return attr_status; + } + } + + return kTfLiteOk; +} + +} // namespace +} // namespace libtextclassifier3 + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_TEXT_ENCODER3S() { + static TfLiteRegistration registration = { + libtextclassifier3::Initialize, libtextclassifier3::Free, + libtextclassifier3::Prepare, libtextclassifier3::Eval}; + return ®istration; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/native/utils/tflite/text_encoder3s.h b/native/utils/tflite/text_encoder3s.h new file mode 100644 index 0000000..50e1e64 --- /dev/null +++ b/native/utils/tflite/text_encoder3s.h @@ -0,0 +1,35 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// An encoder that produces positional and attributes encodings for a +// transformer style model based on byte segmentation of text. + +#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_ +#define LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_ + +#include "tensorflow/lite/context.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_TEXT_ENCODER3S(); + +} // namespace custom +} // namespace ops +} // namespace tflite + +#endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_TEXT_ENCODER3S_H_ diff --git a/native/utils/tokenfree/byte_encoder.cc b/native/utils/tokenfree/byte_encoder.cc new file mode 100644 index 0000000..c79d3a2 --- /dev/null +++ b/native/utils/tokenfree/byte_encoder.cc @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/tokenfree/byte_encoder.h" + +#include <vector> +namespace libtextclassifier3 { + +bool ByteEncoder::Encode(StringPiece input_text, + std::vector<int64_t>* encoded_text) const { + const int len = input_text.size(); + if (len <= 0) { + *encoded_text = {}; + return true; + } + + int size = input_text.size(); + encoded_text->resize(size); + + const auto& text = input_text.ToString(); + for (int i = 0; i < size; i++) { + int64_t encoding = static_cast<int64_t>(text[i]); + (*encoded_text)[i] = encoding; + } + + return true; +} + +} // namespace libtextclassifier3 diff --git a/native/utils/tokenfree/byte_encoder.h b/native/utils/tokenfree/byte_encoder.h new file mode 100644 index 0000000..1a495ec --- /dev/null +++ b/native/utils/tokenfree/byte_encoder.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_ +#define LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_ + +#include <vector> + +#include "utils/base/logging.h" +#include "utils/container/string-set.h" +#include "utils/strings/stringpiece.h" + +namespace libtextclassifier3 { + +// Encoder to segment/tokenize strings into bytes +class ByteEncoder { + public: + bool Encode(StringPiece input_text, std::vector<int64_t>* encoded_text) const; + ByteEncoder() {} +}; + +} // namespace libtextclassifier3 + +#endif // LIBTEXTCLASSIFIER_UTILS_TOKENFREE_BYTE_ENCODER_H_ diff --git a/native/utils/tokenfree/byte_encoder_test.cc b/native/utils/tokenfree/byte_encoder_test.cc new file mode 100644 index 0000000..d4d119e --- /dev/null +++ b/native/utils/tokenfree/byte_encoder_test.cc @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/tokenfree/byte_encoder.h" + +#include <memory> +#include <vector> + +#include "utils/base/integral_types.h" +#include "utils/container/sorted-strings-table.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace libtextclassifier3 { +namespace { + +using testing::ElementsAre; + +TEST(EncoderTest, SimpleTokenization) { + const ByteEncoder encoder; + { + std::vector<int64_t> encoded_text; + EXPECT_TRUE(encoder.Encode("hellothere", &encoded_text)); + EXPECT_THAT(encoded_text, + ElementsAre(104, 101, 108, 108, 111, 116, 104, 101, 114, 101)); + } +} + +TEST(EncoderTest, SimpleTokenization2) { + const ByteEncoder encoder; + { + std::vector<int64_t> encoded_text; + EXPECT_TRUE(encoder.Encode("Hello", &encoded_text)); + EXPECT_THAT(encoded_text, ElementsAre(72, 101, 108, 108, 111)); + } +} +} // namespace +} // namespace libtextclassifier3 diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc index 071141c..7038517 100644 --- a/native/utils/tokenizer.cc +++ b/native/utils/tokenizer.cc @@ -43,11 +43,12 @@ Tokenizer::Tokenizer( codepoint_ranges_.emplace_back(range->UnPack()); } - std::sort(codepoint_ranges_.begin(), codepoint_ranges_.end(), - [](const std::unique_ptr<const TokenizationCodepointRangeT>& a, - const std::unique_ptr<const TokenizationCodepointRangeT>& b) { - return a->start < b->start; - }); + std::stable_sort( + codepoint_ranges_.begin(), codepoint_ranges_.end(), + [](const std::unique_ptr<const TokenizationCodepointRangeT>& a, + const std::unique_ptr<const TokenizationCodepointRangeT>& b) { + return a->start < b->start; + }); SortCodepointRanges(internal_tokenizer_codepoint_ranges, &internal_tokenizer_codepoint_ranges_); diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java index bc30fcf..f539ba7 100644 --- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java +++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsLogSessionTest.java @@ -37,15 +37,20 @@ import androidx.test.filters.LargeTest; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.junit.Before; +import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.ArgumentCaptor; import org.mockito.Mock; -import org.mockito.MockitoAnnotations; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; @LargeTest @RunWith(AndroidJUnit4.class) public class SmartSuggestionsLogSessionTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private static final String RESULT_ID = "resultId"; private static final String REPLY = "reply"; private static final float SCORE = 0.5f; @@ -55,7 +60,6 @@ public class SmartSuggestionsLogSessionTest { @Before public void setup() { - MockitoAnnotations.initMocks(this); session = new SmartSuggestionsLogSession( |