diff options
author | Alex Saveliev <alexsav@google.com> | 2023-01-10 19:23:34 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2023-01-10 19:23:34 +0000 |
commit | 6e15dd0d337c65739900d7b95f6408d7413c8196 (patch) | |
tree | 0ba91ee8775e34738340187614b97b6c5ffcbc8c | |
parent | 48b8f6943906165ec50ebecb9551497ac6faa450 (diff) | |
parent | 947f3d55bb1871285790facda2aa76e02c27a289 (diff) | |
download | icing-6e15dd0d337c65739900d7b95f6408d7413c8196.tar.gz |
Merge remote-tracking branch 'aosp/upstream-master' into androidx-main am: 947f3d55bb
Original change: https://android-review.googlesource.com/c/platform/external/icing/+/2381052
Change-Id: I1e1a7398fee34eb07cc5e2ed7dd521ac60dae9bb
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
35 files changed, 2097 insertions, 858 deletions
diff --git a/icing/file/posting_list/posting-list-identifier.h b/icing/file/posting_list/posting-list-identifier.h index 05c7ce5..54b2888 100644 --- a/icing/file/posting_list/posting-list-identifier.h +++ b/icing/file/posting_list/posting-list-identifier.h @@ -109,7 +109,8 @@ class PostingListIdentifier { private: uint32_t val_; -}; +} __attribute__((packed)); +static_assert(sizeof(PostingListIdentifier) == 4, ""); } // namespace lib } // namespace icing diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index 33e2ca1..bf9c102 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -1191,9 +1191,13 @@ DeleteResultProto IcingSearchEngine::Delete(const std::string_view name_space, // that can support error logging. libtextclassifier3::Status status = document_store_->Delete(name_space, uri); if (!status.ok()) { - ICING_LOG(ERROR) << status.error_message() - << "Failed to delete Document. namespace: " << name_space - << ", uri: " << uri; + LogSeverity::Code severity = ERROR; + if (absl_ports::IsNotFound(status)) { + severity = DBG; + } + ICING_LOG(severity) << status.error_message() + << "Failed to delete Document. namespace: " + << name_space << ", uri: " << uri; TransformStatus(status, result_status); return result_proto; } diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc index e7158ad..dff38fb 100644 --- a/icing/icing-search-engine_test.cc +++ b/icing/icing-search-engine_test.cc @@ -10581,6 +10581,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { .AddStringProperty("lastName", "last1") .AddStringProperty("emailAddress", "email1@gmail.com") .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .SetScore(1) .Build(); DocumentProto person2 = DocumentBuilder() @@ -10590,6 +10591,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { .AddStringProperty("lastName", "last2") .AddStringProperty("emailAddress", "email2@gmail.com") .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .SetScore(2) .Build(); DocumentProto person3 = DocumentBuilder() @@ -10599,6 +10601,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { .AddStringProperty("lastName", "last3") .AddStringProperty("emailAddress", "email3@gmail.com") .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .SetScore(3) .Build(); DocumentProto email1 = @@ -10608,6 +10611,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { .AddStringProperty("subject", "test subject 1") .AddStringProperty("personQualifiedId", "pkg$db/namespace#person1") .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .SetScore(3) .Build(); DocumentProto email2 = DocumentBuilder() @@ -10616,6 +10620,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { .AddStringProperty("subject", "test subject 2") .AddStringProperty("personQualifiedId", "pkg$db/namespace#person2") .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .SetScore(2) .Build(); DocumentProto email3 = DocumentBuilder() @@ -10625,6 +10630,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { .AddStringProperty("personQualifiedId", R"(pkg$db/name\#space\\\\#person3)") // escaped .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .SetScore(1) .Build(); IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); @@ -10644,12 +10650,12 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { // JoinSpec JoinSpecProto* join_spec = search_spec.mutable_join_spec(); - // Set max_joined_child_count as 2, so only email 3, email2 will be included - // in the nested result and email1 will be truncated. - join_spec->set_max_joined_child_count(2); + join_spec->set_max_joined_child_count(100); join_spec->set_parent_property_expression( std::string(JoinProcessor::kQualifiedIdExpr)); join_spec->set_child_property_expression("personQualifiedId"); + join_spec->set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::MAX); JoinSpecProto::NestedSpecProto* nested_spec = join_spec->mutable_nested_spec(); SearchSpecProto* nested_search_spec = nested_spec->mutable_search_spec(); @@ -10665,12 +10671,20 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { ResultSpecProto result_spec; result_spec.set_num_per_page(1); + // Since we: + // - Use MAX for aggregation scoring strategy. + // - (Default) use DOCUMENT_SCORE to score child documents. + // - (Default) use DESC as the ranking order. + // + // person1 + email1 should have the highest aggregated score (3) and be + // returned first. person2 + email2 (aggregated score = 2) should be the + // second, and person3 + email3 (aggregated score = 1) should be the last. SearchResultProto expected_result1; expected_result1.mutable_status()->set_code(StatusProto::OK); SearchResultProto::ResultProto* result_proto1 = expected_result1.mutable_results()->Add(); - *result_proto1->mutable_document() = person3; - *result_proto1->mutable_joined_results()->Add()->mutable_document() = email3; + *result_proto1->mutable_document() = person1; + *result_proto1->mutable_joined_results()->Add()->mutable_document() = email1; SearchResultProto expected_result2; expected_result2.mutable_status()->set_code(StatusProto::OK); @@ -10683,8 +10697,8 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { expected_result3.mutable_status()->set_code(StatusProto::OK); SearchResultProto::ResultProto* result_proto3 = expected_result3.mutable_results()->Add(); - *result_proto3->mutable_document() = person1; - *result_proto3->mutable_joined_results()->Add()->mutable_document() = email1; + *result_proto3->mutable_document() = person3; + *result_proto3->mutable_joined_results()->Add()->mutable_document() = email3; SearchResultProto result1 = icing.Search(search_spec, scoring_spec, result_spec); @@ -10708,133 +10722,6 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { EqualsSearchResultIgnoreStatsAndScores(expected_result3)); } -TEST_F(IcingSearchEngineTest, InvalidJoins) { - SchemaProto schema = - SchemaBuilder() - .AddType(SchemaTypeConfigBuilder() - .SetType("Person") - .AddProperty(PropertyConfigBuilder() - .SetName("firstName") - .SetDataTypeString(TERM_MATCH_PREFIX, - TOKENIZER_PLAIN) - .SetCardinality(CARDINALITY_OPTIONAL)) - .AddProperty(PropertyConfigBuilder() - .SetName("lastName") - .SetDataTypeString(TERM_MATCH_PREFIX, - TOKENIZER_PLAIN) - .SetCardinality(CARDINALITY_OPTIONAL)) - .AddProperty(PropertyConfigBuilder() - .SetName("emailAddress") - .SetDataTypeString(TERM_MATCH_PREFIX, - TOKENIZER_PLAIN) - .SetCardinality(CARDINALITY_OPTIONAL))) - .AddType(SchemaTypeConfigBuilder().SetType("Email").AddProperty( - PropertyConfigBuilder() - .SetName("subjectId") - .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) - .SetCardinality(CARDINALITY_OPTIONAL))) - .Build(); - - DocumentProto person1 = - DocumentBuilder() - .SetKey("pkg$db/namespace", "person1") - .SetSchema("Person") - .AddStringProperty("firstName", "first1") - .AddStringProperty("lastName", "last1") - .AddStringProperty("emailAddress", "email1@gmail.com") - .SetCreationTimestampMs(kDefaultCreationTimestampMs) - .Build(); - DocumentProto person2 = - DocumentBuilder() - .SetKey("pkg$db/namespace\\", "person2") - .SetSchema("Person") - .AddStringProperty("firstName", "first2") - .AddStringProperty("lastName", "last2") - .AddStringProperty("emailAddress", "email2@gmail.com") - .SetCreationTimestampMs(kDefaultCreationTimestampMs) - .Build(); - - // "invalid format" does not refer to any document, so it will not be joined - // to any document. - DocumentProto email1 = - DocumentBuilder() - .SetKey("namespace", "email1") - .SetSchema("Email") - .AddStringProperty("subjectId", "invalid format") - .SetCreationTimestampMs(kDefaultCreationTimestampMs) - .Build(); - // This will not be joined because the # in the subjectId is escaped. - DocumentProto email2 = - DocumentBuilder() - .SetKey("namespace", "email2") - .SetSchema("Email") - .AddStringProperty("subjectId", "pkg$db/namespace\\#person2") - .SetCreationTimestampMs(kDefaultCreationTimestampMs) - .Build(); - - IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); - ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); - ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); - ASSERT_THAT(icing.Put(person1).status(), ProtoIsOk()); - ASSERT_THAT(icing.Put(person2).status(), ProtoIsOk()); - ASSERT_THAT(icing.Put(email1).status(), ProtoIsOk()); - ASSERT_THAT(icing.Put(email2).status(), ProtoIsOk()); - - // Parent SearchSpec - SearchSpecProto search_spec; - search_spec.set_term_match_type(TermMatchType::PREFIX); - search_spec.set_query("first"); - - // JoinSpec - JoinSpecProto* join_spec = search_spec.mutable_join_spec(); - // Set max_joined_child_count as 2, so only email 3, email2 will be included - // in the nested result and email1 will be truncated. - join_spec->set_max_joined_child_count(2); - join_spec->set_parent_property_expression( - std::string(JoinProcessor::kQualifiedIdExpr)); - join_spec->set_child_property_expression("subjectId"); - JoinSpecProto::NestedSpecProto* nested_spec = - join_spec->mutable_nested_spec(); - SearchSpecProto* nested_search_spec = nested_spec->mutable_search_spec(); - nested_search_spec->set_term_match_type(TermMatchType::PREFIX); - nested_search_spec->set_query(""); - *nested_spec->mutable_scoring_spec() = GetDefaultScoringSpec(); - *nested_spec->mutable_result_spec() = ResultSpecProto::default_instance(); - - // Parent ScoringSpec - ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); - - // Parent ResultSpec - ResultSpecProto result_spec; - result_spec.set_num_per_page(1); - - SearchResultProto expected_result1; - expected_result1.mutable_status()->set_code(StatusProto::OK); - SearchResultProto::ResultProto* result_proto1 = - expected_result1.mutable_results()->Add(); - *result_proto1->mutable_document() = person2; - - SearchResultProto expected_result2; - expected_result2.mutable_status()->set_code(StatusProto::OK); - SearchResultProto::ResultProto* result_proto2 = - expected_result2.mutable_results()->Add(); - *result_proto2->mutable_document() = person1; - - SearchResultProto result1 = - icing.Search(search_spec, scoring_spec, result_spec); - uint64_t next_page_token = result1.next_page_token(); - EXPECT_THAT(next_page_token, Ne(kInvalidNextPageToken)); - expected_result1.set_next_page_token(next_page_token); - EXPECT_THAT(result1, - EqualsSearchResultIgnoreStatsAndScores(expected_result1)); - - SearchResultProto result2 = icing.GetNextPage(next_page_token); - next_page_token = result2.next_page_token(); - EXPECT_THAT(next_page_token, Eq(kInvalidNextPageToken)); - EXPECT_THAT(result2, - EqualsSearchResultIgnoreStatsAndScores(expected_result2)); -} - TEST_F(IcingSearchEngineTest, NumericFilterAdvancedQuerySucceeds) { IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); diff --git a/icing/index/iterator/doc-hit-info-iterator-and_test.cc b/icing/index/iterator/doc-hit-info-iterator-and_test.cc index e4730fe..9b9f44b 100644 --- a/icing/index/iterator/doc-hit-info-iterator-and_test.cc +++ b/icing/index/iterator/doc-hit-info-iterator-and_test.cc @@ -32,10 +32,8 @@ namespace lib { namespace { using ::testing::ElementsAre; -using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::IsEmpty; -using ::testing::SizeIs; TEST(CreateAndIteratorTest, And) { // Basic test that we can create a working And iterator. Further testing of @@ -202,23 +200,24 @@ TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) { { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. - SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ - 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2{ - 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3); doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1 = {{0, 1}, {2, 2}, {4, 3}, {6, 4}}; + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(4); doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map2 = {{1, 2}, {2, 6}}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2}; @@ -240,29 +239,25 @@ TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) { EXPECT_THAT(and_iter.doc_hit_info().document_id(), Eq(4)); and_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - EXPECT_EQ(matched_terms_stats.at(1).term, "hello"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies1)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies2)); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask2); + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("hi", expected_section_ids_tf_map1), + EqualsTermMatchInfo("hello", expected_section_ids_tf_map2))); EXPECT_FALSE(and_iter.Advance().ok()); } { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. - SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ - 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); + SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2 + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1 = {{0, 1}, {2, 2}}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info1}; @@ -284,11 +279,8 @@ TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) { EXPECT_THAT(and_iter.doc_hit_info().document_id(), Eq(4)); and_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies1)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "hi", expected_section_ids_tf_map1))); EXPECT_FALSE(and_iter.Advance().ok()); } @@ -470,37 +462,34 @@ TEST(DocHitInfoIteratorAndNaryTest, PopulateMatchedTermsStats) { // Arbitrary section ids/term frequencies for the documents in the // DocHitInfoIterators. // For term "hi", document 10 and 8 - SectionIdMask section_id_mask1_hi = 0b01000101; // hits in sections 0, 2, 6 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hi{ - 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; DocHitInfoTermFrequencyPair doc_hit_info1_hi = DocHitInfo(10); doc_hit_info1_hi.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1_hi.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1_hi = {{0, 1}, {2, 2}, {6, 4}}; DocHitInfoTermFrequencyPair doc_hit_info2_hi = DocHitInfo(8); doc_hit_info2_hi.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); // For term "hello", document 10 and 9 - SectionIdMask section_id_mask1_hello = 0b00001001; // hits in sections 0, 3 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hello{ - 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; DocHitInfoTermFrequencyPair doc_hit_info1_hello = DocHitInfo(10); doc_hit_info1_hello.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); doc_hit_info1_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1_hello = {{0, 2}, {3, 3}}; DocHitInfoTermFrequencyPair doc_hit_info2_hello = DocHitInfo(9); doc_hit_info2_hello.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/3); doc_hit_info2_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/2); // For term "ciao", document 10 and 9 - SectionIdMask section_id_mask1_ciao = 0b00000011; // hits in sections 0, 1 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_ciao{ - 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; DocHitInfoTermFrequencyPair doc_hit_info1_ciao = DocHitInfo(10); doc_hit_info1_ciao.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); doc_hit_info1_ciao.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/3); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1_ciao = {{0, 2}, {1, 3}}; DocHitInfoTermFrequencyPair doc_hit_info2_ciao = DocHitInfo(9); doc_hit_info2_ciao.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); @@ -534,19 +523,12 @@ TEST(DocHitInfoIteratorAndNaryTest, PopulateMatchedTermsStats) { EXPECT_THAT(and_iter.doc_hit_info().document_id(), Eq(10)); and_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(3)); // 3 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies1_hi)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1_hi); - EXPECT_EQ(matched_terms_stats.at(1).term, "hello"); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies1_hello)); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask1_hello); - EXPECT_EQ(matched_terms_stats.at(2).term, "ciao"); - EXPECT_THAT(matched_terms_stats.at(2).term_frequencies, - ElementsAreArray(term_frequencies1_ciao)); - EXPECT_EQ(matched_terms_stats.at(2).section_ids_mask, section_id_mask1_ciao); + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("hi", expected_section_ids_tf_map1_hi), + EqualsTermMatchInfo("hello", expected_section_ids_tf_map1_hello), + EqualsTermMatchInfo("ciao", expected_section_ids_tf_map1_ciao))); EXPECT_FALSE(and_iter.Advance().ok()); } diff --git a/icing/index/iterator/doc-hit-info-iterator-or_test.cc b/icing/index/iterator/doc-hit-info-iterator-or_test.cc index 6e6872c..f487801 100644 --- a/icing/index/iterator/doc-hit-info-iterator-or_test.cc +++ b/icing/index/iterator/doc-hit-info-iterator-or_test.cc @@ -19,7 +19,6 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "icing/index/iterator/doc-hit-info-iterator-and.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/schema/section.h" @@ -32,10 +31,8 @@ namespace lib { namespace { using ::testing::ElementsAre; -using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::IsEmpty; -using ::testing::SizeIs; TEST(CreateAndIteratorTest, Or) { // Basic test that we can create a working Or iterator. Further testing of @@ -182,22 +179,21 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) { { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. - SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ - 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2{ - 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3); doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1 = {{0, 1}, {2, 2}, {4, 3}, {6, 4}}; DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(4); doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map2 = {{1, 2}, {2, 6}}; std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2}; @@ -219,28 +215,23 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) { EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(4)); or_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - EXPECT_EQ(matched_terms_stats.at(1).term, "hello"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies1)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies2)); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask2); + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("hi", expected_section_ids_tf_map1), + EqualsTermMatchInfo("hello", expected_section_ids_tf_map2))); EXPECT_FALSE(or_iter.Advance().ok()); } { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. - SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ - 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); + SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2 + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1 = {{0, 1}, {2, 2}}; std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info1}; @@ -262,33 +253,28 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) { EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(4)); or_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies1)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); - + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "hi", expected_section_ids_tf_map1))); EXPECT_FALSE(or_iter.Advance().ok()); } { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. - SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ - 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2{ - 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3); doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1 = {{0, 1}, {2, 2}, {4, 3}, {6, 4}}; DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(5); doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map2 = {{1, 2}, {2, 6}}; std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2}; @@ -310,22 +296,17 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) { EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(5)); or_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies2)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask2); + EXPECT_THAT(matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("hello", + expected_section_ids_tf_map2))); ICING_EXPECT_OK(or_iter.Advance()); EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(4)); matched_terms_stats.clear(); or_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies1)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "hi", expected_section_ids_tf_map1))); EXPECT_FALSE(or_iter.Advance().ok()); } @@ -476,50 +457,44 @@ TEST(DocHitInfoIteratorOrNaryTest, PopulateMatchedTermsStats) { // Arbitrary section ids/term frequencies for the documents in the // DocHitInfoIterators. // For term "hi", document 10 and 8 - SectionIdMask section_id_mask1_hi = 0b01000101; // hits in sections 0, 2, 6 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hi{ - 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; DocHitInfoTermFrequencyPair doc_hit_info1_hi = DocHitInfo(10); doc_hit_info1_hi.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1_hi.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1_hi = {{0, 1}, {2, 2}, {6, 4}}; - SectionIdMask section_id_mask2_hi = 0b00000110; // hits in sections 1, 2 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2_hi{ - 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; DocHitInfoTermFrequencyPair doc_hit_info2_hi = DocHitInfo(8); doc_hit_info2_hi.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map2_hi = {{1, 2}, {2, 6}}; // For term "hello", document 10 and 9 - SectionIdMask section_id_mask1_hello = 0b00001001; // hits in sections 0, 3 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hello{ - 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; DocHitInfoTermFrequencyPair doc_hit_info1_hello = DocHitInfo(10); doc_hit_info1_hello.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); doc_hit_info1_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1_hello = {{0, 2}, {3, 3}}; - SectionIdMask section_id_mask2_hello = 0b00001100; // hits in sections 2, 3 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2_hello{ - 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; DocHitInfoTermFrequencyPair doc_hit_info2_hello = DocHitInfo(9); doc_hit_info2_hello.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/3); doc_hit_info2_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/2); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map2_hello = {{2, 3}, {3, 2}}; // For term "ciao", document 9 and 8 - SectionIdMask section_id_mask1_ciao = 0b00000011; // hits in sections 0, 1 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_ciao{ - 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; DocHitInfoTermFrequencyPair doc_hit_info1_ciao = DocHitInfo(9); doc_hit_info1_ciao.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); doc_hit_info1_ciao.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/3); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1_ciao = {{0, 2}, {1, 3}}; - SectionIdMask section_id_mask2_ciao = 0b00011000; // hits in sections 3, 4 - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2_ciao{ - 0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; DocHitInfoTermFrequencyPair doc_hit_info2_ciao = DocHitInfo(8); doc_hit_info2_ciao.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); doc_hit_info2_ciao.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/2); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map2_ciao = {{3, 3}, {4, 2}}; std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1_hi, doc_hit_info2_hi}; @@ -549,45 +524,33 @@ TEST(DocHitInfoIteratorOrNaryTest, PopulateMatchedTermsStats) { EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(10)); or_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies1_hi)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1_hi); - EXPECT_EQ(matched_terms_stats.at(1).term, "hello"); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies1_hello)); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask1_hello); + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("hi", expected_section_ids_tf_map1_hi), + EqualsTermMatchInfo("hello", expected_section_ids_tf_map1_hello))); ICING_EXPECT_OK(or_iter.Advance()); EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(9)); matched_terms_stats.clear(); or_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies2_hello)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask2_hello); - EXPECT_EQ(matched_terms_stats.at(1).term, "ciao"); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies1_ciao)); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask1_ciao); + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("hello", expected_section_ids_tf_map2_hello), + EqualsTermMatchInfo("ciao", expected_section_ids_tf_map1_ciao))); ICING_EXPECT_OK(or_iter.Advance()); EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(8)); matched_terms_stats.clear(); or_iter.PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies2_hi)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask2_hi); - EXPECT_EQ(matched_terms_stats.at(1).term, "ciao"); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies2_ciao)); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask2_ciao); + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("hi", expected_section_ids_tf_map2_hi), + EqualsTermMatchInfo("ciao", expected_section_ids_tf_map2_ciao))); EXPECT_FALSE(or_iter.Advance().ok()); } diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc b/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc index e80d8f0..6d41e90 100644 --- a/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc +++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc @@ -44,7 +44,6 @@ namespace lib { namespace { using ::testing::ElementsAre; -using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::IsEmpty; @@ -143,13 +142,10 @@ TEST_F(DocHitInfoIteratorSectionRestrictTest, expected_section_id_mask); section_restrict_iterator.PopulateMatchedTermsStats(&matched_terms_stats); - EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - std::array<Hit::TermFrequency, kTotalNumSections> expected_term_frequencies{ - 1}; - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(expected_term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, - expected_section_id_mask); + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{0, 1}}; + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "hi", expected_section_ids_tf_map))); EXPECT_FALSE(section_restrict_iterator.Advance().ok()); } diff --git a/icing/index/lite/lite-index_test.cc b/icing/index/lite/lite-index_test.cc index 7858b47..2c29640 100644 --- a/icing/index/lite/lite-index_test.cc +++ b/icing/index/lite/lite-index_test.cc @@ -20,7 +20,6 @@ #include "gtest/gtest.h" #include "icing/index/lite/doc-hit-info-iterator-term-lite.h" #include "icing/index/term-id-codec.h" -#include "icing/legacy/index/icing-mock-filesystem.h" #include "icing/schema/section.h" #include "icing/store/suggestion-result-checker.h" #include "icing/testing/always-false-suggestion-result-checker-impl.h" @@ -32,7 +31,7 @@ namespace lib { namespace { -using ::testing::ElementsAreArray; +using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::SizeIs; @@ -112,20 +111,25 @@ TEST_F(LiteIndexTest, LiteIndexIterator) { lite_index_->InsertTerm(term, TermMatchType::PREFIX, kNamespace0)); ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, term_id_codec_->EncodeTvi(tvi, TviType::LITE)); - Hit doc_hit0(/*section_id=*/0, /*document_id=*/0, 3, - /*is_in_prefix_section=*/false); - Hit doc_hit1(/*section_id=*/1, /*document_id=*/0, 5, - /*is_in_prefix_section=*/false); - Hit::TermFrequencyArray doc0_term_frequencies{3, 5}; - Hit doc_hit2(/*section_id=*/1, /*document_id=*/1, 7, - /*is_in_prefix_section=*/false); - Hit doc_hit3(/*section_id=*/2, /*document_id=*/1, 11, - /*is_in_prefix_section=*/false); - Hit::TermFrequencyArray doc1_term_frequencies{0, 7, 11}; - ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit0)); - ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit1)); - ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit2)); - ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit3)); + Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/3, + /*is_in_prefix_section=*/false); + Hit doc0_hit1(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5, + /*is_in_prefix_section=*/false); + SectionIdMask doc0_section_id_mask = 0b11; + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map0 = {{0, 3}, {1, 5}}; + ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc0_hit0)); + ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc0_hit1)); + + Hit doc1_hit1(/*section_id=*/1, /*document_id=*/1, /*term_frequency=*/7, + /*is_in_prefix_section=*/false); + Hit doc1_hit2(/*section_id=*/2, /*document_id=*/1, /*term_frequency=*/11, + /*is_in_prefix_section=*/false); + SectionIdMask doc1_section_id_mask = 0b110; + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map1 = {{1, 7}, {2, 11}}; + ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc1_hit1)); + ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc1_hit2)); std::unique_ptr<DocHitInfoIteratorTermLiteExact> iter = std::make_unique<DocHitInfoIteratorTermLiteExact>( @@ -134,25 +138,22 @@ TEST_F(LiteIndexTest, LiteIndexIterator) { ASSERT_THAT(iter->Advance(), IsOk()); EXPECT_THAT(iter->doc_hit_info().document_id(), Eq(1)); - EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(), Eq(0b110)); + EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(), + Eq(doc1_section_id_mask)); + std::vector<TermMatchInfo> matched_terms_stats; iter->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); - EXPECT_EQ(matched_terms_stats.at(0).term, term); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, 0b110); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(doc1_term_frequencies)); + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + term, expected_section_ids_tf_map1))); ASSERT_THAT(iter->Advance(), IsOk()); EXPECT_THAT(iter->doc_hit_info().document_id(), Eq(0)); - EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(), Eq(0b11)); + EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(), + Eq(doc0_section_id_mask)); matched_terms_stats.clear(); iter->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); - EXPECT_EQ(matched_terms_stats.at(0).term, term); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, 0b11); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(doc0_term_frequencies)); + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + term, expected_section_ids_tf_map0))); } } // namespace diff --git a/icing/index/numeric/integer-index-storage.h b/icing/index/numeric/integer-index-storage.h new file mode 100644 index 0000000..2048e76 --- /dev/null +++ b/icing/index/numeric/integer-index-storage.h @@ -0,0 +1,186 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_NUMERIC_INTEGER_INDEX_STORAGE_H_ +#define ICING_INDEX_NUMERIC_INTEGER_INDEX_STORAGE_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> + +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h" +#include "icing/util/crc32.h" + +namespace icing { +namespace lib { + +// IntegerIndexStorage: a class for indexing (persistent storage) and searching +// contents of integer type sections in documents. +// - Accepts new integer contents (a.k.a keys) and adds records (BasicHit, key) +// into the integer index. +// - Stores records (BasicHit, key) in posting lists and compresses them. +// - Bucketizes these records by key to make range query more efficient and +// manages them with the corresponding posting lists. +// - When a posting list reaches the max size and is full, the mechanism of +// PostingListAccessor is to create another new (max-size) posting list and +// chain them together. +// - It will be inefficient if we store all records in the same PL chain. E.g. +// small range query needs to iterate through the whole PL chain but skips a +// lot of non-relevant records (whose keys don't belong to the query range). +// - Therefore, we implement the splitting mechanism to split a full max-size +// posting list. Also adjust range of the original bucket and add new +// buckets. +// - Ranges of all buckets are disjoint and the union of them is [INT64_MIN, +// INT64_MAX]. +// - Buckets should be sorted, so we can do binary search to find the desired +// bucket(s). However, we may split a bucket into several buckets, and the +// cost to insert newly created buckets is high. +// - Thus, we introduce an unsorted bucket array for newly created buckets, +// and merge unsorted buckets into the sorted bucket array only if length of +// the unsorted bucket array exceeds the threshold. This mechanism will +// reduce # of merging events and amortize the overall cost for bucket order +// maintenance. +// Note: some tree data structures (e.g. segment tree, B+ tree) maintain the +// bucket order more efficiently than the sorted/unsorted bucket array +// mechanism, but the implementation is more complicated and doesn't improve +// the performance too much according to our analysis, so currently we +// choose sorted/unsorted bucket array. +// - Then we do binary search on the sorted bucket array and sequential search +// on the unsorted bucket array. +class IntegerIndexStorage { + public: + // Crcs and Info will be written into the metadata file. + // File layout: <Crcs><Info> + // Crcs + struct Crcs { + static constexpr int32_t kFileOffset = 0; + + struct ComponentCrcs { + uint32_t info_crc; + uint32_t sorted_buckets_crc; + uint32_t unsorted_buckets_crc; + uint32_t flash_index_storage_crc; + + bool operator==(const ComponentCrcs& other) const { + return info_crc == other.info_crc && + sorted_buckets_crc == other.sorted_buckets_crc && + unsorted_buckets_crc == other.unsorted_buckets_crc && + flash_index_storage_crc == other.flash_index_storage_crc; + } + + Crc32 ComputeChecksum() const { + return Crc32(std::string_view(reinterpret_cast<const char*>(this), + sizeof(ComponentCrcs))); + } + } __attribute__((packed)); + + bool operator==(const Crcs& other) const { + return all_crc == other.all_crc && component_crcs == other.component_crcs; + } + + uint32_t all_crc; + ComponentCrcs component_crcs; + } __attribute__((packed)); + static_assert(sizeof(Crcs) == 20, ""); + + // Info + struct Info { + static constexpr int32_t kFileOffset = static_cast<int32_t>(sizeof(Crcs)); + static constexpr int32_t kMagic = 0xc4bf0ccc; + + int32_t magic; + int32_t num_keys; + + Crc32 ComputeChecksum() const { + return Crc32( + std::string_view(reinterpret_cast<const char*>(this), sizeof(Info))); + } + } __attribute__((packed)); + static_assert(sizeof(Info) == 8, ""); + + // Bucket + class Bucket { + public: + // Absolute max # of buckets allowed. Since the absolute max file size of + // FileBackedVector on 32-bit platform is ~2^28, we can at most have ~13.4M + // buckets. To make it power of 2, round it down to 2^23. Also since we're + // using FileBackedVector to store buckets, add some static_asserts to + // ensure numbers here are compatible with FileBackedVector. + static constexpr int32_t kMaxNumBuckets = 1 << 23; + + explicit Bucket(int64_t key_lower, int64_t key_upper, + PostingListIdentifier posting_list_identifier) + : key_lower_(key_lower), + key_upper_(key_upper), + posting_list_identifier_(posting_list_identifier) {} + + // For FileBackedVector + bool operator==(const Bucket& other) const { + return key_lower_ == other.key_lower_ && key_upper_ == other.key_upper_ && + posting_list_identifier_ == other.posting_list_identifier_; + } + + PostingListIdentifier posting_list_identifier() const { + return posting_list_identifier_; + } + void set_posting_list_identifier( + PostingListIdentifier posting_list_identifier) { + posting_list_identifier_ = posting_list_identifier; + } + + private: + int64_t key_lower_; + int64_t key_upper_; + PostingListIdentifier posting_list_identifier_; + } __attribute__((packed)); + static_assert(sizeof(Bucket) == 20, ""); + static_assert(sizeof(Bucket) == FileBackedVector<Bucket>::kElementTypeSize, + "Bucket type size is inconsistent with FileBackedVector " + "element type size"); + static_assert(Bucket::kMaxNumBuckets <= + (FileBackedVector<Bucket>::kMaxFileSize - + FileBackedVector<Bucket>::Header::kHeaderSize) / + FileBackedVector<Bucket>::kElementTypeSize, + "Max # of buckets cannot fit into FileBackedVector"); + + private: + explicit IntegerIndexStorage( + const Filesystem& filesystem, std::string_view base_dir, + PostingListUsedIntegerIndexDataSerializer* serializer, + std::unique_ptr<MemoryMappedFile> metadata_mmapped_file, + std::unique_ptr<FileBackedVector<Bucket>> sorted_buckets, + std::unique_ptr<FileBackedVector<Bucket>> unsorted_buckets, + std::unique_ptr<FlashIndexStorage> flash_index_storage); + + const Filesystem& filesystem_; + std::string base_dir_; + + PostingListUsedIntegerIndexDataSerializer* serializer_; // Does not own. + + std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_; + std::unique_ptr<FileBackedVector<Bucket>> sorted_buckets_; + std::unique_ptr<FileBackedVector<Bucket>> unsorted_buckets_; + std::unique_ptr<FlashIndexStorage> flash_index_storage_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_NUMERIC_INTEGER_INDEX_STORAGE_H_ diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc index 9a7df38..51f3106 100644 --- a/icing/jni/icing-search-engine-jni.cc +++ b/icing/jni/icing-search-engine-jni.cc @@ -51,7 +51,8 @@ jbyteArray SerializeProtoToJniByteArray( int size = protobuf.ByteSizeLong(); jbyteArray ret = env->NewByteArray(size); if (ret == nullptr) { - ICING_LOG(ERROR) << "Failed to allocated bytes for jni protobuf"; + ICING_LOG(icing::lib::ERROR) + << "Failed to allocated bytes for jni protobuf"; return nullptr; } @@ -75,7 +76,7 @@ extern "C" { jint JNI_OnLoad(JavaVM* vm, void* reserved) { JNIEnv* env; if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) { - ICING_LOG(ERROR) << "ERROR: GetEnv failed"; + ICING_LOG(icing::lib::ERROR) << "ERROR: GetEnv failed"; return JNI_ERR; } @@ -88,7 +89,7 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeCreate( icing::lib::IcingSearchEngineOptions options; if (!ParseProtoFromJniByteArray(env, icing_search_engine_options_bytes, &options)) { - ICING_LOG(ERROR) + ICING_LOG(icing::lib::ERROR) << "Failed to parse IcingSearchEngineOptions in nativeCreate"; return 0; } @@ -131,7 +132,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetSchema( icing::lib::SchemaProto schema_proto; if (!ParseProtoFromJniByteArray(env, schema_bytes, &schema_proto)) { - ICING_LOG(ERROR) << "Failed to parse SchemaProto in nativeSetSchema"; + ICING_LOG(icing::lib::ERROR) + << "Failed to parse SchemaProto in nativeSetSchema"; return nullptr; } @@ -173,7 +175,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativePut( icing::lib::DocumentProto document_proto; if (!ParseProtoFromJniByteArray(env, document_bytes, &document_proto)) { - ICING_LOG(ERROR) << "Failed to parse DocumentProto in nativePut"; + ICING_LOG(icing::lib::ERROR) + << "Failed to parse DocumentProto in nativePut"; return nullptr; } @@ -192,7 +195,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeGet( icing::lib::GetResultSpecProto get_result_spec; if (!ParseProtoFromJniByteArray(env, result_spec_bytes, &get_result_spec)) { - ICING_LOG(ERROR) << "Failed to parse GetResultSpecProto in nativeGet"; + ICING_LOG(icing::lib::ERROR) + << "Failed to parse GetResultSpecProto in nativeGet"; return nullptr; } icing::lib::ScopedUtfChars scoped_name_space_chars(env, name_space); @@ -212,7 +216,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeReportUsage( icing::lib::UsageReport usage_report; if (!ParseProtoFromJniByteArray(env, usage_report_bytes, &usage_report)) { - ICING_LOG(ERROR) << "Failed to parse UsageReport in nativeReportUsage"; + ICING_LOG(icing::lib::ERROR) + << "Failed to parse UsageReport in nativeReportUsage"; return nullptr; } @@ -279,20 +284,23 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearch( icing::lib::SearchSpecProto search_spec_proto; if (!ParseProtoFromJniByteArray(env, search_spec_bytes, &search_spec_proto)) { - ICING_LOG(ERROR) << "Failed to parse SearchSpecProto in nativeSearch"; + ICING_LOG(icing::lib::ERROR) + << "Failed to parse SearchSpecProto in nativeSearch"; return nullptr; } icing::lib::ScoringSpecProto scoring_spec_proto; if (!ParseProtoFromJniByteArray(env, scoring_spec_bytes, &scoring_spec_proto)) { - ICING_LOG(ERROR) << "Failed to parse ScoringSpecProto in nativeSearch"; + ICING_LOG(icing::lib::ERROR) + << "Failed to parse ScoringSpecProto in nativeSearch"; return nullptr; } icing::lib::ResultSpecProto result_spec_proto; if (!ParseProtoFromJniByteArray(env, result_spec_bytes, &result_spec_proto)) { - ICING_LOG(ERROR) << "Failed to parse ResultSpecProto in nativeSearch"; + ICING_LOG(icing::lib::ERROR) + << "Failed to parse ResultSpecProto in nativeSearch"; return nullptr; } @@ -363,7 +371,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByQuery( icing::lib::SearchSpecProto search_spec_proto; if (!ParseProtoFromJniByteArray(env, search_spec_bytes, &search_spec_proto)) { - ICING_LOG(ERROR) << "Failed to parse SearchSpecProto in nativeSearch"; + ICING_LOG(icing::lib::ERROR) + << "Failed to parse SearchSpecProto in nativeSearch"; return nullptr; } icing::lib::DeleteByQueryResultProto delete_result_proto = @@ -379,8 +388,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativePersistToDisk( GetIcingSearchEnginePointer(env, object); if (!icing::lib::PersistType::Code_IsValid(persist_type_code)) { - ICING_LOG(ERROR) << persist_type_code - << " is an invalid value for PersistType::Code"; + ICING_LOG(icing::lib::ERROR) + << persist_type_code << " is an invalid value for PersistType::Code"; return nullptr; } icing::lib::PersistType::Code persist_type_code_enum = @@ -447,7 +456,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearchSuggestions( icing::lib::SuggestionSpecProto suggestion_spec_proto; if (!ParseProtoFromJniByteArray(env, suggestion_spec_bytes, &suggestion_spec_proto)) { - ICING_LOG(ERROR) << "Failed to parse SuggestionSpecProto in nativeSearch"; + ICING_LOG(icing::lib::ERROR) + << "Failed to parse SuggestionSpecProto in nativeSearch"; return nullptr; } icing::lib::SuggestionResponse suggestionResponse = @@ -463,7 +473,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetDebugInfo( GetIcingSearchEnginePointer(env, object); if (!icing::lib::DebugInfoVerbosity::Code_IsValid(verbosity)) { - ICING_LOG(ERROR) << "Invalid value for Debug Info verbosity: " << verbosity; + ICING_LOG(icing::lib::ERROR) + << "Invalid value for Debug Info verbosity: " << verbosity; return nullptr; } @@ -478,7 +489,8 @@ JNIEXPORT jboolean JNICALL Java_com_google_android_icing_IcingSearchEngineImpl_nativeShouldLog( JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) { if (!icing::lib::LogSeverity::Code_IsValid(severity)) { - ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity; + ICING_LOG(icing::lib::ERROR) + << "Invalid value for logging severity: " << severity; return false; } return icing::lib::ShouldLog( @@ -489,7 +501,8 @@ JNIEXPORT jboolean JNICALL Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetLoggingLevel( JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) { if (!icing::lib::LogSeverity::Code_IsValid(severity)) { - ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity; + ICING_LOG(icing::lib::ERROR) + << "Invalid value for logging severity: " << severity; return false; } return icing::lib::SetLoggingLevel( diff --git a/icing/join/aggregate-scorer.cc b/icing/join/aggregation-scorer.cc index 7b17482..3dee3dd 100644 --- a/icing/join/aggregate-scorer.cc +++ b/icing/join/aggregation-scorer.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/join/aggregate-scorer.h" +#include "icing/join/aggregation-scorer.h" #include <algorithm> #include <memory> @@ -25,24 +25,26 @@ namespace icing { namespace lib { -class MinAggregateScorer : public AggregateScorer { +class CountAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { - return std::min_element(children.begin(), children.end(), - [](const ScoredDocumentHit& lhs, - const ScoredDocumentHit& rhs) -> bool { - return lhs.score() < rhs.score(); - }) - ->score(); + return children.size(); } }; -class MaxAggregateScorer : public AggregateScorer { +class MinAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { - return std::max_element(children.begin(), children.end(), + if (children.empty()) { + // Return 0 if there is no child document. + // For non-empty children with negative scores, they are considered "worse + // than" 0, so it is correct to return 0 for empty children to assign it a + // rank higher than non-empty children with negative scores. + return 0.0; + } + return std::min_element(children.begin(), children.end(), [](const ScoredDocumentHit& lhs, const ScoredDocumentHit& rhs) -> bool { return lhs.score() < rhs.score(); @@ -51,41 +53,59 @@ class MaxAggregateScorer : public AggregateScorer { } }; -class AverageAggregateScorer : public AggregateScorer { +class AverageAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { - if (children.empty()) return 0.0; + if (children.empty()) { + // Return 0 if there is no child document. + // For non-empty children with negative scores, they are considered "worse + // than" 0, so it is correct to return 0 for empty children to assign it a + // rank higher than non-empty children with negative scores. + return 0.0; + } return std::reduce( children.begin(), children.end(), 0.0, - [](const double& prev, const ScoredDocumentHit& item) -> double { + [](double prev, const ScoredDocumentHit& item) -> double { return prev + item.score(); }) / children.size(); } }; -class CountAggregateScorer : public AggregateScorer { +class MaxAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { - return children.size(); + if (children.empty()) { + // Return 0 if there is no child document. + // For non-empty children with negative scores, they are considered "worse + // than" 0, so it is correct to return 0 for empty children to assign it a + // rank higher than non-empty children with negative scores. + return 0.0; + } + return std::max_element(children.begin(), children.end(), + [](const ScoredDocumentHit& lhs, + const ScoredDocumentHit& rhs) -> bool { + return lhs.score() < rhs.score(); + }) + ->score(); } }; -class SumAggregateScorer : public AggregateScorer { +class SumAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { return std::reduce( children.begin(), children.end(), 0.0, - [](const double& prev, const ScoredDocumentHit& item) -> double { + [](double prev, const ScoredDocumentHit& item) -> double { return prev + item.score(); }); } }; -class DefaultAggregateScorer : public AggregateScorer { +class DefaultAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { @@ -93,23 +113,25 @@ class DefaultAggregateScorer : public AggregateScorer { } }; -std::unique_ptr<AggregateScorer> AggregateScorer::Create( +std::unique_ptr<AggregationScorer> AggregationScorer::Create( const JoinSpecProto& join_spec) { - switch (join_spec.aggregation_score_strategy()) { - case JoinSpecProto_AggregationScore_MIN: - return std::make_unique<MinAggregateScorer>(); - case JoinSpecProto_AggregationScore_MAX: - return std::make_unique<MaxAggregateScorer>(); - case JoinSpecProto_AggregationScore_COUNT: - return std::make_unique<CountAggregateScorer>(); - case JoinSpecProto_AggregationScore_AVG: - return std::make_unique<AverageAggregateScorer>(); - case JoinSpecProto_AggregationScore_SUM: - return std::make_unique<SumAggregateScorer>(); - case JoinSpecProto_AggregationScore_UNDEFINED: + switch (join_spec.aggregation_scoring_strategy()) { + case JoinSpecProto::AggregationScoringStrategy::COUNT: + return std::make_unique<CountAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::MIN: + return std::make_unique<MinAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::AVG: + return std::make_unique<AverageAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::MAX: + return std::make_unique<MaxAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::SUM: + return std::make_unique<SumAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::NONE: + // No aggregation strategy means using parent document score, so fall + // through to return DefaultAggregationScorer. [[fallthrough]]; default: - return std::make_unique<DefaultAggregateScorer>(); + return std::make_unique<DefaultAggregationScorer>(); } } diff --git a/icing/join/aggregate-scorer.h b/icing/join/aggregation-scorer.h index 27731b9..3d38cf0 100644 --- a/icing/join/aggregate-scorer.h +++ b/icing/join/aggregation-scorer.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef ICING_JOIN_AGGREGATE_SCORER_H_ -#define ICING_JOIN_AGGREGATE_SCORER_H_ +#ifndef ICING_JOIN_AGGREGATION_SCORER_H_ +#define ICING_JOIN_AGGREGATION_SCORER_H_ #include <memory> #include <vector> @@ -24,12 +24,12 @@ namespace icing { namespace lib { -class AggregateScorer { +class AggregationScorer { public: - static std::unique_ptr<AggregateScorer> Create( + static std::unique_ptr<AggregationScorer> Create( const JoinSpecProto& join_spec); - virtual ~AggregateScorer() = default; + virtual ~AggregationScorer() = default; virtual double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) = 0; @@ -38,4 +38,4 @@ class AggregateScorer { } // namespace lib } // namespace icing -#endif // ICING_JOIN_AGGREGATE_SCORER_H_ +#endif // ICING_JOIN_AGGREGATION_SCORER_H_ diff --git a/icing/join/aggregation-scorer_test.cc b/icing/join/aggregation-scorer_test.cc new file mode 100644 index 0000000..19a7239 --- /dev/null +++ b/icing/join/aggregation-scorer_test.cc @@ -0,0 +1,215 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/join/aggregation-scorer.h" + +#include <algorithm> +#include <iterator> +#include <memory> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/proto/search.pb.h" +#include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::DoubleEq; + +struct AggregationScorerTestParam { + double ans; + JoinSpecProto::AggregationScoringStrategy::Code scoring_strategy; + double parent_score; + std::vector<double> child_scores; + + explicit AggregationScorerTestParam( + double ans_in, + JoinSpecProto::AggregationScoringStrategy::Code scoring_strategy_in, + double parent_score_in, std::vector<double> child_scores_in) + : ans(ans_in), + scoring_strategy(scoring_strategy_in), + parent_score(std::move(parent_score_in)), + child_scores(std::move(child_scores_in)) {} +}; + +class AggregationScorerTest + : public ::testing::TestWithParam<AggregationScorerTestParam> {}; + +TEST_P(AggregationScorerTest, GetScore) { + static constexpr DocumentId kDefaultDocumentId = 0; + + const AggregationScorerTestParam& param = GetParam(); + // Test AggregationScorer by creating some ScoredDocumentHits for parent and + // child documents. DocumentId and SectionIdMask won't affect the aggregation + // score calculation, so just simply set default values. + // Parent document + ScoredDocumentHit parent_scored_document_hit( + kDefaultDocumentId, kSectionIdMaskNone, param.parent_score); + // Child documents + std::vector<ScoredDocumentHit> child_scored_document_hits; + child_scored_document_hits.reserve(param.child_scores.size()); + std::transform(param.child_scores.cbegin(), param.child_scores.cend(), + std::back_inserter(child_scored_document_hits), + [](double score) -> ScoredDocumentHit { + return ScoredDocumentHit(kDefaultDocumentId, + kSectionIdMaskNone, score); + }); + + JoinSpecProto join_spec; + join_spec.set_aggregation_scoring_strategy(param.scoring_strategy); + std::unique_ptr<AggregationScorer> scorer = + AggregationScorer::Create(join_spec); + EXPECT_THAT( + scorer->GetScore(parent_scored_document_hit, child_scored_document_hits), + DoubleEq(param.ans)); +} + +INSTANTIATE_TEST_SUITE_P( + CountAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/5, JoinSpecProto::AggregationScoringStrategy::COUNT, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child. + AggregationScorerTestParam( + /*ans_in=*/1, JoinSpecProto::AggregationScoringStrategy::COUNT, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::COUNT, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + MinAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/1, JoinSpecProto::AggregationScoringStrategy::MIN, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child, greater than parent. + AggregationScorerTestParam( + /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::MIN, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // Only one child, smaller than parent. + AggregationScorerTestParam( + /*ans_in=*/50, JoinSpecProto::AggregationScoringStrategy::MIN, + /*parent_score_in=*/98, + /*child_scores_in=*/{50}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::MIN, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + AverageAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/4.6, JoinSpecProto::AggregationScoringStrategy::AVG, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child. + AggregationScorerTestParam( + /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::AVG, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::AVG, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + MaxAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/8, JoinSpecProto::AggregationScoringStrategy::MAX, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child, greater than parent. + AggregationScorerTestParam( + /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::MAX, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // Only one child, smaller than parent. + AggregationScorerTestParam( + /*ans_in=*/50, JoinSpecProto::AggregationScoringStrategy::MAX, + /*parent_score_in=*/98, + /*child_scores_in=*/{50}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::MAX, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + SumAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/23, JoinSpecProto::AggregationScoringStrategy::SUM, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child. + AggregationScorerTestParam( + /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::SUM, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::SUM, + /*parent_score_in=*/0, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + DefaultAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child, greater than parent. + AggregationScorerTestParam( + /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // Only one child, smaller than parent. + AggregationScorerTestParam( + /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE, + /*parent_score_in=*/98, + /*child_scores_in=*/{50}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/join/join-processor.cc b/icing/join/join-processor.cc index 9b17396..7700397 100644 --- a/icing/join/join-processor.cc +++ b/icing/join/join-processor.cc @@ -23,6 +23,7 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/join/aggregation-scorer.h" #include "icing/join/qualified-id.h" #include "icing/proto/scoring.pb.h" #include "icing/proto/search.pb.h" @@ -45,8 +46,6 @@ JoinProcessor::Join( ScoringSpecProto::Order::DESC)); // TODO(b/256022027): - // - Aggregate scoring - // - Calculate the aggregated score if strategy is AGGREGATION_SCORING. // - Optimization // - Cache property to speed up property retrieval. // - If there is no cache, then we still have the flexibility to fetch it @@ -93,6 +92,9 @@ JoinProcessor::Join( } } + std::unique_ptr<AggregationScorer> aggregation_scorer = + AggregationScorer::Create(join_spec); + std::vector<JoinedScoredDocumentHit> joined_scored_document_hits; joined_scored_document_hits.reserve(parent_scored_document_hits.size()); @@ -110,14 +112,15 @@ JoinProcessor::Join( "Parent property expression must be ", kQualifiedIdExpr)); } - // TODO(b/256022027): Derive final score from - // parent_id_to_child_map[parent_doc_id] and - // join_spec.aggregation_score_strategy() - double final_score = parent.score(); - joined_scored_document_hits.emplace_back( - final_score, std::move(parent), - std::vector<ScoredDocumentHit>( - std::move(parent_id_to_child_map[parent_doc_id]))); + std::vector<ScoredDocumentHit> children; + if (auto iter = parent_id_to_child_map.find(parent_doc_id); + iter != parent_id_to_child_map.end()) { + children = std::move(iter->second); + } + + double final_score = aggregation_scorer->GetScore(parent, children); + joined_scored_document_hits.emplace_back(final_score, std::move(parent), + std::move(children)); } return joined_scored_document_hits; diff --git a/icing/join/join-processor_test.cc b/icing/join/join-processor_test.cc new file mode 100644 index 0000000..70eaf3f --- /dev/null +++ b/icing/join/join-processor_test.cc @@ -0,0 +1,624 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/join/join-processor.h" + +#include <memory> +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/proto/document.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/proto/scoring.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; + +class JoinProcessorTest : public ::testing::Test { + protected: + void SetUp() override { + test_dir_ = GetTestTempDir() + "/icing_join_processor_test"; + filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder().SetType("Person").AddProperty( + PropertyConfigBuilder() + .SetName("Name") + .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType( + SchemaTypeConfigBuilder() + .SetType("Email") + .AddProperty(PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString(TERM_MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("sender") + .SetDataTypeJoinableString( + JOINABLE_VALUE_TYPE_QUALIFIED_ID) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + doc_store_ = std::move(create_result.document_store); + } + + void TearDown() override { + doc_store_.reset(); + schema_store_.reset(); + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + Filesystem filesystem_; + std::string test_dir_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<DocumentStore> doc_store_; + FakeClock fake_clock_; +}; + +TEST_F(JoinProcessorTest, JoinByQualifiedId) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + DocumentProto person2 = DocumentBuilder() + .SetKey(R"(pkg$db/name#space\\)", "person2") + .SetSchema("Person") + .AddStringProperty("Name", "Bob") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .AddStringProperty("sender", + R"(pkg$db/name\#space\\\\#person2)") // escaped + .Build(); + DocumentProto email3 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email3") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 3") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(person2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, doc_store_->Put(email3)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/3.0); + ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone, + /*score=*/4.0); + ScoredDocumentHit scored_doc_hit5(document_id5, kSectionIdMaskNone, + /*score=*/5.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit2, scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = { + scored_doc_hit5, scored_doc_hit4, scored_doc_hit3}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit2, + /*child_scored_document_hits=*/{scored_doc_hit4})), + EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/2.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/ + {scored_doc_hit5, scored_doc_hit3})))); +} + +TEST_F(JoinProcessorTest, ShouldIgnoreChildDocumentsWithoutJoiningProperty) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email2)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/5.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/6.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = {scored_doc_hit2, + scored_doc_hit3}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Since Email2 doesn't have "sender" property, it should be ignored. + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/{scored_doc_hit2})))); +} + +TEST_F(JoinProcessorTest, ShouldIgnoreChildDocumentsWithInvalidQualifiedId) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .AddStringProperty( + "sender", + "pkg$db/namespace#person2") // qualified id is invalid since + // person2 doesn't exist. + .Build(); + DocumentProto email3 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email3") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 3") + .AddStringProperty("sender", + R"(pkg$db/namespace\#person1)") // invalid format + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email3)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone, + /*score=*/0.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = { + scored_doc_hit2, scored_doc_hit3, scored_doc_hit4}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Email 2 and email 3 (document id 3 and 4) contain invalid qualified ids. + // Join processor should ignore them. + EXPECT_THAT(joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/{scored_doc_hit2})))); +} + +TEST_F(JoinProcessorTest, LeftJoinShouldReturnParentWithoutChildren) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + DocumentProto person2 = DocumentBuilder() + .SetKey(R"(pkg$db/name#space\\)", "person2") + .SetSchema("Person") + .AddStringProperty("Name", "Bob") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", + R"(pkg$db/name\#space\\\\#person2)") // escaped + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(person2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email1)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/3.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit2, scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = {scored_doc_hit3}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Person1 has no child documents, but left join should also include it. + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit2, + /*child_scored_document_hits=*/{scored_doc_hit3})), + EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/0.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/{})))); +} + +TEST_F(JoinProcessorTest, ShouldSortChildDocumentsByRankingStrategy) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email3 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email3") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 3") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email3)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/2.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/5.0); + ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone, + /*score=*/3.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = { + scored_doc_hit2, scored_doc_hit3, scored_doc_hit4}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Child documents should be sorted according to the (nested) ranking + // strategy. + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/3.0, /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/ + {scored_doc_hit3, scored_doc_hit4, scored_doc_hit2})))); +} + +TEST_F(JoinProcessorTest, + ShouldTruncateByRankingStrategyIfExceedingMaxJoinedChildCount) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + DocumentProto person2 = DocumentBuilder() + .SetKey(R"(pkg$db/name#space\\)", "person2") + .SetSchema("Person") + .AddStringProperty("Name", "Bob") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email3 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email3") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 3") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email4 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email4") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 4") + .AddStringProperty("sender", + R"(pkg$db/name\#space\\\\#person2)") // escaped + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(person2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, doc_store_->Put(email3)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id6, doc_store_->Put(email4)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/2.0); + ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone, + /*score=*/5.0); + ScoredDocumentHit scored_doc_hit5(document_id5, kSectionIdMaskNone, + /*score=*/3.0); + ScoredDocumentHit scored_doc_hit6(document_id6, kSectionIdMaskNone, + /*score=*/1.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1, scored_doc_hit2}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = { + scored_doc_hit3, scored_doc_hit4, scored_doc_hit5, scored_doc_hit6}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(2); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Since we set max_joind_child_count as 2 and use DESC as the (nested) + // ranking strategy, parent document with # of child documents more than 2 + // should only keep 2 child documents with higher scores and the rest should + // be truncated. + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/2.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/ + {scored_doc_hit4, scored_doc_hit5})), + EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit2, + /*child_scored_document_hits=*/{scored_doc_hit6})))); +} + +TEST_F(JoinProcessorTest, ShouldAllowSelfJoining) { + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#email1") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(email1)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = {scored_doc_hit1}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + EXPECT_THAT(joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/{scored_doc_hit1})))); +} + +// TODO(b/256022027): add unit tests for non-joinable property. If joinable +// value type is unset, then qualifed id join should not +// include the child document even if it contains a valid +// qualified id string. + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc index 910d722..fbd4504 100644 --- a/icing/query/advanced_query_parser/query-visitor.cc +++ b/icing/query/advanced_query_parser/query-visitor.cc @@ -366,5 +366,23 @@ void QueryVisitor::VisitNaryOperator(const NaryOperatorNode* node) { pending_values_.push(std::move(pending_value)); } +libtextclassifier3::StatusOr<QueryResults> QueryVisitor::ConsumeResults() && { + if (has_pending_error()) { + return std::move(pending_error_); + } + if (pending_values_.size() != 1) { + return absl_ports::InvalidArgumentError( + "Visitor does not contain a single root iterator."); + } + auto iterator_or = RetrieveIterator(); + if (!iterator_or.ok()) { + return std::move(iterator_or).status(); + } + QueryResults results; + results.root_iterator = std::move(iterator_or).ValueOrDie(); + results.features_in_use = std::move(features_); + return results; +} + } // namespace lib } // namespace icing diff --git a/icing/query/advanced_query_parser/query-visitor.h b/icing/query/advanced_query_parser/query-visitor.h index 8bba7ea..c6b7d8e 100644 --- a/icing/query/advanced_query_parser/query-visitor.h +++ b/icing/query/advanced_query_parser/query-visitor.h @@ -26,10 +26,11 @@ #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/index/numeric/numeric-index.h" #include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/query/query-features.h" +#include "icing/query/query-results.h" #include "icing/schema/schema-store.h" #include "icing/store/document-store.h" #include "icing/transform/normalizer.h" -#include "icing/query/query-features.h" namespace icing { namespace lib { @@ -60,27 +61,10 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor { void VisitNaryOperator(const NaryOperatorNode* node) override; // RETURNS: - // - the DocHitInfoIterator that is the root of the query iterator tree + // - the QueryResults reflecting the AST that was visited // - INVALID_ARGUMENT if the AST does not conform to supported expressions // - NOT_FOUND if the AST refers to a property that does not exist - libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> root() && { - if (has_pending_error()) { - return pending_error_; - } - if (pending_values_.size() != 1) { - return absl_ports::InvalidArgumentError( - "Visitor does not contain a single root iterator."); - } - auto iterator_or = RetrieveIterator(); - if (!iterator_or.ok()) { - pending_error_ = std::move(iterator_or).status(); - return pending_error_; - } - return std::move(iterator_or).ValueOrDie(); - } - - // Returns the set of features used in the query. - const std::unordered_set<Feature>& features() const { return features_; } + libtextclassifier3::StatusOr<QueryResults> ConsumeResults() &&; private: // A holder for intermediate results when processing child nodes. @@ -185,15 +169,16 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor { std::stack<PendingValue> pending_values_; libtextclassifier3::Status pending_error_; + // Set of features invoked in the query. + std::unordered_set<Feature> features_; + Index& index_; // Does not own! const NumericIndex<int64_t>& numeric_index_; // Does not own! const DocumentStore& document_store_; // Does not own! const SchemaStore& schema_store_; // Does not own! const Normalizer& normalizer_; // Does not own! - TermMatchType::Code match_type_; - // Set of features invoked in the query. - std::unordered_set<Feature> features_; + TermMatchType::Code match_type_; }; } // namespace lib diff --git a/icing/query/advanced_query_parser/query-visitor_test.cc b/icing/query/advanced_query_parser/query-visitor_test.cc index 2f816a8..2b5117b 100644 --- a/icing/query/advanced_query_parser/query-visitor_test.cc +++ b/icing/query/advanced_query_parser/query-visitor_test.cc @@ -146,10 +146,11 @@ TEST_F(QueryVisitorTest, SimpleLessThan) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kNumericSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId1, kDocumentId0)); } @@ -176,10 +177,11 @@ TEST_F(QueryVisitorTest, SimpleLessThanEq) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kNumericSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId1, kDocumentId0)); } @@ -206,10 +208,12 @@ TEST_F(QueryVisitorTest, SimpleEqual) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kNumericSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId2)); } TEST_F(QueryVisitorTest, SimpleGreaterThanEq) { @@ -235,10 +239,11 @@ TEST_F(QueryVisitorTest, SimpleGreaterThanEq) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kNumericSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId2, kDocumentId1)); } @@ -265,10 +270,12 @@ TEST_F(QueryVisitorTest, SimpleGreaterThan) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kNumericSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId2)); } // TODO(b/208654892) Properly handle negative numbers in query expressions. @@ -296,10 +303,12 @@ TEST_F(QueryVisitorTest, DISABLED_IntMinLessThanEqual) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kNumericSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); } TEST_F(QueryVisitorTest, IntMaxGreaterThanEqual) { @@ -326,10 +335,12 @@ TEST_F(QueryVisitorTest, IntMaxGreaterThanEqual) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kNumericSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1)); } TEST_F(QueryVisitorTest, NestedPropertyLessThan) { @@ -357,10 +368,11 @@ TEST_F(QueryVisitorTest, NestedPropertyLessThan) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kNumericSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId1, kDocumentId0)); } @@ -372,7 +384,7 @@ TEST_F(QueryVisitorTest, IntParsingError) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(std::move(query_visitor).root(), + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -384,7 +396,7 @@ TEST_F(QueryVisitorTest, NotEqualsUnsupported) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(std::move(query_visitor).root(), + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::UNIMPLEMENTED)); } @@ -427,7 +439,7 @@ TEST_F(QueryVisitorTest, LessThanTooManyOperandsInvalid) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(std::move(query_visitor).root(), + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -449,7 +461,7 @@ TEST_F(QueryVisitorTest, LessThanTooFewOperandsInvalid) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(std::move(query_visitor).root(), + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -480,7 +492,7 @@ TEST_F(QueryVisitorTest, LessThanNonExistentPropertyNotFound) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(std::move(query_visitor).root(), + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } @@ -488,7 +500,7 @@ TEST_F(QueryVisitorTest, NeverVisitedReturnsInvalid) { QueryVisitor query_visitor(index_.get(), numeric_index_.get(), document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); - EXPECT_THAT(std::move(query_visitor).root(), + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -517,7 +529,7 @@ TEST_F(QueryVisitorTest, DISABLED_IntMinLessThanInvalid) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(std::move(query_visitor).root(), + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -545,7 +557,7 @@ TEST_F(QueryVisitorTest, IntMaxGreaterThanInvalid) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(std::move(query_visitor).root(), + EXPECT_THAT(std::move(query_visitor).ConsumeResults(), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } @@ -574,9 +586,9 @@ TEST_F(QueryVisitorTest, SingleTerm) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId1, kDocumentId0)); } @@ -605,10 +617,11 @@ TEST_F(QueryVisitorTest, SingleVerbatimTerm) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kVerbatimSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId1, kDocumentId0)); } @@ -650,10 +663,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingQuote) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kVerbatimSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId2)); } // 2. How does a user represent a escape char (\) that immediately precedes the @@ -687,10 +702,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingEscape) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kVerbatimSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1)); } // 3. How do we handle other escaped chars? @@ -726,10 +743,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kVerbatimSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); // Issue a query for the verbatim token `foobar\y`. query = R"("foobar\\y")"; @@ -738,10 +757,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor_two); - EXPECT_THAT(query_visitor_two.features(), ElementsAre(kVerbatimSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(root_iterator, - std::move(query_visitor_two).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor_two).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kVerbatimSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId2)); } // This isn't a special case, but is fairly useful for demonstrating. There are @@ -778,10 +799,12 @@ TEST_F(QueryVisitorTest, VerbatimTermNewLine) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kVerbatimSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); // Now, issue a query for the verbatim token `foobar\n`. query = R"("foobar\\n")"; @@ -790,10 +813,12 @@ TEST_F(QueryVisitorTest, VerbatimTermNewLine) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor_two); - EXPECT_THAT(query_visitor_two.features(), ElementsAre(kVerbatimSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(root_iterator, - std::move(query_visitor_two).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor_two).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kVerbatimSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId2)); } TEST_F(QueryVisitorTest, VerbatimTermEscapingComplex) { @@ -824,10 +849,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingComplex) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature)); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, + ElementsAre(kVerbatimSearchFeature)); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); } TEST_F(QueryVisitorTest, SingleMinusTerm) { @@ -866,10 +893,11 @@ TEST_F(QueryVisitorTest, SingleMinusTerm) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId2)); } TEST_F(QueryVisitorTest, SingleNotTerm) { @@ -908,10 +936,11 @@ TEST_F(QueryVisitorTest, SingleNotTerm) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId2)); } TEST_F(QueryVisitorTest, ImplicitAndTerms) { Index::Editor editor = index_->Edit(kDocumentId0, kSectionId1, @@ -937,10 +966,11 @@ TEST_F(QueryVisitorTest, ImplicitAndTerms) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1)); } TEST_F(QueryVisitorTest, ExplicitAndTerms) { @@ -967,10 +997,11 @@ TEST_F(QueryVisitorTest, ExplicitAndTerms) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId1)); } TEST_F(QueryVisitorTest, OrTerms) { @@ -997,10 +1028,10 @@ TEST_F(QueryVisitorTest, OrTerms) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId2, kDocumentId0)); } @@ -1030,10 +1061,10 @@ TEST_F(QueryVisitorTest, AndOrTermPrecedence) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId2, kDocumentId1)); // Should be interpreted like `(bar OR baz) foo` @@ -1043,10 +1074,10 @@ TEST_F(QueryVisitorTest, AndOrTermPrecedence) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor_two); - EXPECT_THAT(query_visitor_two.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(root_iterator, - std::move(query_visitor_two).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor_two).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId2, kDocumentId1)); query = "(bar OR baz) foo"; @@ -1055,10 +1086,10 @@ TEST_F(QueryVisitorTest, AndOrTermPrecedence) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor_three); - EXPECT_THAT(query_visitor_three.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(root_iterator, - std::move(query_visitor_three).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor_three).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId2, kDocumentId1)); } @@ -1103,10 +1134,10 @@ TEST_F(QueryVisitorTest, AndOrNotPrecedence) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId2, kDocumentId0)); query = "foo NOT (bar OR baz)"; @@ -1115,10 +1146,11 @@ TEST_F(QueryVisitorTest, AndOrNotPrecedence) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor_two); - EXPECT_THAT(query_visitor_two.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(root_iterator, - std::move(query_visitor_two).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0)); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor_two).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId0)); } TEST_F(QueryVisitorTest, PropertyFilter) { @@ -1169,10 +1201,10 @@ TEST_F(QueryVisitorTest, PropertyFilter) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId1, kDocumentId0)); } @@ -1224,10 +1256,10 @@ TEST_F(QueryVisitorTest, PropertyFilterWithGrouping) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), ElementsAre(kDocumentId1, kDocumentId0)); } @@ -1279,10 +1311,11 @@ TEST_F(QueryVisitorTest, PropertyFilterWithNot) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor); - EXPECT_THAT(query_visitor.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, - std::move(query_visitor).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); + ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results, + std::move(query_visitor).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId2)); query = "NOT prop1:(foo OR bar)"; ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query)); @@ -1290,10 +1323,11 @@ TEST_F(QueryVisitorTest, PropertyFilterWithNot) { document_store_.get(), schema_store_.get(), normalizer_.get(), TERM_MATCH_PREFIX); root_node->Accept(&query_visitor_two); - EXPECT_THAT(query_visitor_two.features(), IsEmpty()); - ICING_ASSERT_OK_AND_ASSIGN(root_iterator, - std::move(query_visitor_two).root()); - EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); + ICING_ASSERT_OK_AND_ASSIGN(query_results, + std::move(query_visitor_two).ConsumeResults()); + EXPECT_THAT(query_results.features_in_use, IsEmpty()); + EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()), + ElementsAre(kDocumentId2)); } } // namespace diff --git a/icing/query/query-features.h b/icing/query/query-features.h index 56825f7..1471063 100644 --- a/icing/query/query-features.h +++ b/icing/query/query-features.h @@ -22,13 +22,17 @@ namespace icing { namespace lib { // A feature used in a query. -using Feature = std::string_view; - // All feature values here must be kept in sync with its counterpart in: // androidx-main/frameworks/support/appsearch/appsearch/src/main/java/androidx/appsearch/app/Features.java +using Feature = std::string_view; + +// This feature relates to the use of the numeric comparison operators in the +// advanced query language. Ex. `price < 10`. constexpr Feature kNumericSearchFeature = "NUMERIC_SEARCH"; // Features#NUMERIC_SEARCH +// This feature relates to the use of the STRING terminal in the advanced query +// language. Ex. `"foo?bar"` is treated as a single term - `foo?bar`. constexpr Feature kVerbatimSearchFeature = "VERBATIM_SEARCH"; // Features#VERBATIM_SEARCH diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc index 7860684..283d83d 100644 --- a/icing/query/query-processor.cc +++ b/icing/query/query-processor.cc @@ -207,10 +207,7 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery( &schema_store_, &normalizer_, search_spec.term_match_type()); tree_root->Accept(&query_visitor); - results.features_in_use = query_visitor.features(); - ICING_ASSIGN_OR_RETURN(results.root_iterator, - std::move(query_visitor).root()); - return results; + return std::move(query_visitor).ConsumeResults(); } // TODO(cassiewang): Collect query stats to populate the SearchResultsProto diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc index 161f180..c22f6aa 100644 --- a/icing/query/query-processor_test.cc +++ b/icing/query/query-processor_test.cc @@ -59,7 +59,6 @@ namespace lib { namespace { using ::testing::ElementsAre; -using ::testing::ElementsAreArray; using ::testing::IsEmpty; using ::testing::SizeIs; using ::testing::UnorderedElementsAre; @@ -293,7 +292,6 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -320,18 +318,14 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "world"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map), + EqualsTermMatchInfo("world", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "world")); @@ -359,7 +353,6 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::PREFIX; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -383,14 +376,12 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "he"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "he", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he")); EXPECT_THAT(results.query_term_iterators, SizeIs(1)); @@ -442,14 +433,12 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatchWithMaxSectionID) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "he"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "he", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he")); EXPECT_THAT(results.query_term_iterators, SizeIs(1)); @@ -476,7 +465,6 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -500,14 +488,13 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello")); EXPECT_THAT(results.query_term_iterators, SizeIs(1)); @@ -534,7 +521,6 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -554,18 +540,18 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) { EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), section_id_mask); + // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map))); } - ASSERT_FALSE(results.root_iterator->Advance().ok()); // TODO(b/208654892) Support Query Terms with advanced query @@ -597,7 +583,6 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -624,18 +609,14 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "world"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map), + EqualsTermMatchInfo("world", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "world")); @@ -663,7 +644,6 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::PREFIX; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -687,13 +667,12 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "he"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "he", expected_section_ids_tf_map))); } ASSERT_FALSE(results.root_iterator->Advance().ok()); @@ -726,7 +705,6 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::PREFIX; EXPECT_THAT( @@ -755,18 +733,14 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "he"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "wo"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("he", expected_section_ids_tf_map), + EqualsTermMatchInfo("wo", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he", "wo")); EXPECT_THAT(results.query_term_iterators, SizeIs(2)); @@ -792,7 +766,6 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::PREFIX; EXPECT_THAT(AddTokenToIndex(document_id, section_id, @@ -821,18 +794,14 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "wo"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map), + EqualsTermMatchInfo("wo", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "wo")); EXPECT_THAT(results.query_term_iterators, SizeIs(2)); @@ -863,7 +832,6 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; EXPECT_THAT( @@ -888,33 +856,34 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) { EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id2); EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), section_id_mask); + // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "world"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("world", expected_section_ids_tf_map))); } ASSERT_THAT(results.root_iterator->Advance(), IsOk()); EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id1); EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), section_id_mask); + // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "world")); @@ -946,7 +915,6 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::PREFIX; EXPECT_THAT( @@ -975,13 +943,12 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "wo"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "wo", expected_section_ids_tf_map))); } ASSERT_THAT(results.root_iterator->Advance(), IsOk()); @@ -992,14 +959,12 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "he"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "he", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he", "wo")); EXPECT_THAT(results.query_term_iterators, SizeIs(2)); @@ -1030,7 +995,6 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT(AddTokenToIndex(document_id1, section_id, TermMatchType::EXACT_ONLY, "hello"), @@ -1058,13 +1022,12 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "wo"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); + EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo( + "wo", expected_section_ids_tf_map))); } ASSERT_THAT(results.root_iterator->Advance(), IsOk()); @@ -1075,14 +1038,13 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term - EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "wo")); EXPECT_THAT(results.query_term_iterators, SizeIs(2)); @@ -1112,7 +1074,6 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "animal puppy dog" @@ -1159,17 +1120,14 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "puppy"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "dog"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies)); + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("puppy", expected_section_ids_tf_map), + EqualsTermMatchInfo("dog", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], @@ -1203,17 +1161,15 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "animal"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "kitten"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies)); + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("animal", expected_section_ids_tf_map), + EqualsTermMatchInfo("kitten", expected_section_ids_tf_map))); } ASSERT_THAT(results.root_iterator->Advance(), IsOk()); @@ -1225,18 +1181,15 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "animal"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "puppy"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies)); - + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("animal", expected_section_ids_tf_map), + EqualsTermMatchInfo("puppy", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("animal", "puppy", "kitten")); @@ -1267,17 +1220,15 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 1}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "kitten"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "cat"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(term_frequencies)); + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("kitten", expected_section_ids_tf_map), + EqualsTermMatchInfo("cat", expected_section_ids_tf_map))); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], @@ -1901,9 +1852,6 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) { QueryResults results, query_processor_->ParseSearch(search_spec, ScoringSpecProto::RankingStrategy::NONE)); - // Since need_hit_term_frequency is false, the expected term frequencies - // should all be 0. - Hit::TermFrequencyArray exp_term_frequencies{0}; // Descending order of valid DocumentIds // The first Document to match (Document 2) matches on 'animal' AND 'kitten' @@ -1915,17 +1863,17 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + // Since need_hit_term_frequency is false, the expected term frequency for + // the section with the hit should be 0. + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 0}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "animal"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(exp_term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "kitten"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(exp_term_frequencies)); + EXPECT_THAT( + matched_terms_stats, + ElementsAre( + EqualsTermMatchInfo("animal", expected_section_ids_tf_map), + EqualsTermMatchInfo("kitten", expected_section_ids_tf_map))); } // The second Document to match (Document 1) matches on 'animal' AND 'puppy' @@ -1937,17 +1885,14 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) { // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + std::unordered_map<SectionId, Hit::TermFrequency> + expected_section_ids_tf_map = {{section_id, 0}}; std::vector<TermMatchInfo> matched_terms_stats; results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); - ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms - EXPECT_EQ(matched_terms_stats.at(0).term, "animal"); - EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, - ElementsAreArray(exp_term_frequencies)); - EXPECT_EQ(matched_terms_stats.at(1).term, "puppy"); - EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); - EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, - ElementsAreArray(exp_term_frequencies)); + EXPECT_THAT( + matched_terms_stats, + ElementsAre(EqualsTermMatchInfo("animal", expected_section_ids_tf_map), + EqualsTermMatchInfo("puppy", expected_section_ids_tf_map))); // This should be empty because ranking_strategy != RELEVANCE_SCORE EXPECT_THAT(results.query_term_iterators, IsEmpty()); @@ -2009,6 +1954,7 @@ TEST_P(QueryProcessorTest, DeletedFilter) { expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); + // TODO(b/208654892) Support Query Terms with advanced query if (GetParam() != SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { diff --git a/icing/scoring/advanced_scoring/advanced-scorer.h b/icing/scoring/advanced_scoring/advanced-scorer.h index 077d734..763499b 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer.h +++ b/icing/scoring/advanced_scoring/advanced-scorer.h @@ -60,13 +60,20 @@ class AdvancedScorer : public Scorer { bm25f_calculator_->PrepareToScore(query_term_iterators); } + bool is_constant() const { return score_expression_->is_constant_double(); } + private: explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression, std::unique_ptr<Bm25fCalculator> bm25f_calculator, double default_score) : score_expression_(std::move(score_expression)), bm25f_calculator_(std::move(bm25f_calculator)), - default_score_(default_score) {} + default_score_(default_score) { + if (is_constant()) { + ICING_LOG(WARNING) + << "The advanced scoring expression will evaluate to a constant."; + } + } std::unique_ptr<ScoreExpression> score_expression_; std::unique_ptr<Bm25fCalculator> bm25f_calculator_; diff --git a/icing/scoring/advanced_scoring/advanced-scorer_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_test.cc index 36d38a2..b0b32e9 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer_test.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer_test.cc @@ -437,7 +437,7 @@ TEST_F(AdvancedScorerTest, ComplexExpression) { DocHitInfo docHitInfo = DocHitInfo(document_id); ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<Scorer> scorer, + std::unique_ptr<AdvancedScorer> scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec( "pow(sin(2), 2)" // This is this.usageCount(1) @@ -449,13 +449,14 @@ TEST_F(AdvancedScorerTest, ComplexExpression) { "+ this.relevanceScore()"), /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_FALSE(scorer->is_constant()); scorer->PrepareToScore(/*query_term_iterators=*/{}); ICING_ASSERT_OK(document_store_->ReportUsage( CreateUsageReport("namespace", "uri", 0, UsageReport::USAGE_TYPE1))); ICING_ASSERT_OK(document_store_->ReportUsage( CreateUsageReport("namespace", "uri", 0, UsageReport::USAGE_TYPE1))); - EXPECT_THAT(scorer->GetScore(docHitInfo), + EXPECT_THAT(scorer->GetScore(docHitInfo, /*query_it=*/nullptr), DoubleNear(pow(sin(2), 2) + 2 / 12.34 * (10 * pow(2 * 1, sin(2)) + @@ -464,6 +465,18 @@ TEST_F(AdvancedScorerTest, ComplexExpression) { kEps)); } +TEST_F(AdvancedScorerTest, ConstantExpression) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<AdvancedScorer> scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec( + "pow(sin(2), 2)" + "+ log(2, 122) / 12.34" + "* (10 * pow(2 * 1, sin(2)) + 10 * (2 + 10))"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_TRUE(scorer->is_constant()); +} + // Should be a parsing Error TEST_F(AdvancedScorerTest, EmptyExpression) { EXPECT_THAT( diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc index 08da1c5..a8749df 100644 --- a/icing/scoring/advanced_scoring/score-expression.cc +++ b/icing/scoring/advanced_scoring/score-expression.cc @@ -17,19 +17,23 @@ namespace icing { namespace lib { -libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>> +libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> OperatorScoreExpression::Create( OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) { if (children.empty()) { return absl_ports::InvalidArgumentError( "OperatorScoreExpression must have at least one argument."); } + bool children_all_constant_double = true; for (const auto& child : children) { ICING_RETURN_ERROR_IF_NULL(child); if (child->is_document_type()) { return absl_ports::InvalidArgumentError( "Operators are not supported for document type."); } + if (!child->is_constant_double()) { + children_all_constant_double = false; + } } if (op == OperatorType::kNegative) { if (children.size() != 1) { @@ -37,8 +41,16 @@ OperatorScoreExpression::Create( "Negative operator must have only 1 argument."); } } - return std::unique_ptr<OperatorScoreExpression>( - new OperatorScoreExpression(op, std::move(children))); + std::unique_ptr<ScoreExpression> expression = + std::unique_ptr<OperatorScoreExpression>( + new OperatorScoreExpression(op, std::move(children))); + if (children_all_constant_double) { + // Because all of the children are constants, this expression does not + // depend on the DocHitInto or query_it that are passed into it. + return ConstantScoreExpression::Create( + expression->eval(DocHitInfo(), /*query_it=*/nullptr)); + } + return expression; } libtextclassifier3::StatusOr<double> OperatorScoreExpression::eval( @@ -85,7 +97,7 @@ const std::unordered_map<std::string, MathFunctionScoreExpression::FunctionType> {"sin", FunctionType::kSin}, {"cos", FunctionType::kCos}, {"tan", FunctionType::kTan}}; -libtextclassifier3::StatusOr<std::unique_ptr<MathFunctionScoreExpression>> +libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> MathFunctionScoreExpression::Create( FunctionType function_type, std::vector<std::unique_ptr<ScoreExpression>> children) { @@ -93,12 +105,16 @@ MathFunctionScoreExpression::Create( return absl_ports::InvalidArgumentError( "Math functions must have at least one argument."); } + bool children_all_constant_double = true; for (const auto& child : children) { ICING_RETURN_ERROR_IF_NULL(child); if (child->is_document_type()) { return absl_ports::InvalidArgumentError( "Math functions are not supported for document type."); } + if (!child->is_constant_double()) { + children_all_constant_double = false; + } } switch (function_type) { case FunctionType::kLog: @@ -143,8 +159,16 @@ MathFunctionScoreExpression::Create( case FunctionType::kMin: break; } - return std::unique_ptr<MathFunctionScoreExpression>( - new MathFunctionScoreExpression(function_type, std::move(children))); + std::unique_ptr<ScoreExpression> expression = + std::unique_ptr<MathFunctionScoreExpression>( + new MathFunctionScoreExpression(function_type, std::move(children))); + if (children_all_constant_double) { + // Because all of the children are constants, this expression does not + // depend on the DocHitInto or query_it that are passed into it. + return ConstantScoreExpression::Create( + expression->eval(DocHitInfo(), /*query_it=*/nullptr)); + } + return expression; } libtextclassifier3::StatusOr<double> MathFunctionScoreExpression::eval( diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h index 533ca52..f80da33 100644 --- a/icing/scoring/advanced_scoring/score-expression.h +++ b/icing/scoring/advanced_scoring/score-expression.h @@ -31,8 +31,6 @@ namespace icing { namespace lib { -// TODO(b/261474063) Simplify every ScoreExpression node to -// ConstantScoreExpression if its evaluation does not depend on a document. class ScoreExpression { public: virtual ~ScoreExpression() = default; @@ -49,6 +47,10 @@ class ScoreExpression { // Indicate whether the current expression is of document type virtual bool is_document_type() const { return false; } + + // Indicate whether the current expression is a constant double. + // Returns true if and only if the object is of ConstantScoreExpression type. + virtual bool is_constant_double() const { return false; } }; class ThisExpression : public ScoreExpression { @@ -72,7 +74,8 @@ class ThisExpression : public ScoreExpression { class ConstantScoreExpression : public ScoreExpression { public: - static std::unique_ptr<ConstantScoreExpression> Create(double c) { + static std::unique_ptr<ConstantScoreExpression> Create( + libtextclassifier3::StatusOr<double> c) { return std::unique_ptr<ConstantScoreExpression>( new ConstantScoreExpression(c)); } @@ -82,10 +85,13 @@ class ConstantScoreExpression : public ScoreExpression { return c_; } + bool is_constant_double() const override { return true; } + private: - explicit ConstantScoreExpression(double c) : c_(c) {} + explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c) + : c_(c) {} - double c_; + libtextclassifier3::StatusOr<double> c_; }; class OperatorScoreExpression : public ScoreExpression { @@ -93,12 +99,12 @@ class OperatorScoreExpression : public ScoreExpression { enum class OperatorType { kPlus, kMinus, kNegative, kTimes, kDiv }; // RETURNS: - // - An OperatorScoreExpression instance on success. + // - An OperatorScoreExpression instance on success if not simplifiable. + // - A ConstantScoreExpression instance on success if simplifiable. // - FAILED_PRECONDITION on any null pointer in children. // - INVALID_ARGUMENT on type errors. - static libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>> - Create(OperatorType op, - std::vector<std::unique_ptr<ScoreExpression>> children); + static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( + OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children); libtextclassifier3::StatusOr<double> eval( const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override; @@ -129,13 +135,13 @@ class MathFunctionScoreExpression : public ScoreExpression { static const std::unordered_map<std::string, FunctionType> kFunctionNames; // RETURNS: - // - A MathFunctionScoreExpression instance on success. + // - A MathFunctionScoreExpression instance on success if not simplifiable. + // - A ConstantScoreExpression instance on success if simplifiable. // - FAILED_PRECONDITION on any null pointer in children. // - INVALID_ARGUMENT on type errors. - static libtextclassifier3::StatusOr< - std::unique_ptr<MathFunctionScoreExpression>> - Create(FunctionType function_type, - std::vector<std::unique_ptr<ScoreExpression>> children); + static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create( + FunctionType function_type, + std::vector<std::unique_ptr<ScoreExpression>> children); libtextclassifier3::StatusOr<double> eval( const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override; diff --git a/icing/scoring/advanced_scoring/score-expression_test.cc b/icing/scoring/advanced_scoring/score-expression_test.cc new file mode 100644 index 0000000..b49b658 --- /dev/null +++ b/icing/scoring/advanced_scoring/score-expression_test.cc @@ -0,0 +1,186 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/advanced_scoring/score-expression.h" + +#include <cmath> +#include <memory> +#include <string> +#include <string_view> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::Eq; + +class NonConstantScoreExpression : public ScoreExpression { + public: + static std::unique_ptr<NonConstantScoreExpression> Create() { + return std::make_unique<NonConstantScoreExpression>(); + } + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo&, const DocHitInfoIterator*) override { + return 0; + } + + bool is_constant_double() const override { return false; } +}; + +template <typename... Args> +std::vector<std::unique_ptr<ScoreExpression>> MakeChildren(Args... args) { + std::vector<std::unique_ptr<ScoreExpression>> children; + (children.push_back(std::move(args)), ...); + return children; +} + +TEST(ScoreExpressionTest, OperatorSimplification) { + // 1 + 1 = 2 + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoreExpression> expression, + OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kPlus, + MakeChildren(ConstantScoreExpression::Create(1), + ConstantScoreExpression::Create(1)))); + ASSERT_TRUE(expression->is_constant_double()); + EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2))); + + // 1 - 2 - 3 = -4 + ICING_ASSERT_OK_AND_ASSIGN( + expression, OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kMinus, + MakeChildren(ConstantScoreExpression::Create(1), + ConstantScoreExpression::Create(2), + ConstantScoreExpression::Create(3)))); + ASSERT_TRUE(expression->is_constant_double()); + EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-4))); + + // 1 * 2 * 3 * 4 = 24 + ICING_ASSERT_OK_AND_ASSIGN( + expression, OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kTimes, + MakeChildren(ConstantScoreExpression::Create(1), + ConstantScoreExpression::Create(2), + ConstantScoreExpression::Create(3), + ConstantScoreExpression::Create(4)))); + ASSERT_TRUE(expression->is_constant_double()); + EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(24))); + + // 1 / 2 / 4 = 0.125 + ICING_ASSERT_OK_AND_ASSIGN( + expression, OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kDiv, + MakeChildren(ConstantScoreExpression::Create(1), + ConstantScoreExpression::Create(2), + ConstantScoreExpression::Create(4)))); + ASSERT_TRUE(expression->is_constant_double()); + EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(0.125))); + + // -(2) = -2 + ICING_ASSERT_OK_AND_ASSIGN( + expression, OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kNegative, + MakeChildren(ConstantScoreExpression::Create(2)))); + ASSERT_TRUE(expression->is_constant_double()); + EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-2))); +} + +TEST(ScoreExpressionTest, MathFunctionSimplification) { + // pow(2, 2) = 4 + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoreExpression> expression, + MathFunctionScoreExpression::Create( + MathFunctionScoreExpression::FunctionType::kPow, + MakeChildren(ConstantScoreExpression::Create(2), + ConstantScoreExpression::Create(2)))); + ASSERT_TRUE(expression->is_constant_double()); + EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(4))); + + // abs(-2) = 2 + ICING_ASSERT_OK_AND_ASSIGN( + expression, MathFunctionScoreExpression::Create( + MathFunctionScoreExpression::FunctionType::kAbs, + MakeChildren(ConstantScoreExpression::Create(-2)))); + ASSERT_TRUE(expression->is_constant_double()); + EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2))); + + // log(e) = 1 + ICING_ASSERT_OK_AND_ASSIGN( + expression, MathFunctionScoreExpression::Create( + MathFunctionScoreExpression::FunctionType::kLog, + MakeChildren(ConstantScoreExpression::Create(M_E)))); + ASSERT_TRUE(expression->is_constant_double()); + EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(1))); +} + +TEST(ScoreExpressionTest, CannotSimplifyNonConstant) { + // 1 + non_constant = non_constant + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoreExpression> expression, + OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kPlus, + MakeChildren(ConstantScoreExpression::Create(1), + NonConstantScoreExpression::Create()))); + ASSERT_FALSE(expression->is_constant_double()); + + // non_constant * non_constant = non_constant + ICING_ASSERT_OK_AND_ASSIGN( + expression, OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kTimes, + MakeChildren(NonConstantScoreExpression::Create(), + NonConstantScoreExpression::Create()))); + ASSERT_FALSE(expression->is_constant_double()); + + // -(non_constant) = non_constant + ICING_ASSERT_OK_AND_ASSIGN( + expression, OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kNegative, + MakeChildren(NonConstantScoreExpression::Create()))); + ASSERT_FALSE(expression->is_constant_double()); + + // pow(non_constant, 2) = non_constant + ICING_ASSERT_OK_AND_ASSIGN( + expression, MathFunctionScoreExpression::Create( + MathFunctionScoreExpression::FunctionType::kPow, + MakeChildren(NonConstantScoreExpression::Create(), + ConstantScoreExpression::Create(2)))); + ASSERT_FALSE(expression->is_constant_double()); + + // abs(non_constant) = non_constant + ICING_ASSERT_OK_AND_ASSIGN( + expression, MathFunctionScoreExpression::Create( + MathFunctionScoreExpression::FunctionType::kAbs, + MakeChildren(NonConstantScoreExpression::Create()))); + ASSERT_FALSE(expression->is_constant_double()); + + // log(non_constant) = non_constant + ICING_ASSERT_OK_AND_ASSIGN( + expression, MathFunctionScoreExpression::Create( + MathFunctionScoreExpression::FunctionType::kLog, + MakeChildren(NonConstantScoreExpression::Create()))); + ASSERT_FALSE(expression->is_constant_double()); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc index 1396dcc..ea2e190 100644 --- a/icing/scoring/advanced_scoring/scoring-visitor.cc +++ b/icing/scoring/advanced_scoring/scoring-visitor.cc @@ -132,8 +132,8 @@ void ScoringVisitor::VisitUnaryOperator(const UnaryOperatorNode* node) { std::vector<std::unique_ptr<ScoreExpression>> children; children.push_back(pop_stack()); - libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>> - expression = OperatorScoreExpression::Create( + libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> expression = + OperatorScoreExpression::Create( OperatorScoreExpression::OperatorType::kNegative, std::move(children)); if (!expression.ok()) { @@ -153,8 +153,8 @@ void ScoringVisitor::VisitNaryOperator(const NaryOperatorNode* node) { children.push_back(pop_stack()); } - libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>> - expression = absl_ports::InvalidArgumentError( + libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> expression = + absl_ports::InvalidArgumentError( absl_ports::StrCat("Unknown Nary operator: ", node->operator_text())); if (node->operator_text() == "PLUS") { diff --git a/icing/scoring/scored-document-hit.h b/icing/scoring/scored-document-hit.h index 141049e..5fc2f3a 100644 --- a/icing/scoring/scored-document-hit.h +++ b/icing/scoring/scored-document-hit.h @@ -123,8 +123,8 @@ class JoinedScoredDocumentHit { }; explicit JoinedScoredDocumentHit( - double final_score, ScoredDocumentHit&& parent_scored_document_hit, - std::vector<ScoredDocumentHit>&& child_scored_document_hits) + double final_score, ScoredDocumentHit parent_scored_document_hit, + std::vector<ScoredDocumentHit> child_scored_document_hits) : final_score_(final_score), parent_scored_document_hit_(std::move(parent_scored_document_hit)), child_scored_document_hits_(std::move(child_scored_document_hits)) {} diff --git a/icing/scoring/scorer-factory.cc b/icing/scoring/scorer-factory.cc index 600fe6b..f75b564 100644 --- a/icing/scoring/scorer-factory.cc +++ b/icing/scoring/scorer-factory.cc @@ -213,8 +213,9 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( return AdvancedScorer::Create(scoring_spec, default_score, document_store, schema_store); case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE: - ICING_LOG(WARNING) - << "JOIN_AGGREGATE_SCORE not implemented, falling back to NoScorer"; + // Use join aggregate score to rank. Since the aggregation score is + // calculated by child documents after joining (in JoinProcessor), we can + // simply use NoScorer for parent documents. [[fallthrough]]; case ScoringSpecProto::RankingStrategy::NONE: return std::make_unique<NoScorer>(default_score); diff --git a/icing/store/document-store.cc b/icing/store/document-store.cc index 3add705..9e79790 100644 --- a/icing/store/document-store.cc +++ b/icing/store/document-store.cc @@ -222,7 +222,7 @@ libtextclassifier3::StatusOr<DocumentId> DocumentStore::Put( DocumentProto&& document, int32_t num_tokens, PutDocumentStatsProto* put_document_stats) { document.mutable_internal_fields()->set_length_in_tokens(num_tokens); - return InternalPut(document, put_document_stats); + return InternalPut(std::move(document), put_document_stats); } DocumentStore::~DocumentStore() { @@ -840,7 +840,7 @@ libtextclassifier3::Status DocumentStore::UpdateHeader(const Crc32& checksum) { } libtextclassifier3::StatusOr<DocumentId> DocumentStore::InternalPut( - DocumentProto& document, PutDocumentStatsProto* put_document_stats) { + DocumentProto&& document, PutDocumentStatsProto* put_document_stats) { std::unique_ptr<Timer> put_timer = clock_.GetNewTimer(); ICING_RETURN_IF_ERROR(document_validator_.Validate(document)); @@ -1714,7 +1714,7 @@ DocumentStore::OptimizeInto(const std::string& new_directory, } // Guaranteed to have a document now. - DocumentProto document_to_keep = document_or.ValueOrDie(); + DocumentProto document_to_keep = std::move(document_or).ValueOrDie(); libtextclassifier3::StatusOr<DocumentId> new_document_id_or; if (document_to_keep.internal_fields().length_in_tokens() == 0) { @@ -1729,11 +1729,12 @@ DocumentStore::OptimizeInto(const std::string& new_directory, TokenizedDocument tokenized_document( std::move(tokenized_document_or).ValueOrDie()); new_document_id_or = new_doc_store->Put( - document_to_keep, tokenized_document.num_string_tokens()); + std::move(document_to_keep), tokenized_document.num_string_tokens()); } else { // TODO(b/144458732): Implement a more robust version of // TC_ASSIGN_OR_RETURN that can support error logging. - new_document_id_or = new_doc_store->InternalPut(document_to_keep); + new_document_id_or = + new_doc_store->InternalPut(std::move(document_to_keep)); } if (!new_document_id_or.ok()) { ICING_LOG(ERROR) << new_document_id_or.status().error_message() diff --git a/icing/store/document-store.h b/icing/store/document-store.h index 58977bf..bda351d 100644 --- a/icing/store/document-store.h +++ b/icing/store/document-store.h @@ -628,7 +628,7 @@ class DocumentStore { libtextclassifier3::Status UpdateHeader(const Crc32& checksum); libtextclassifier3::StatusOr<DocumentId> InternalPut( - DocumentProto& document, + DocumentProto&& document, PutDocumentStatsProto* put_document_stats = nullptr); // Helper function to do batch deletes. Documents with the given diff --git a/icing/testing/common-matchers.cc b/icing/testing/common-matchers.cc new file mode 100644 index 0000000..cd4e446 --- /dev/null +++ b/icing/testing/common-matchers.cc @@ -0,0 +1,124 @@ +// Copyright (C) 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +ExtractTermFrequenciesResult ExtractTermFrequencies( + const std::unordered_map<SectionId, Hit::TermFrequency>& + section_ids_tf_map) { + ExtractTermFrequenciesResult result; + for (const auto& [section_id, tf] : section_ids_tf_map) { + result.term_frequencies[section_id] = tf; + result.section_mask |= UINT64_C(1) << section_id; + } + return result; +} + +CheckTermFrequencyResult CheckTermFrequency( + const std::array<Hit::TermFrequency, kTotalNumSections>& + expected_term_frequencies, + const std::array<Hit::TermFrequency, kTotalNumSections>& + actual_term_frequencies) { + CheckTermFrequencyResult result; + for (SectionId section_id = 0; section_id < kTotalNumSections; ++section_id) { + if (expected_term_frequencies.at(section_id) != + actual_term_frequencies.at(section_id)) { + result.term_frequencies_match = false; + } + } + result.actual_term_frequencies_str = + absl_ports::StrCat("[", + absl_ports::StrJoin(actual_term_frequencies, ",", + absl_ports::NumberFormatter()), + "]"); + result.expected_term_frequencies_str = + absl_ports::StrCat("[", + absl_ports::StrJoin(expected_term_frequencies, ",", + absl_ports::NumberFormatter()), + "]"); + return result; +} + +std::string StatusCodeToString(libtextclassifier3::StatusCode code) { + switch (code) { + case libtextclassifier3::StatusCode::OK: + return "OK"; + case libtextclassifier3::StatusCode::CANCELLED: + return "CANCELLED"; + case libtextclassifier3::StatusCode::UNKNOWN: + return "UNKNOWN"; + case libtextclassifier3::StatusCode::INVALID_ARGUMENT: + return "INVALID_ARGUMENT"; + case libtextclassifier3::StatusCode::DEADLINE_EXCEEDED: + return "DEADLINE_EXCEEDED"; + case libtextclassifier3::StatusCode::NOT_FOUND: + return "NOT_FOUND"; + case libtextclassifier3::StatusCode::ALREADY_EXISTS: + return "ALREADY_EXISTS"; + case libtextclassifier3::StatusCode::PERMISSION_DENIED: + return "PERMISSION_DENIED"; + case libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED: + return "RESOURCE_EXHAUSTED"; + case libtextclassifier3::StatusCode::FAILED_PRECONDITION: + return "FAILED_PRECONDITION"; + case libtextclassifier3::StatusCode::ABORTED: + return "ABORTED"; + case libtextclassifier3::StatusCode::OUT_OF_RANGE: + return "OUT_OF_RANGE"; + case libtextclassifier3::StatusCode::UNIMPLEMENTED: + return "UNIMPLEMENTED"; + case libtextclassifier3::StatusCode::INTERNAL: + return "INTERNAL"; + case libtextclassifier3::StatusCode::UNAVAILABLE: + return "UNAVAILABLE"; + case libtextclassifier3::StatusCode::DATA_LOSS: + return "DATA_LOSS"; + case libtextclassifier3::StatusCode::UNAUTHENTICATED: + return "UNAUTHENTICATED"; + default: + return ""; + } +} + +std::string ProtoStatusCodeToString(StatusProto::Code code) { + switch (code) { + case StatusProto::OK: + return "OK"; + case StatusProto::UNKNOWN: + return "UNKNOWN"; + case StatusProto::INVALID_ARGUMENT: + return "INVALID_ARGUMENT"; + case StatusProto::NOT_FOUND: + return "NOT_FOUND"; + case StatusProto::ALREADY_EXISTS: + return "ALREADY_EXISTS"; + case StatusProto::OUT_OF_SPACE: + return "OUT_OF_SPACE"; + case StatusProto::FAILED_PRECONDITION: + return "FAILED_PRECONDITION"; + case StatusProto::ABORTED: + return "ABORTED"; + case StatusProto::INTERNAL: + return "INTERNAL"; + case StatusProto::WARNING_DATA_LOSS: + return "WARNING_DATA_LOSS"; + default: + return ""; + } +} + +} // namespace lib +} // namespace icing diff --git a/icing/testing/common-matchers.h b/icing/testing/common-matchers.h index e090800..db7b7ef 100644 --- a/icing/testing/common-matchers.h +++ b/icing/testing/common-matchers.h @@ -22,11 +22,11 @@ #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/status_macros.h" -#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/absl_ports/str_join.h" #include "icing/index/hit/doc-hit-info.h" +#include "icing/index/hit/hit.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/portable/equals-proto.h" @@ -35,7 +35,6 @@ #include "icing/schema/schema-store.h" #include "icing/schema/section.h" #include "icing/scoring/scored-document-hit.h" -#include "icing/util/status-macros.h" namespace icing { namespace lib { @@ -69,43 +68,85 @@ MATCHER_P2(EqualsDocHitInfo, document_id, section_ids, "") { actual.hit_section_ids_mask() == section_mask; } +struct ExtractTermFrequenciesResult { + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies = {0}; + SectionIdMask section_mask = kSectionIdMaskNone; +}; +// Extracts the term frequencies represented by the section_ids_tf_map. +// Returns: +// - a SectionIdMask representing all sections that appears as entries in the +// map, even if they have an entry with term_frequency==0 +// - an array representing the term frequencies for each section. Sections not +// present in section_ids_tf_map have a term frequency of 0. +ExtractTermFrequenciesResult ExtractTermFrequencies( + const std::unordered_map<SectionId, Hit::TermFrequency>& + section_ids_tf_map); + +struct CheckTermFrequencyResult { + std::string expected_term_frequencies_str; + std::string actual_term_frequencies_str; + bool term_frequencies_match = true; +}; +// Checks that the term frequencies in actual_term_frequencies match those +// specified in expected_section_ids_tf_map. If there is no entry in +// expected_section_ids_tf_map, then it is assumed that the term frequency for +// that section is 0. +// Returns: +// - a bool indicating if the term frequencies match +// - debug strings representing the contents of the actual and expected term +// term frequency arrays. +CheckTermFrequencyResult CheckTermFrequency( + const std::array<Hit::TermFrequency, kTotalNumSections>& + expected_term_frequencies, + const std::array<Hit::TermFrequency, kTotalNumSections>& + actual_term_frequencies); + // Used to match a DocHitInfo MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id, section_ids_to_term_frequencies_map, "") { const DocHitInfoTermFrequencyPair& actual = arg; - SectionIdMask section_mask = kSectionIdMaskNone; - - bool term_frequency_as_expected = true; - std::vector<Hit::TermFrequency> expected_tfs; - std::vector<Hit::TermFrequency> actual_tfs; - for (auto itr = section_ids_to_term_frequencies_map.begin(); - itr != section_ids_to_term_frequencies_map.end(); itr++) { - SectionId section_id = itr->first; - section_mask |= UINT64_C(1) << section_id; - expected_tfs.push_back(itr->second); - actual_tfs.push_back(actual.hit_term_frequency(section_id)); - if (actual.hit_term_frequency(section_id) != itr->second) { - term_frequency_as_expected = false; - } + std::array<Hit::TermFrequency, kTotalNumSections> actual_tf_array; + for (SectionId section_id = 0; section_id < kTotalNumSections; ++section_id) { + actual_tf_array[section_id] = actual.hit_term_frequency(section_id); } - std::string actual_term_frequencies = absl_ports::StrCat( - "[", absl_ports::StrJoin(actual_tfs, ",", absl_ports::NumberFormatter()), - "]"); - std::string expected_term_frequencies = absl_ports::StrCat( - "[", - absl_ports::StrJoin(expected_tfs, ",", absl_ports::NumberFormatter()), - "]"); + ExtractTermFrequenciesResult expected = + ExtractTermFrequencies(section_ids_to_term_frequencies_map); + CheckTermFrequencyResult check_tf_result = + CheckTermFrequency(expected.term_frequencies, actual_tf_array); + *result_listener << IcingStringUtil::StringPrintf( "(actual is {document_id=%d, section_mask=%" PRIu64 ", term_frequencies=%s}, but expected was " "{document_id=%d, section_mask=%" PRIu64 ", term_frequencies=%s}.)", actual.doc_hit_info().document_id(), actual.doc_hit_info().hit_section_ids_mask(), - actual_term_frequencies.c_str(), document_id, section_mask, - expected_term_frequencies.c_str()); + check_tf_result.actual_term_frequencies_str.c_str(), document_id, + expected.section_mask, + check_tf_result.expected_term_frequencies_str.c_str()); return actual.doc_hit_info().document_id() == document_id && - actual.doc_hit_info().hit_section_ids_mask() == section_mask && - term_frequency_as_expected; + actual.doc_hit_info().hit_section_ids_mask() == + expected.section_mask && + check_tf_result.term_frequencies_match; +} + +MATCHER_P2(EqualsTermMatchInfo, term, section_ids_to_term_frequencies_map, "") { + const TermMatchInfo& actual = arg; + std::string term_str(term); + ExtractTermFrequenciesResult expected = + ExtractTermFrequencies(section_ids_to_term_frequencies_map); + CheckTermFrequencyResult check_tf_result = + CheckTermFrequency(expected.term_frequencies, actual.term_frequencies); + *result_listener << IcingStringUtil::StringPrintf( + "(actual is {term=%s, section_mask=%" PRIu64 + ", term_frequencies=%s}, but expected was " + "{term=%s, section_mask=%" PRIu64 ", term_frequencies=%s}.)", + actual.term.data(), actual.section_ids_mask, + check_tf_result.actual_term_frequencies_str.c_str(), term_str.data(), + expected.section_mask, + check_tf_result.expected_term_frequencies_str.c_str()); + return actual.term == term && + actual.section_ids_mask == expected.section_mask && + check_tf_result.term_frequencies_match; } class ScoredDocumentHitFormatter { @@ -337,73 +378,9 @@ MATCHER_P(EqualsSetSchemaResult, expected, "") { return false; } -inline std::string StatusCodeToString(libtextclassifier3::StatusCode code) { - switch (code) { - case libtextclassifier3::StatusCode::OK: - return "OK"; - case libtextclassifier3::StatusCode::CANCELLED: - return "CANCELLED"; - case libtextclassifier3::StatusCode::UNKNOWN: - return "UNKNOWN"; - case libtextclassifier3::StatusCode::INVALID_ARGUMENT: - return "INVALID_ARGUMENT"; - case libtextclassifier3::StatusCode::DEADLINE_EXCEEDED: - return "DEADLINE_EXCEEDED"; - case libtextclassifier3::StatusCode::NOT_FOUND: - return "NOT_FOUND"; - case libtextclassifier3::StatusCode::ALREADY_EXISTS: - return "ALREADY_EXISTS"; - case libtextclassifier3::StatusCode::PERMISSION_DENIED: - return "PERMISSION_DENIED"; - case libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED: - return "RESOURCE_EXHAUSTED"; - case libtextclassifier3::StatusCode::FAILED_PRECONDITION: - return "FAILED_PRECONDITION"; - case libtextclassifier3::StatusCode::ABORTED: - return "ABORTED"; - case libtextclassifier3::StatusCode::OUT_OF_RANGE: - return "OUT_OF_RANGE"; - case libtextclassifier3::StatusCode::UNIMPLEMENTED: - return "UNIMPLEMENTED"; - case libtextclassifier3::StatusCode::INTERNAL: - return "INTERNAL"; - case libtextclassifier3::StatusCode::UNAVAILABLE: - return "UNAVAILABLE"; - case libtextclassifier3::StatusCode::DATA_LOSS: - return "DATA_LOSS"; - case libtextclassifier3::StatusCode::UNAUTHENTICATED: - return "UNAUTHENTICATED"; - default: - return ""; - } -} +std::string StatusCodeToString(libtextclassifier3::StatusCode code); -inline std::string ProtoStatusCodeToString(StatusProto::Code code) { - switch (code) { - case StatusProto::OK: - return "OK"; - case StatusProto::UNKNOWN: - return "UNKNOWN"; - case StatusProto::INVALID_ARGUMENT: - return "INVALID_ARGUMENT"; - case StatusProto::NOT_FOUND: - return "NOT_FOUND"; - case StatusProto::ALREADY_EXISTS: - return "ALREADY_EXISTS"; - case StatusProto::OUT_OF_SPACE: - return "OUT_OF_SPACE"; - case StatusProto::FAILED_PRECONDITION: - return "FAILED_PRECONDITION"; - case StatusProto::ABORTED: - return "ABORTED"; - case StatusProto::INTERNAL: - return "INTERNAL"; - case StatusProto::WARNING_DATA_LOSS: - return "WARNING_DATA_LOSS"; - default: - return ""; - } -} +std::string ProtoStatusCodeToString(StatusProto::Code code); MATCHER(IsOk, "") { libtextclassifier3::StatusAdapter adapter(arg); diff --git a/icing/util/logging.h b/icing/util/logging.h index 7742302..23280dc 100644 --- a/icing/util/logging.h +++ b/icing/util/logging.h @@ -131,13 +131,29 @@ class LogMessage { inline constexpr char kIcingLoggingTag[] = "AppSearchIcing"; -#define ICING_VLOG(verbose_level) \ - ::icing::lib::LogMessage(::icing::lib::LogSeverity::VERBOSE, verbose_level, \ - __FILE__, __LINE__) \ - .stream() -#define ICING_LOG(severity) \ - ::icing::lib::LogMessage(::icing::lib::LogSeverity::severity, \ - /*verbosity=*/0, __FILE__, __LINE__) \ +// Define consts to make it easier to refer to log severities in code. +constexpr ::icing::lib::LogSeverity::Code VERBOSE = + ::icing::lib::LogSeverity::VERBOSE; + +constexpr ::icing::lib::LogSeverity::Code DBG = ::icing::lib::LogSeverity::DBG; + +constexpr ::icing::lib::LogSeverity::Code INFO = + ::icing::lib::LogSeverity::INFO; + +constexpr ::icing::lib::LogSeverity::Code WARNING = + ::icing::lib::LogSeverity::WARNING; + +constexpr ::icing::lib::LogSeverity::Code ERROR = + ::icing::lib::LogSeverity::ERROR; + +constexpr ::icing::lib::LogSeverity::Code FATAL = + ::icing::lib::LogSeverity::FATAL; + +#define ICING_VLOG(verbose_level) \ + ::icing::lib::LogMessage(VERBOSE, verbose_level, __FILE__, __LINE__).stream() + +#define ICING_LOG(severity) \ + ::icing::lib::LogMessage(severity, /*verbosity=*/0, __FILE__, __LINE__) \ .stream() } // namespace lib diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index b2435d0..c9e2b1d 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -505,13 +505,16 @@ message JoinSpecProto { // taken on it. If JOIN_AGGREGATE_SCORE is used in the base SearchSpecProto, // the COUNT value will rank entity documents based on the number of child // documents. - enum AggregationScore { - UNDEFINED = 0; - COUNT = 1; - MIN = 2; - AVG = 3; - MAX = 4; - SUM = 5; + message AggregationScoringStrategy { + enum Code { + NONE = 0; // No aggregation strategy for child documents and use parent + // document score. + COUNT = 1; + MIN = 2; + AVG = 3; + MAX = 4; + SUM = 5; + } } - optional AggregationScore aggregation_score_strategy = 5 [default = COUNT]; + optional AggregationScoringStrategy.Code aggregation_scoring_strategy = 5; } diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index 00194ce..a4f3a30 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=496703735) +set(synced_AOSP_CL_number=500254546) |