aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlex Saveliev <alexsav@google.com>2023-01-10 19:23:34 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2023-01-10 19:23:34 +0000
commit6e15dd0d337c65739900d7b95f6408d7413c8196 (patch)
tree0ba91ee8775e34738340187614b97b6c5ffcbc8c
parent48b8f6943906165ec50ebecb9551497ac6faa450 (diff)
parent947f3d55bb1871285790facda2aa76e02c27a289 (diff)
downloadicing-6e15dd0d337c65739900d7b95f6408d7413c8196.tar.gz
Merge remote-tracking branch 'aosp/upstream-master' into androidx-main am: 947f3d55bb
Original change: https://android-review.googlesource.com/c/platform/external/icing/+/2381052 Change-Id: I1e1a7398fee34eb07cc5e2ed7dd521ac60dae9bb Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r--icing/file/posting_list/posting-list-identifier.h3
-rw-r--r--icing/icing-search-engine.cc10
-rw-r--r--icing/icing-search-engine_test.cc155
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-and_test.cc80
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-or_test.cc145
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc12
-rw-r--r--icing/index/lite/lite-index_test.cc57
-rw-r--r--icing/index/numeric/integer-index-storage.h186
-rw-r--r--icing/jni/icing-search-engine-jni.cc47
-rw-r--r--icing/join/aggregation-scorer.cc (renamed from icing/join/aggregate-scorer.cc)86
-rw-r--r--icing/join/aggregation-scorer.h (renamed from icing/join/aggregate-scorer.h)12
-rw-r--r--icing/join/aggregation-scorer_test.cc215
-rw-r--r--icing/join/join-processor.cc23
-rw-r--r--icing/join/join-processor_test.cc624
-rw-r--r--icing/query/advanced_query_parser/query-visitor.cc18
-rw-r--r--icing/query/advanced_query_parser/query-visitor.h31
-rw-r--r--icing/query/advanced_query_parser/query-visitor_test.cc296
-rw-r--r--icing/query/query-features.h8
-rw-r--r--icing/query/query-processor.cc5
-rw-r--r--icing/query/query-processor_test.cc292
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer.h9
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer_test.cc17
-rw-r--r--icing/scoring/advanced_scoring/score-expression.cc36
-rw-r--r--icing/scoring/advanced_scoring/score-expression.h34
-rw-r--r--icing/scoring/advanced_scoring/score-expression_test.cc186
-rw-r--r--icing/scoring/advanced_scoring/scoring-visitor.cc8
-rw-r--r--icing/scoring/scored-document-hit.h4
-rw-r--r--icing/scoring/scorer-factory.cc5
-rw-r--r--icing/store/document-store.cc11
-rw-r--r--icing/store/document-store.h2
-rw-r--r--icing/testing/common-matchers.cc124
-rw-r--r--icing/testing/common-matchers.h163
-rw-r--r--icing/util/logging.h30
-rw-r--r--proto/icing/proto/search.proto19
-rw-r--r--synced_AOSP_CL_number.txt2
35 files changed, 2097 insertions, 858 deletions
diff --git a/icing/file/posting_list/posting-list-identifier.h b/icing/file/posting_list/posting-list-identifier.h
index 05c7ce5..54b2888 100644
--- a/icing/file/posting_list/posting-list-identifier.h
+++ b/icing/file/posting_list/posting-list-identifier.h
@@ -109,7 +109,8 @@ class PostingListIdentifier {
private:
uint32_t val_;
-};
+} __attribute__((packed));
+static_assert(sizeof(PostingListIdentifier) == 4, "");
} // namespace lib
} // namespace icing
diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc
index 33e2ca1..bf9c102 100644
--- a/icing/icing-search-engine.cc
+++ b/icing/icing-search-engine.cc
@@ -1191,9 +1191,13 @@ DeleteResultProto IcingSearchEngine::Delete(const std::string_view name_space,
// that can support error logging.
libtextclassifier3::Status status = document_store_->Delete(name_space, uri);
if (!status.ok()) {
- ICING_LOG(ERROR) << status.error_message()
- << "Failed to delete Document. namespace: " << name_space
- << ", uri: " << uri;
+ LogSeverity::Code severity = ERROR;
+ if (absl_ports::IsNotFound(status)) {
+ severity = DBG;
+ }
+ ICING_LOG(severity) << status.error_message()
+ << "Failed to delete Document. namespace: "
+ << name_space << ", uri: " << uri;
TransformStatus(status, result_status);
return result_proto;
}
diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc
index e7158ad..dff38fb 100644
--- a/icing/icing-search-engine_test.cc
+++ b/icing/icing-search-engine_test.cc
@@ -10581,6 +10581,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
.AddStringProperty("lastName", "last1")
.AddStringProperty("emailAddress", "email1@gmail.com")
.SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .SetScore(1)
.Build();
DocumentProto person2 =
DocumentBuilder()
@@ -10590,6 +10591,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
.AddStringProperty("lastName", "last2")
.AddStringProperty("emailAddress", "email2@gmail.com")
.SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .SetScore(2)
.Build();
DocumentProto person3 =
DocumentBuilder()
@@ -10599,6 +10601,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
.AddStringProperty("lastName", "last3")
.AddStringProperty("emailAddress", "email3@gmail.com")
.SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .SetScore(3)
.Build();
DocumentProto email1 =
@@ -10608,6 +10611,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
.AddStringProperty("subject", "test subject 1")
.AddStringProperty("personQualifiedId", "pkg$db/namespace#person1")
.SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .SetScore(3)
.Build();
DocumentProto email2 =
DocumentBuilder()
@@ -10616,6 +10620,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
.AddStringProperty("subject", "test subject 2")
.AddStringProperty("personQualifiedId", "pkg$db/namespace#person2")
.SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .SetScore(2)
.Build();
DocumentProto email3 =
DocumentBuilder()
@@ -10625,6 +10630,7 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
.AddStringProperty("personQualifiedId",
R"(pkg$db/name\#space\\\\#person3)") // escaped
.SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .SetScore(1)
.Build();
IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
@@ -10644,12 +10650,12 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
// JoinSpec
JoinSpecProto* join_spec = search_spec.mutable_join_spec();
- // Set max_joined_child_count as 2, so only email 3, email2 will be included
- // in the nested result and email1 will be truncated.
- join_spec->set_max_joined_child_count(2);
+ join_spec->set_max_joined_child_count(100);
join_spec->set_parent_property_expression(
std::string(JoinProcessor::kQualifiedIdExpr));
join_spec->set_child_property_expression("personQualifiedId");
+ join_spec->set_aggregation_scoring_strategy(
+ JoinSpecProto::AggregationScoringStrategy::MAX);
JoinSpecProto::NestedSpecProto* nested_spec =
join_spec->mutable_nested_spec();
SearchSpecProto* nested_search_spec = nested_spec->mutable_search_spec();
@@ -10665,12 +10671,20 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
ResultSpecProto result_spec;
result_spec.set_num_per_page(1);
+ // Since we:
+ // - Use MAX for aggregation scoring strategy.
+ // - (Default) use DOCUMENT_SCORE to score child documents.
+ // - (Default) use DESC as the ranking order.
+ //
+ // person1 + email1 should have the highest aggregated score (3) and be
+ // returned first. person2 + email2 (aggregated score = 2) should be the
+ // second, and person3 + email3 (aggregated score = 1) should be the last.
SearchResultProto expected_result1;
expected_result1.mutable_status()->set_code(StatusProto::OK);
SearchResultProto::ResultProto* result_proto1 =
expected_result1.mutable_results()->Add();
- *result_proto1->mutable_document() = person3;
- *result_proto1->mutable_joined_results()->Add()->mutable_document() = email3;
+ *result_proto1->mutable_document() = person1;
+ *result_proto1->mutable_joined_results()->Add()->mutable_document() = email1;
SearchResultProto expected_result2;
expected_result2.mutable_status()->set_code(StatusProto::OK);
@@ -10683,8 +10697,8 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
expected_result3.mutable_status()->set_code(StatusProto::OK);
SearchResultProto::ResultProto* result_proto3 =
expected_result3.mutable_results()->Add();
- *result_proto3->mutable_document() = person1;
- *result_proto3->mutable_joined_results()->Add()->mutable_document() = email1;
+ *result_proto3->mutable_document() = person3;
+ *result_proto3->mutable_joined_results()->Add()->mutable_document() = email3;
SearchResultProto result1 =
icing.Search(search_spec, scoring_spec, result_spec);
@@ -10708,133 +10722,6 @@ TEST_F(IcingSearchEngineTest, JoinByQualifiedId) {
EqualsSearchResultIgnoreStatsAndScores(expected_result3));
}
-TEST_F(IcingSearchEngineTest, InvalidJoins) {
- SchemaProto schema =
- SchemaBuilder()
- .AddType(SchemaTypeConfigBuilder()
- .SetType("Person")
- .AddProperty(PropertyConfigBuilder()
- .SetName("firstName")
- .SetDataTypeString(TERM_MATCH_PREFIX,
- TOKENIZER_PLAIN)
- .SetCardinality(CARDINALITY_OPTIONAL))
- .AddProperty(PropertyConfigBuilder()
- .SetName("lastName")
- .SetDataTypeString(TERM_MATCH_PREFIX,
- TOKENIZER_PLAIN)
- .SetCardinality(CARDINALITY_OPTIONAL))
- .AddProperty(PropertyConfigBuilder()
- .SetName("emailAddress")
- .SetDataTypeString(TERM_MATCH_PREFIX,
- TOKENIZER_PLAIN)
- .SetCardinality(CARDINALITY_OPTIONAL)))
- .AddType(SchemaTypeConfigBuilder().SetType("Email").AddProperty(
- PropertyConfigBuilder()
- .SetName("subjectId")
- .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN)
- .SetCardinality(CARDINALITY_OPTIONAL)))
- .Build();
-
- DocumentProto person1 =
- DocumentBuilder()
- .SetKey("pkg$db/namespace", "person1")
- .SetSchema("Person")
- .AddStringProperty("firstName", "first1")
- .AddStringProperty("lastName", "last1")
- .AddStringProperty("emailAddress", "email1@gmail.com")
- .SetCreationTimestampMs(kDefaultCreationTimestampMs)
- .Build();
- DocumentProto person2 =
- DocumentBuilder()
- .SetKey("pkg$db/namespace\\", "person2")
- .SetSchema("Person")
- .AddStringProperty("firstName", "first2")
- .AddStringProperty("lastName", "last2")
- .AddStringProperty("emailAddress", "email2@gmail.com")
- .SetCreationTimestampMs(kDefaultCreationTimestampMs)
- .Build();
-
- // "invalid format" does not refer to any document, so it will not be joined
- // to any document.
- DocumentProto email1 =
- DocumentBuilder()
- .SetKey("namespace", "email1")
- .SetSchema("Email")
- .AddStringProperty("subjectId", "invalid format")
- .SetCreationTimestampMs(kDefaultCreationTimestampMs)
- .Build();
- // This will not be joined because the # in the subjectId is escaped.
- DocumentProto email2 =
- DocumentBuilder()
- .SetKey("namespace", "email2")
- .SetSchema("Email")
- .AddStringProperty("subjectId", "pkg$db/namespace\\#person2")
- .SetCreationTimestampMs(kDefaultCreationTimestampMs)
- .Build();
-
- IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
- ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
- ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk());
- ASSERT_THAT(icing.Put(person1).status(), ProtoIsOk());
- ASSERT_THAT(icing.Put(person2).status(), ProtoIsOk());
- ASSERT_THAT(icing.Put(email1).status(), ProtoIsOk());
- ASSERT_THAT(icing.Put(email2).status(), ProtoIsOk());
-
- // Parent SearchSpec
- SearchSpecProto search_spec;
- search_spec.set_term_match_type(TermMatchType::PREFIX);
- search_spec.set_query("first");
-
- // JoinSpec
- JoinSpecProto* join_spec = search_spec.mutable_join_spec();
- // Set max_joined_child_count as 2, so only email 3, email2 will be included
- // in the nested result and email1 will be truncated.
- join_spec->set_max_joined_child_count(2);
- join_spec->set_parent_property_expression(
- std::string(JoinProcessor::kQualifiedIdExpr));
- join_spec->set_child_property_expression("subjectId");
- JoinSpecProto::NestedSpecProto* nested_spec =
- join_spec->mutable_nested_spec();
- SearchSpecProto* nested_search_spec = nested_spec->mutable_search_spec();
- nested_search_spec->set_term_match_type(TermMatchType::PREFIX);
- nested_search_spec->set_query("");
- *nested_spec->mutable_scoring_spec() = GetDefaultScoringSpec();
- *nested_spec->mutable_result_spec() = ResultSpecProto::default_instance();
-
- // Parent ScoringSpec
- ScoringSpecProto scoring_spec = GetDefaultScoringSpec();
-
- // Parent ResultSpec
- ResultSpecProto result_spec;
- result_spec.set_num_per_page(1);
-
- SearchResultProto expected_result1;
- expected_result1.mutable_status()->set_code(StatusProto::OK);
- SearchResultProto::ResultProto* result_proto1 =
- expected_result1.mutable_results()->Add();
- *result_proto1->mutable_document() = person2;
-
- SearchResultProto expected_result2;
- expected_result2.mutable_status()->set_code(StatusProto::OK);
- SearchResultProto::ResultProto* result_proto2 =
- expected_result2.mutable_results()->Add();
- *result_proto2->mutable_document() = person1;
-
- SearchResultProto result1 =
- icing.Search(search_spec, scoring_spec, result_spec);
- uint64_t next_page_token = result1.next_page_token();
- EXPECT_THAT(next_page_token, Ne(kInvalidNextPageToken));
- expected_result1.set_next_page_token(next_page_token);
- EXPECT_THAT(result1,
- EqualsSearchResultIgnoreStatsAndScores(expected_result1));
-
- SearchResultProto result2 = icing.GetNextPage(next_page_token);
- next_page_token = result2.next_page_token();
- EXPECT_THAT(next_page_token, Eq(kInvalidNextPageToken));
- EXPECT_THAT(result2,
- EqualsSearchResultIgnoreStatsAndScores(expected_result2));
-}
-
TEST_F(IcingSearchEngineTest, NumericFilterAdvancedQuerySucceeds) {
IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
diff --git a/icing/index/iterator/doc-hit-info-iterator-and_test.cc b/icing/index/iterator/doc-hit-info-iterator-and_test.cc
index e4730fe..9b9f44b 100644
--- a/icing/index/iterator/doc-hit-info-iterator-and_test.cc
+++ b/icing/index/iterator/doc-hit-info-iterator-and_test.cc
@@ -32,10 +32,8 @@ namespace lib {
namespace {
using ::testing::ElementsAre;
-using ::testing::ElementsAreArray;
using ::testing::Eq;
using ::testing::IsEmpty;
-using ::testing::SizeIs;
TEST(CreateAndIteratorTest, And) {
// Basic test that we can create a working And iterator. Further testing of
@@ -202,23 +200,24 @@ TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) {
{
// Arbitrary section ids for the documents in the DocHitInfoIterators.
// Created to test correct section_id_mask behavior.
- SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{
- 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0};
- SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2{
- 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
-
DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4);
doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1);
doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2);
doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3);
doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4);
+ SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1 = {{0, 1}, {2, 2}, {4, 3}, {6, 4}};
+
DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(4);
doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2);
doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6);
+ SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map2 = {{1, 2}, {2, 6}};
+
std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1};
std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2};
@@ -240,29 +239,25 @@ TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) {
EXPECT_THAT(and_iter.doc_hit_info().document_id(), Eq(4));
and_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "hi");
- EXPECT_EQ(matched_terms_stats.at(1).term, "hello");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies1));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies2));
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask2);
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("hi", expected_section_ids_tf_map1),
+ EqualsTermMatchInfo("hello", expected_section_ids_tf_map2)));
EXPECT_FALSE(and_iter.Advance().ok());
}
{
// Arbitrary section ids for the documents in the DocHitInfoIterators.
// Created to test correct section_id_mask behavior.
- SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{
- 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
-
DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4);
doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1);
doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2);
+ SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1 = {{0, 1}, {2, 2}};
+
std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1};
std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info1};
@@ -284,11 +279,8 @@ TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) {
EXPECT_THAT(and_iter.doc_hit_info().document_id(), Eq(4));
and_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "hi");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies1));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1);
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "hi", expected_section_ids_tf_map1)));
EXPECT_FALSE(and_iter.Advance().ok());
}
@@ -470,37 +462,34 @@ TEST(DocHitInfoIteratorAndNaryTest, PopulateMatchedTermsStats) {
// Arbitrary section ids/term frequencies for the documents in the
// DocHitInfoIterators.
// For term "hi", document 10 and 8
- SectionIdMask section_id_mask1_hi = 0b01000101; // hits in sections 0, 2, 6
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hi{
- 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0};
DocHitInfoTermFrequencyPair doc_hit_info1_hi = DocHitInfo(10);
doc_hit_info1_hi.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1);
doc_hit_info1_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2);
doc_hit_info1_hi.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1_hi = {{0, 1}, {2, 2}, {6, 4}};
DocHitInfoTermFrequencyPair doc_hit_info2_hi = DocHitInfo(8);
doc_hit_info2_hi.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2);
doc_hit_info2_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6);
// For term "hello", document 10 and 9
- SectionIdMask section_id_mask1_hello = 0b00001001; // hits in sections 0, 3
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hello{
- 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
DocHitInfoTermFrequencyPair doc_hit_info1_hello = DocHitInfo(10);
doc_hit_info1_hello.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2);
doc_hit_info1_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1_hello = {{0, 2}, {3, 3}};
DocHitInfoTermFrequencyPair doc_hit_info2_hello = DocHitInfo(9);
doc_hit_info2_hello.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/3);
doc_hit_info2_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/2);
// For term "ciao", document 10 and 9
- SectionIdMask section_id_mask1_ciao = 0b00000011; // hits in sections 0, 1
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_ciao{
- 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
DocHitInfoTermFrequencyPair doc_hit_info1_ciao = DocHitInfo(10);
doc_hit_info1_ciao.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2);
doc_hit_info1_ciao.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/3);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1_ciao = {{0, 2}, {1, 3}};
DocHitInfoTermFrequencyPair doc_hit_info2_ciao = DocHitInfo(9);
doc_hit_info2_ciao.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3);
@@ -534,19 +523,12 @@ TEST(DocHitInfoIteratorAndNaryTest, PopulateMatchedTermsStats) {
EXPECT_THAT(and_iter.doc_hit_info().document_id(), Eq(10));
and_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(3)); // 3 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "hi");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies1_hi));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1_hi);
- EXPECT_EQ(matched_terms_stats.at(1).term, "hello");
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies1_hello));
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask1_hello);
- EXPECT_EQ(matched_terms_stats.at(2).term, "ciao");
- EXPECT_THAT(matched_terms_stats.at(2).term_frequencies,
- ElementsAreArray(term_frequencies1_ciao));
- EXPECT_EQ(matched_terms_stats.at(2).section_ids_mask, section_id_mask1_ciao);
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("hi", expected_section_ids_tf_map1_hi),
+ EqualsTermMatchInfo("hello", expected_section_ids_tf_map1_hello),
+ EqualsTermMatchInfo("ciao", expected_section_ids_tf_map1_ciao)));
EXPECT_FALSE(and_iter.Advance().ok());
}
diff --git a/icing/index/iterator/doc-hit-info-iterator-or_test.cc b/icing/index/iterator/doc-hit-info-iterator-or_test.cc
index 6e6872c..f487801 100644
--- a/icing/index/iterator/doc-hit-info-iterator-or_test.cc
+++ b/icing/index/iterator/doc-hit-info-iterator-or_test.cc
@@ -19,7 +19,6 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include "icing/index/iterator/doc-hit-info-iterator-and.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/schema/section.h"
@@ -32,10 +31,8 @@ namespace lib {
namespace {
using ::testing::ElementsAre;
-using ::testing::ElementsAreArray;
using ::testing::Eq;
using ::testing::IsEmpty;
-using ::testing::SizeIs;
TEST(CreateAndIteratorTest, Or) {
// Basic test that we can create a working Or iterator. Further testing of
@@ -182,22 +179,21 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) {
{
// Arbitrary section ids for the documents in the DocHitInfoIterators.
// Created to test correct section_id_mask behavior.
- SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{
- 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0};
- SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2{
- 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
-
DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4);
doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1);
doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2);
doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3);
doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4);
+ SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1 = {{0, 1}, {2, 2}, {4, 3}, {6, 4}};
DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(4);
doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2);
doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6);
+ SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map2 = {{1, 2}, {2, 6}};
std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1};
std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2};
@@ -219,28 +215,23 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) {
EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(4));
or_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "hi");
- EXPECT_EQ(matched_terms_stats.at(1).term, "hello");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies1));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies2));
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask2);
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("hi", expected_section_ids_tf_map1),
+ EqualsTermMatchInfo("hello", expected_section_ids_tf_map2)));
EXPECT_FALSE(or_iter.Advance().ok());
}
{
// Arbitrary section ids for the documents in the DocHitInfoIterators.
// Created to test correct section_id_mask behavior.
- SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{
- 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
-
DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4);
doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1);
doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2);
+ SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1 = {{0, 1}, {2, 2}};
std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1};
std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info1};
@@ -262,33 +253,28 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) {
EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(4));
or_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "hi");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies1));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1);
-
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "hi", expected_section_ids_tf_map1)));
EXPECT_FALSE(or_iter.Advance().ok());
}
{
// Arbitrary section ids for the documents in the DocHitInfoIterators.
// Created to test correct section_id_mask behavior.
- SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{
- 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0};
- SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2{
- 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
-
DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4);
doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1);
doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2);
doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3);
doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4);
+ SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1 = {{0, 1}, {2, 2}, {4, 3}, {6, 4}};
DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(5);
doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2);
doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6);
+ SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map2 = {{1, 2}, {2, 6}};
std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1};
std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2};
@@ -310,22 +296,17 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) {
EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(5));
or_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "hello");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies2));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask2);
+ EXPECT_THAT(matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("hello",
+ expected_section_ids_tf_map2)));
ICING_EXPECT_OK(or_iter.Advance());
EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(4));
matched_terms_stats.clear();
or_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "hi");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies1));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1);
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "hi", expected_section_ids_tf_map1)));
EXPECT_FALSE(or_iter.Advance().ok());
}
@@ -476,50 +457,44 @@ TEST(DocHitInfoIteratorOrNaryTest, PopulateMatchedTermsStats) {
// Arbitrary section ids/term frequencies for the documents in the
// DocHitInfoIterators.
// For term "hi", document 10 and 8
- SectionIdMask section_id_mask1_hi = 0b01000101; // hits in sections 0, 2, 6
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hi{
- 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0};
DocHitInfoTermFrequencyPair doc_hit_info1_hi = DocHitInfo(10);
doc_hit_info1_hi.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1);
doc_hit_info1_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2);
doc_hit_info1_hi.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1_hi = {{0, 1}, {2, 2}, {6, 4}};
- SectionIdMask section_id_mask2_hi = 0b00000110; // hits in sections 1, 2
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2_hi{
- 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
DocHitInfoTermFrequencyPair doc_hit_info2_hi = DocHitInfo(8);
doc_hit_info2_hi.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2);
doc_hit_info2_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map2_hi = {{1, 2}, {2, 6}};
// For term "hello", document 10 and 9
- SectionIdMask section_id_mask1_hello = 0b00001001; // hits in sections 0, 3
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hello{
- 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
DocHitInfoTermFrequencyPair doc_hit_info1_hello = DocHitInfo(10);
doc_hit_info1_hello.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2);
doc_hit_info1_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1_hello = {{0, 2}, {3, 3}};
- SectionIdMask section_id_mask2_hello = 0b00001100; // hits in sections 2, 3
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2_hello{
- 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
DocHitInfoTermFrequencyPair doc_hit_info2_hello = DocHitInfo(9);
doc_hit_info2_hello.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/3);
doc_hit_info2_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/2);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map2_hello = {{2, 3}, {3, 2}};
// For term "ciao", document 9 and 8
- SectionIdMask section_id_mask1_ciao = 0b00000011; // hits in sections 0, 1
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_ciao{
- 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
DocHitInfoTermFrequencyPair doc_hit_info1_ciao = DocHitInfo(9);
doc_hit_info1_ciao.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2);
doc_hit_info1_ciao.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/3);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1_ciao = {{0, 2}, {1, 3}};
- SectionIdMask section_id_mask2_ciao = 0b00011000; // hits in sections 3, 4
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2_ciao{
- 0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
DocHitInfoTermFrequencyPair doc_hit_info2_ciao = DocHitInfo(8);
doc_hit_info2_ciao.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3);
doc_hit_info2_ciao.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/2);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map2_ciao = {{3, 3}, {4, 2}};
std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1_hi,
doc_hit_info2_hi};
@@ -549,45 +524,33 @@ TEST(DocHitInfoIteratorOrNaryTest, PopulateMatchedTermsStats) {
EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(10));
or_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "hi");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies1_hi));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1_hi);
- EXPECT_EQ(matched_terms_stats.at(1).term, "hello");
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies1_hello));
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask1_hello);
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("hi", expected_section_ids_tf_map1_hi),
+ EqualsTermMatchInfo("hello", expected_section_ids_tf_map1_hello)));
ICING_EXPECT_OK(or_iter.Advance());
EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(9));
matched_terms_stats.clear();
or_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "hello");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies2_hello));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask2_hello);
- EXPECT_EQ(matched_terms_stats.at(1).term, "ciao");
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies1_ciao));
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask1_ciao);
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("hello", expected_section_ids_tf_map2_hello),
+ EqualsTermMatchInfo("ciao", expected_section_ids_tf_map1_ciao)));
ICING_EXPECT_OK(or_iter.Advance());
EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(8));
matched_terms_stats.clear();
or_iter.PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "hi");
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies2_hi));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask2_hi);
- EXPECT_EQ(matched_terms_stats.at(1).term, "ciao");
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies2_ciao));
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask2_ciao);
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("hi", expected_section_ids_tf_map2_hi),
+ EqualsTermMatchInfo("ciao", expected_section_ids_tf_map2_ciao)));
EXPECT_FALSE(or_iter.Advance().ok());
}
diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc b/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc
index e80d8f0..6d41e90 100644
--- a/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc
+++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc
@@ -44,7 +44,6 @@ namespace lib {
namespace {
using ::testing::ElementsAre;
-using ::testing::ElementsAreArray;
using ::testing::Eq;
using ::testing::IsEmpty;
@@ -143,13 +142,10 @@ TEST_F(DocHitInfoIteratorSectionRestrictTest,
expected_section_id_mask);
section_restrict_iterator.PopulateMatchedTermsStats(&matched_terms_stats);
- EXPECT_EQ(matched_terms_stats.at(0).term, "hi");
- std::array<Hit::TermFrequency, kTotalNumSections> expected_term_frequencies{
- 1};
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(expected_term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask,
- expected_section_id_mask);
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{0, 1}};
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "hi", expected_section_ids_tf_map)));
EXPECT_FALSE(section_restrict_iterator.Advance().ok());
}
diff --git a/icing/index/lite/lite-index_test.cc b/icing/index/lite/lite-index_test.cc
index 7858b47..2c29640 100644
--- a/icing/index/lite/lite-index_test.cc
+++ b/icing/index/lite/lite-index_test.cc
@@ -20,7 +20,6 @@
#include "gtest/gtest.h"
#include "icing/index/lite/doc-hit-info-iterator-term-lite.h"
#include "icing/index/term-id-codec.h"
-#include "icing/legacy/index/icing-mock-filesystem.h"
#include "icing/schema/section.h"
#include "icing/store/suggestion-result-checker.h"
#include "icing/testing/always-false-suggestion-result-checker-impl.h"
@@ -32,7 +31,7 @@ namespace lib {
namespace {
-using ::testing::ElementsAreArray;
+using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::IsEmpty;
using ::testing::SizeIs;
@@ -112,20 +111,25 @@ TEST_F(LiteIndexTest, LiteIndexIterator) {
lite_index_->InsertTerm(term, TermMatchType::PREFIX, kNamespace0));
ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id,
term_id_codec_->EncodeTvi(tvi, TviType::LITE));
- Hit doc_hit0(/*section_id=*/0, /*document_id=*/0, 3,
- /*is_in_prefix_section=*/false);
- Hit doc_hit1(/*section_id=*/1, /*document_id=*/0, 5,
- /*is_in_prefix_section=*/false);
- Hit::TermFrequencyArray doc0_term_frequencies{3, 5};
- Hit doc_hit2(/*section_id=*/1, /*document_id=*/1, 7,
- /*is_in_prefix_section=*/false);
- Hit doc_hit3(/*section_id=*/2, /*document_id=*/1, 11,
- /*is_in_prefix_section=*/false);
- Hit::TermFrequencyArray doc1_term_frequencies{0, 7, 11};
- ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit0));
- ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit1));
- ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit2));
- ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit3));
+ Hit doc0_hit0(/*section_id=*/0, /*document_id=*/0, /*term_frequency=*/3,
+ /*is_in_prefix_section=*/false);
+ Hit doc0_hit1(/*section_id=*/1, /*document_id=*/0, /*term_frequency=*/5,
+ /*is_in_prefix_section=*/false);
+ SectionIdMask doc0_section_id_mask = 0b11;
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map0 = {{0, 3}, {1, 5}};
+ ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc0_hit0));
+ ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc0_hit1));
+
+ Hit doc1_hit1(/*section_id=*/1, /*document_id=*/1, /*term_frequency=*/7,
+ /*is_in_prefix_section=*/false);
+ Hit doc1_hit2(/*section_id=*/2, /*document_id=*/1, /*term_frequency=*/11,
+ /*is_in_prefix_section=*/false);
+ SectionIdMask doc1_section_id_mask = 0b110;
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map1 = {{1, 7}, {2, 11}};
+ ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc1_hit1));
+ ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc1_hit2));
std::unique_ptr<DocHitInfoIteratorTermLiteExact> iter =
std::make_unique<DocHitInfoIteratorTermLiteExact>(
@@ -134,25 +138,22 @@ TEST_F(LiteIndexTest, LiteIndexIterator) {
ASSERT_THAT(iter->Advance(), IsOk());
EXPECT_THAT(iter->doc_hit_info().document_id(), Eq(1));
- EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(), Eq(0b110));
+ EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(),
+ Eq(doc1_section_id_mask));
+
std::vector<TermMatchInfo> matched_terms_stats;
iter->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1));
- EXPECT_EQ(matched_terms_stats.at(0).term, term);
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, 0b110);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(doc1_term_frequencies));
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ term, expected_section_ids_tf_map1)));
ASSERT_THAT(iter->Advance(), IsOk());
EXPECT_THAT(iter->doc_hit_info().document_id(), Eq(0));
- EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(), Eq(0b11));
+ EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(),
+ Eq(doc0_section_id_mask));
matched_terms_stats.clear();
iter->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1));
- EXPECT_EQ(matched_terms_stats.at(0).term, term);
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, 0b11);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(doc0_term_frequencies));
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ term, expected_section_ids_tf_map0)));
}
} // namespace
diff --git a/icing/index/numeric/integer-index-storage.h b/icing/index/numeric/integer-index-storage.h
new file mode 100644
index 0000000..2048e76
--- /dev/null
+++ b/icing/index/numeric/integer-index-storage.h
@@ -0,0 +1,186 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_NUMERIC_INTEGER_INDEX_STORAGE_H_
+#define ICING_INDEX_NUMERIC_INTEGER_INDEX_STORAGE_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+
+#include "icing/file/file-backed-vector.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/memory-mapped-file.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h"
+#include "icing/util/crc32.h"
+
+namespace icing {
+namespace lib {
+
+// IntegerIndexStorage: a class for indexing (persistent storage) and searching
+// contents of integer type sections in documents.
+// - Accepts new integer contents (a.k.a keys) and adds records (BasicHit, key)
+// into the integer index.
+// - Stores records (BasicHit, key) in posting lists and compresses them.
+// - Bucketizes these records by key to make range query more efficient and
+// manages them with the corresponding posting lists.
+// - When a posting list reaches the max size and is full, the mechanism of
+// PostingListAccessor is to create another new (max-size) posting list and
+// chain them together.
+// - It will be inefficient if we store all records in the same PL chain. E.g.
+// small range query needs to iterate through the whole PL chain but skips a
+// lot of non-relevant records (whose keys don't belong to the query range).
+// - Therefore, we implement the splitting mechanism to split a full max-size
+// posting list. Also adjust range of the original bucket and add new
+// buckets.
+// - Ranges of all buckets are disjoint and the union of them is [INT64_MIN,
+// INT64_MAX].
+// - Buckets should be sorted, so we can do binary search to find the desired
+// bucket(s). However, we may split a bucket into several buckets, and the
+// cost to insert newly created buckets is high.
+// - Thus, we introduce an unsorted bucket array for newly created buckets,
+// and merge unsorted buckets into the sorted bucket array only if length of
+// the unsorted bucket array exceeds the threshold. This mechanism will
+// reduce # of merging events and amortize the overall cost for bucket order
+// maintenance.
+// Note: some tree data structures (e.g. segment tree, B+ tree) maintain the
+// bucket order more efficiently than the sorted/unsorted bucket array
+// mechanism, but the implementation is more complicated and doesn't improve
+// the performance too much according to our analysis, so currently we
+// choose sorted/unsorted bucket array.
+// - Then we do binary search on the sorted bucket array and sequential search
+// on the unsorted bucket array.
+class IntegerIndexStorage {
+ public:
+ // Crcs and Info will be written into the metadata file.
+ // File layout: <Crcs><Info>
+ // Crcs
+ struct Crcs {
+ static constexpr int32_t kFileOffset = 0;
+
+ struct ComponentCrcs {
+ uint32_t info_crc;
+ uint32_t sorted_buckets_crc;
+ uint32_t unsorted_buckets_crc;
+ uint32_t flash_index_storage_crc;
+
+ bool operator==(const ComponentCrcs& other) const {
+ return info_crc == other.info_crc &&
+ sorted_buckets_crc == other.sorted_buckets_crc &&
+ unsorted_buckets_crc == other.unsorted_buckets_crc &&
+ flash_index_storage_crc == other.flash_index_storage_crc;
+ }
+
+ Crc32 ComputeChecksum() const {
+ return Crc32(std::string_view(reinterpret_cast<const char*>(this),
+ sizeof(ComponentCrcs)));
+ }
+ } __attribute__((packed));
+
+ bool operator==(const Crcs& other) const {
+ return all_crc == other.all_crc && component_crcs == other.component_crcs;
+ }
+
+ uint32_t all_crc;
+ ComponentCrcs component_crcs;
+ } __attribute__((packed));
+ static_assert(sizeof(Crcs) == 20, "");
+
+ // Info
+ struct Info {
+ static constexpr int32_t kFileOffset = static_cast<int32_t>(sizeof(Crcs));
+ static constexpr int32_t kMagic = 0xc4bf0ccc;
+
+ int32_t magic;
+ int32_t num_keys;
+
+ Crc32 ComputeChecksum() const {
+ return Crc32(
+ std::string_view(reinterpret_cast<const char*>(this), sizeof(Info)));
+ }
+ } __attribute__((packed));
+ static_assert(sizeof(Info) == 8, "");
+
+ // Bucket
+ class Bucket {
+ public:
+ // Absolute max # of buckets allowed. Since the absolute max file size of
+ // FileBackedVector on 32-bit platform is ~2^28, we can at most have ~13.4M
+ // buckets. To make it power of 2, round it down to 2^23. Also since we're
+ // using FileBackedVector to store buckets, add some static_asserts to
+ // ensure numbers here are compatible with FileBackedVector.
+ static constexpr int32_t kMaxNumBuckets = 1 << 23;
+
+ explicit Bucket(int64_t key_lower, int64_t key_upper,
+ PostingListIdentifier posting_list_identifier)
+ : key_lower_(key_lower),
+ key_upper_(key_upper),
+ posting_list_identifier_(posting_list_identifier) {}
+
+ // For FileBackedVector
+ bool operator==(const Bucket& other) const {
+ return key_lower_ == other.key_lower_ && key_upper_ == other.key_upper_ &&
+ posting_list_identifier_ == other.posting_list_identifier_;
+ }
+
+ PostingListIdentifier posting_list_identifier() const {
+ return posting_list_identifier_;
+ }
+ void set_posting_list_identifier(
+ PostingListIdentifier posting_list_identifier) {
+ posting_list_identifier_ = posting_list_identifier;
+ }
+
+ private:
+ int64_t key_lower_;
+ int64_t key_upper_;
+ PostingListIdentifier posting_list_identifier_;
+ } __attribute__((packed));
+ static_assert(sizeof(Bucket) == 20, "");
+ static_assert(sizeof(Bucket) == FileBackedVector<Bucket>::kElementTypeSize,
+ "Bucket type size is inconsistent with FileBackedVector "
+ "element type size");
+ static_assert(Bucket::kMaxNumBuckets <=
+ (FileBackedVector<Bucket>::kMaxFileSize -
+ FileBackedVector<Bucket>::Header::kHeaderSize) /
+ FileBackedVector<Bucket>::kElementTypeSize,
+ "Max # of buckets cannot fit into FileBackedVector");
+
+ private:
+ explicit IntegerIndexStorage(
+ const Filesystem& filesystem, std::string_view base_dir,
+ PostingListUsedIntegerIndexDataSerializer* serializer,
+ std::unique_ptr<MemoryMappedFile> metadata_mmapped_file,
+ std::unique_ptr<FileBackedVector<Bucket>> sorted_buckets,
+ std::unique_ptr<FileBackedVector<Bucket>> unsorted_buckets,
+ std::unique_ptr<FlashIndexStorage> flash_index_storage);
+
+ const Filesystem& filesystem_;
+ std::string base_dir_;
+
+ PostingListUsedIntegerIndexDataSerializer* serializer_; // Does not own.
+
+ std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_;
+ std::unique_ptr<FileBackedVector<Bucket>> sorted_buckets_;
+ std::unique_ptr<FileBackedVector<Bucket>> unsorted_buckets_;
+ std::unique_ptr<FlashIndexStorage> flash_index_storage_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_NUMERIC_INTEGER_INDEX_STORAGE_H_
diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc
index 9a7df38..51f3106 100644
--- a/icing/jni/icing-search-engine-jni.cc
+++ b/icing/jni/icing-search-engine-jni.cc
@@ -51,7 +51,8 @@ jbyteArray SerializeProtoToJniByteArray(
int size = protobuf.ByteSizeLong();
jbyteArray ret = env->NewByteArray(size);
if (ret == nullptr) {
- ICING_LOG(ERROR) << "Failed to allocated bytes for jni protobuf";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to allocated bytes for jni protobuf";
return nullptr;
}
@@ -75,7 +76,7 @@ extern "C" {
jint JNI_OnLoad(JavaVM* vm, void* reserved) {
JNIEnv* env;
if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
- ICING_LOG(ERROR) << "ERROR: GetEnv failed";
+ ICING_LOG(icing::lib::ERROR) << "ERROR: GetEnv failed";
return JNI_ERR;
}
@@ -88,7 +89,7 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeCreate(
icing::lib::IcingSearchEngineOptions options;
if (!ParseProtoFromJniByteArray(env, icing_search_engine_options_bytes,
&options)) {
- ICING_LOG(ERROR)
+ ICING_LOG(icing::lib::ERROR)
<< "Failed to parse IcingSearchEngineOptions in nativeCreate";
return 0;
}
@@ -131,7 +132,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetSchema(
icing::lib::SchemaProto schema_proto;
if (!ParseProtoFromJniByteArray(env, schema_bytes, &schema_proto)) {
- ICING_LOG(ERROR) << "Failed to parse SchemaProto in nativeSetSchema";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to parse SchemaProto in nativeSetSchema";
return nullptr;
}
@@ -173,7 +175,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativePut(
icing::lib::DocumentProto document_proto;
if (!ParseProtoFromJniByteArray(env, document_bytes, &document_proto)) {
- ICING_LOG(ERROR) << "Failed to parse DocumentProto in nativePut";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to parse DocumentProto in nativePut";
return nullptr;
}
@@ -192,7 +195,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeGet(
icing::lib::GetResultSpecProto get_result_spec;
if (!ParseProtoFromJniByteArray(env, result_spec_bytes, &get_result_spec)) {
- ICING_LOG(ERROR) << "Failed to parse GetResultSpecProto in nativeGet";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to parse GetResultSpecProto in nativeGet";
return nullptr;
}
icing::lib::ScopedUtfChars scoped_name_space_chars(env, name_space);
@@ -212,7 +216,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeReportUsage(
icing::lib::UsageReport usage_report;
if (!ParseProtoFromJniByteArray(env, usage_report_bytes, &usage_report)) {
- ICING_LOG(ERROR) << "Failed to parse UsageReport in nativeReportUsage";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to parse UsageReport in nativeReportUsage";
return nullptr;
}
@@ -279,20 +284,23 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearch(
icing::lib::SearchSpecProto search_spec_proto;
if (!ParseProtoFromJniByteArray(env, search_spec_bytes, &search_spec_proto)) {
- ICING_LOG(ERROR) << "Failed to parse SearchSpecProto in nativeSearch";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to parse SearchSpecProto in nativeSearch";
return nullptr;
}
icing::lib::ScoringSpecProto scoring_spec_proto;
if (!ParseProtoFromJniByteArray(env, scoring_spec_bytes,
&scoring_spec_proto)) {
- ICING_LOG(ERROR) << "Failed to parse ScoringSpecProto in nativeSearch";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to parse ScoringSpecProto in nativeSearch";
return nullptr;
}
icing::lib::ResultSpecProto result_spec_proto;
if (!ParseProtoFromJniByteArray(env, result_spec_bytes, &result_spec_proto)) {
- ICING_LOG(ERROR) << "Failed to parse ResultSpecProto in nativeSearch";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to parse ResultSpecProto in nativeSearch";
return nullptr;
}
@@ -363,7 +371,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByQuery(
icing::lib::SearchSpecProto search_spec_proto;
if (!ParseProtoFromJniByteArray(env, search_spec_bytes, &search_spec_proto)) {
- ICING_LOG(ERROR) << "Failed to parse SearchSpecProto in nativeSearch";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to parse SearchSpecProto in nativeSearch";
return nullptr;
}
icing::lib::DeleteByQueryResultProto delete_result_proto =
@@ -379,8 +388,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativePersistToDisk(
GetIcingSearchEnginePointer(env, object);
if (!icing::lib::PersistType::Code_IsValid(persist_type_code)) {
- ICING_LOG(ERROR) << persist_type_code
- << " is an invalid value for PersistType::Code";
+ ICING_LOG(icing::lib::ERROR)
+ << persist_type_code << " is an invalid value for PersistType::Code";
return nullptr;
}
icing::lib::PersistType::Code persist_type_code_enum =
@@ -447,7 +456,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearchSuggestions(
icing::lib::SuggestionSpecProto suggestion_spec_proto;
if (!ParseProtoFromJniByteArray(env, suggestion_spec_bytes,
&suggestion_spec_proto)) {
- ICING_LOG(ERROR) << "Failed to parse SuggestionSpecProto in nativeSearch";
+ ICING_LOG(icing::lib::ERROR)
+ << "Failed to parse SuggestionSpecProto in nativeSearch";
return nullptr;
}
icing::lib::SuggestionResponse suggestionResponse =
@@ -463,7 +473,8 @@ Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetDebugInfo(
GetIcingSearchEnginePointer(env, object);
if (!icing::lib::DebugInfoVerbosity::Code_IsValid(verbosity)) {
- ICING_LOG(ERROR) << "Invalid value for Debug Info verbosity: " << verbosity;
+ ICING_LOG(icing::lib::ERROR)
+ << "Invalid value for Debug Info verbosity: " << verbosity;
return nullptr;
}
@@ -478,7 +489,8 @@ JNIEXPORT jboolean JNICALL
Java_com_google_android_icing_IcingSearchEngineImpl_nativeShouldLog(
JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) {
if (!icing::lib::LogSeverity::Code_IsValid(severity)) {
- ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity;
+ ICING_LOG(icing::lib::ERROR)
+ << "Invalid value for logging severity: " << severity;
return false;
}
return icing::lib::ShouldLog(
@@ -489,7 +501,8 @@ JNIEXPORT jboolean JNICALL
Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetLoggingLevel(
JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) {
if (!icing::lib::LogSeverity::Code_IsValid(severity)) {
- ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity;
+ ICING_LOG(icing::lib::ERROR)
+ << "Invalid value for logging severity: " << severity;
return false;
}
return icing::lib::SetLoggingLevel(
diff --git a/icing/join/aggregate-scorer.cc b/icing/join/aggregation-scorer.cc
index 7b17482..3dee3dd 100644
--- a/icing/join/aggregate-scorer.cc
+++ b/icing/join/aggregation-scorer.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "icing/join/aggregate-scorer.h"
+#include "icing/join/aggregation-scorer.h"
#include <algorithm>
#include <memory>
@@ -25,24 +25,26 @@
namespace icing {
namespace lib {
-class MinAggregateScorer : public AggregateScorer {
+class CountAggregationScorer : public AggregationScorer {
public:
double GetScore(const ScoredDocumentHit& parent,
const std::vector<ScoredDocumentHit>& children) override {
- return std::min_element(children.begin(), children.end(),
- [](const ScoredDocumentHit& lhs,
- const ScoredDocumentHit& rhs) -> bool {
- return lhs.score() < rhs.score();
- })
- ->score();
+ return children.size();
}
};
-class MaxAggregateScorer : public AggregateScorer {
+class MinAggregationScorer : public AggregationScorer {
public:
double GetScore(const ScoredDocumentHit& parent,
const std::vector<ScoredDocumentHit>& children) override {
- return std::max_element(children.begin(), children.end(),
+ if (children.empty()) {
+ // Return 0 if there is no child document.
+ // For non-empty children with negative scores, they are considered "worse
+ // than" 0, so it is correct to return 0 for empty children to assign it a
+ // rank higher than non-empty children with negative scores.
+ return 0.0;
+ }
+ return std::min_element(children.begin(), children.end(),
[](const ScoredDocumentHit& lhs,
const ScoredDocumentHit& rhs) -> bool {
return lhs.score() < rhs.score();
@@ -51,41 +53,59 @@ class MaxAggregateScorer : public AggregateScorer {
}
};
-class AverageAggregateScorer : public AggregateScorer {
+class AverageAggregationScorer : public AggregationScorer {
public:
double GetScore(const ScoredDocumentHit& parent,
const std::vector<ScoredDocumentHit>& children) override {
- if (children.empty()) return 0.0;
+ if (children.empty()) {
+ // Return 0 if there is no child document.
+ // For non-empty children with negative scores, they are considered "worse
+ // than" 0, so it is correct to return 0 for empty children to assign it a
+ // rank higher than non-empty children with negative scores.
+ return 0.0;
+ }
return std::reduce(
children.begin(), children.end(), 0.0,
- [](const double& prev, const ScoredDocumentHit& item) -> double {
+ [](double prev, const ScoredDocumentHit& item) -> double {
return prev + item.score();
}) /
children.size();
}
};
-class CountAggregateScorer : public AggregateScorer {
+class MaxAggregationScorer : public AggregationScorer {
public:
double GetScore(const ScoredDocumentHit& parent,
const std::vector<ScoredDocumentHit>& children) override {
- return children.size();
+ if (children.empty()) {
+ // Return 0 if there is no child document.
+ // For non-empty children with negative scores, they are considered "worse
+ // than" 0, so it is correct to return 0 for empty children to assign it a
+ // rank higher than non-empty children with negative scores.
+ return 0.0;
+ }
+ return std::max_element(children.begin(), children.end(),
+ [](const ScoredDocumentHit& lhs,
+ const ScoredDocumentHit& rhs) -> bool {
+ return lhs.score() < rhs.score();
+ })
+ ->score();
}
};
-class SumAggregateScorer : public AggregateScorer {
+class SumAggregationScorer : public AggregationScorer {
public:
double GetScore(const ScoredDocumentHit& parent,
const std::vector<ScoredDocumentHit>& children) override {
return std::reduce(
children.begin(), children.end(), 0.0,
- [](const double& prev, const ScoredDocumentHit& item) -> double {
+ [](double prev, const ScoredDocumentHit& item) -> double {
return prev + item.score();
});
}
};
-class DefaultAggregateScorer : public AggregateScorer {
+class DefaultAggregationScorer : public AggregationScorer {
public:
double GetScore(const ScoredDocumentHit& parent,
const std::vector<ScoredDocumentHit>& children) override {
@@ -93,23 +113,25 @@ class DefaultAggregateScorer : public AggregateScorer {
}
};
-std::unique_ptr<AggregateScorer> AggregateScorer::Create(
+std::unique_ptr<AggregationScorer> AggregationScorer::Create(
const JoinSpecProto& join_spec) {
- switch (join_spec.aggregation_score_strategy()) {
- case JoinSpecProto_AggregationScore_MIN:
- return std::make_unique<MinAggregateScorer>();
- case JoinSpecProto_AggregationScore_MAX:
- return std::make_unique<MaxAggregateScorer>();
- case JoinSpecProto_AggregationScore_COUNT:
- return std::make_unique<CountAggregateScorer>();
- case JoinSpecProto_AggregationScore_AVG:
- return std::make_unique<AverageAggregateScorer>();
- case JoinSpecProto_AggregationScore_SUM:
- return std::make_unique<SumAggregateScorer>();
- case JoinSpecProto_AggregationScore_UNDEFINED:
+ switch (join_spec.aggregation_scoring_strategy()) {
+ case JoinSpecProto::AggregationScoringStrategy::COUNT:
+ return std::make_unique<CountAggregationScorer>();
+ case JoinSpecProto::AggregationScoringStrategy::MIN:
+ return std::make_unique<MinAggregationScorer>();
+ case JoinSpecProto::AggregationScoringStrategy::AVG:
+ return std::make_unique<AverageAggregationScorer>();
+ case JoinSpecProto::AggregationScoringStrategy::MAX:
+ return std::make_unique<MaxAggregationScorer>();
+ case JoinSpecProto::AggregationScoringStrategy::SUM:
+ return std::make_unique<SumAggregationScorer>();
+ case JoinSpecProto::AggregationScoringStrategy::NONE:
+ // No aggregation strategy means using parent document score, so fall
+ // through to return DefaultAggregationScorer.
[[fallthrough]];
default:
- return std::make_unique<DefaultAggregateScorer>();
+ return std::make_unique<DefaultAggregationScorer>();
}
}
diff --git a/icing/join/aggregate-scorer.h b/icing/join/aggregation-scorer.h
index 27731b9..3d38cf0 100644
--- a/icing/join/aggregate-scorer.h
+++ b/icing/join/aggregation-scorer.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef ICING_JOIN_AGGREGATE_SCORER_H_
-#define ICING_JOIN_AGGREGATE_SCORER_H_
+#ifndef ICING_JOIN_AGGREGATION_SCORER_H_
+#define ICING_JOIN_AGGREGATION_SCORER_H_
#include <memory>
#include <vector>
@@ -24,12 +24,12 @@
namespace icing {
namespace lib {
-class AggregateScorer {
+class AggregationScorer {
public:
- static std::unique_ptr<AggregateScorer> Create(
+ static std::unique_ptr<AggregationScorer> Create(
const JoinSpecProto& join_spec);
- virtual ~AggregateScorer() = default;
+ virtual ~AggregationScorer() = default;
virtual double GetScore(const ScoredDocumentHit& parent,
const std::vector<ScoredDocumentHit>& children) = 0;
@@ -38,4 +38,4 @@ class AggregateScorer {
} // namespace lib
} // namespace icing
-#endif // ICING_JOIN_AGGREGATE_SCORER_H_
+#endif // ICING_JOIN_AGGREGATION_SCORER_H_
diff --git a/icing/join/aggregation-scorer_test.cc b/icing/join/aggregation-scorer_test.cc
new file mode 100644
index 0000000..19a7239
--- /dev/null
+++ b/icing/join/aggregation-scorer_test.cc
@@ -0,0 +1,215 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/join/aggregation-scorer.h"
+
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/proto/search.pb.h"
+#include "icing/schema/section.h"
+#include "icing/scoring/scored-document-hit.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::DoubleEq;
+
+struct AggregationScorerTestParam {
+ double ans;
+ JoinSpecProto::AggregationScoringStrategy::Code scoring_strategy;
+ double parent_score;
+ std::vector<double> child_scores;
+
+ explicit AggregationScorerTestParam(
+ double ans_in,
+ JoinSpecProto::AggregationScoringStrategy::Code scoring_strategy_in,
+ double parent_score_in, std::vector<double> child_scores_in)
+ : ans(ans_in),
+ scoring_strategy(scoring_strategy_in),
+ parent_score(std::move(parent_score_in)),
+ child_scores(std::move(child_scores_in)) {}
+};
+
+class AggregationScorerTest
+ : public ::testing::TestWithParam<AggregationScorerTestParam> {};
+
+TEST_P(AggregationScorerTest, GetScore) {
+ static constexpr DocumentId kDefaultDocumentId = 0;
+
+ const AggregationScorerTestParam& param = GetParam();
+ // Test AggregationScorer by creating some ScoredDocumentHits for parent and
+ // child documents. DocumentId and SectionIdMask won't affect the aggregation
+ // score calculation, so just simply set default values.
+ // Parent document
+ ScoredDocumentHit parent_scored_document_hit(
+ kDefaultDocumentId, kSectionIdMaskNone, param.parent_score);
+ // Child documents
+ std::vector<ScoredDocumentHit> child_scored_document_hits;
+ child_scored_document_hits.reserve(param.child_scores.size());
+ std::transform(param.child_scores.cbegin(), param.child_scores.cend(),
+ std::back_inserter(child_scored_document_hits),
+ [](double score) -> ScoredDocumentHit {
+ return ScoredDocumentHit(kDefaultDocumentId,
+ kSectionIdMaskNone, score);
+ });
+
+ JoinSpecProto join_spec;
+ join_spec.set_aggregation_scoring_strategy(param.scoring_strategy);
+ std::unique_ptr<AggregationScorer> scorer =
+ AggregationScorer::Create(join_spec);
+ EXPECT_THAT(
+ scorer->GetScore(parent_scored_document_hit, child_scored_document_hits),
+ DoubleEq(param.ans));
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ CountAggregationScorerTest, AggregationScorerTest,
+ testing::Values(
+ // General case.
+ AggregationScorerTestParam(
+ /*ans_in=*/5, JoinSpecProto::AggregationScoringStrategy::COUNT,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{8, 3, 1, 4, 7}),
+ // Only one child.
+ AggregationScorerTestParam(
+ /*ans_in=*/1, JoinSpecProto::AggregationScoringStrategy::COUNT,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{123}),
+ // No child.
+ AggregationScorerTestParam(
+ /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::COUNT,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{})));
+
+INSTANTIATE_TEST_SUITE_P(
+ MinAggregationScorerTest, AggregationScorerTest,
+ testing::Values(
+ // General case.
+ AggregationScorerTestParam(
+ /*ans_in=*/1, JoinSpecProto::AggregationScoringStrategy::MIN,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{8, 3, 1, 4, 7}),
+ // Only one child, greater than parent.
+ AggregationScorerTestParam(
+ /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::MIN,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{123}),
+ // Only one child, smaller than parent.
+ AggregationScorerTestParam(
+ /*ans_in=*/50, JoinSpecProto::AggregationScoringStrategy::MIN,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{50}),
+ // No child.
+ AggregationScorerTestParam(
+ /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::MIN,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{})));
+
+INSTANTIATE_TEST_SUITE_P(
+ AverageAggregationScorerTest, AggregationScorerTest,
+ testing::Values(
+ // General case.
+ AggregationScorerTestParam(
+ /*ans_in=*/4.6, JoinSpecProto::AggregationScoringStrategy::AVG,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{8, 3, 1, 4, 7}),
+ // Only one child.
+ AggregationScorerTestParam(
+ /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::AVG,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{123}),
+ // No child.
+ AggregationScorerTestParam(
+ /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::AVG,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{})));
+
+INSTANTIATE_TEST_SUITE_P(
+ MaxAggregationScorerTest, AggregationScorerTest,
+ testing::Values(
+ // General case.
+ AggregationScorerTestParam(
+ /*ans_in=*/8, JoinSpecProto::AggregationScoringStrategy::MAX,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{8, 3, 1, 4, 7}),
+ // Only one child, greater than parent.
+ AggregationScorerTestParam(
+ /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::MAX,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{123}),
+ // Only one child, smaller than parent.
+ AggregationScorerTestParam(
+ /*ans_in=*/50, JoinSpecProto::AggregationScoringStrategy::MAX,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{50}),
+ // No child.
+ AggregationScorerTestParam(
+ /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::MAX,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{})));
+
+INSTANTIATE_TEST_SUITE_P(
+ SumAggregationScorerTest, AggregationScorerTest,
+ testing::Values(
+ // General case.
+ AggregationScorerTestParam(
+ /*ans_in=*/23, JoinSpecProto::AggregationScoringStrategy::SUM,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{8, 3, 1, 4, 7}),
+ // Only one child.
+ AggregationScorerTestParam(
+ /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::SUM,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{123}),
+ // No child.
+ AggregationScorerTestParam(
+ /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::SUM,
+ /*parent_score_in=*/0,
+ /*child_scores_in=*/{})));
+
+INSTANTIATE_TEST_SUITE_P(
+ DefaultAggregationScorerTest, AggregationScorerTest,
+ testing::Values(
+ // General case.
+ AggregationScorerTestParam(
+ /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{8, 3, 1, 4, 7}),
+ // Only one child, greater than parent.
+ AggregationScorerTestParam(
+ /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{123}),
+ // Only one child, smaller than parent.
+ AggregationScorerTestParam(
+ /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{50}),
+ // No child.
+ AggregationScorerTestParam(
+ /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE,
+ /*parent_score_in=*/98,
+ /*child_scores_in=*/{})));
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/join/join-processor.cc b/icing/join/join-processor.cc
index 9b17396..7700397 100644
--- a/icing/join/join-processor.cc
+++ b/icing/join/join-processor.cc
@@ -23,6 +23,7 @@
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/absl_ports/str_cat.h"
+#include "icing/join/aggregation-scorer.h"
#include "icing/join/qualified-id.h"
#include "icing/proto/scoring.pb.h"
#include "icing/proto/search.pb.h"
@@ -45,8 +46,6 @@ JoinProcessor::Join(
ScoringSpecProto::Order::DESC));
// TODO(b/256022027):
- // - Aggregate scoring
- // - Calculate the aggregated score if strategy is AGGREGATION_SCORING.
// - Optimization
// - Cache property to speed up property retrieval.
// - If there is no cache, then we still have the flexibility to fetch it
@@ -93,6 +92,9 @@ JoinProcessor::Join(
}
}
+ std::unique_ptr<AggregationScorer> aggregation_scorer =
+ AggregationScorer::Create(join_spec);
+
std::vector<JoinedScoredDocumentHit> joined_scored_document_hits;
joined_scored_document_hits.reserve(parent_scored_document_hits.size());
@@ -110,14 +112,15 @@ JoinProcessor::Join(
"Parent property expression must be ", kQualifiedIdExpr));
}
- // TODO(b/256022027): Derive final score from
- // parent_id_to_child_map[parent_doc_id] and
- // join_spec.aggregation_score_strategy()
- double final_score = parent.score();
- joined_scored_document_hits.emplace_back(
- final_score, std::move(parent),
- std::vector<ScoredDocumentHit>(
- std::move(parent_id_to_child_map[parent_doc_id])));
+ std::vector<ScoredDocumentHit> children;
+ if (auto iter = parent_id_to_child_map.find(parent_doc_id);
+ iter != parent_id_to_child_map.end()) {
+ children = std::move(iter->second);
+ }
+
+ double final_score = aggregation_scorer->GetScore(parent, children);
+ joined_scored_document_hits.emplace_back(final_score, std::move(parent),
+ std::move(children));
}
return joined_scored_document_hits;
diff --git a/icing/join/join-processor_test.cc b/icing/join/join-processor_test.cc
new file mode 100644
index 0000000..70eaf3f
--- /dev/null
+++ b/icing/join/join-processor_test.cc
@@ -0,0 +1,624 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/join/join-processor.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/proto/document.pb.h"
+#include "icing/proto/schema.pb.h"
+#include "icing/proto/scoring.pb.h"
+#include "icing/proto/search.pb.h"
+#include "icing/schema-builder.h"
+#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/scoring/scored-document-hit.h"
+#include "icing/store/document-id.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/fake-clock.h"
+#include "icing/testing/tmp-directory.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+
+class JoinProcessorTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ test_dir_ = GetTestTempDir() + "/icing_join_processor_test";
+ filesystem_.CreateDirectoryRecursively(test_dir_.c_str());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ schema_store_,
+ SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_));
+
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder().SetType("Person").AddProperty(
+ PropertyConfigBuilder()
+ .SetName("Name")
+ .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .AddType(
+ SchemaTypeConfigBuilder()
+ .SetType("Email")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("subject")
+ .SetDataTypeString(TERM_MATCH_EXACT,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("sender")
+ .SetDataTypeJoinableString(
+ JOINABLE_VALUE_TYPE_QUALIFIED_ID)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+ ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentStore::CreateResult create_result,
+ DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_,
+ schema_store_.get()));
+ doc_store_ = std::move(create_result.document_store);
+ }
+
+ void TearDown() override {
+ doc_store_.reset();
+ schema_store_.reset();
+ filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
+ }
+
+ Filesystem filesystem_;
+ std::string test_dir_;
+ std::unique_ptr<SchemaStore> schema_store_;
+ std::unique_ptr<DocumentStore> doc_store_;
+ FakeClock fake_clock_;
+};
+
+TEST_F(JoinProcessorTest, JoinByQualifiedId) {
+ DocumentProto person1 = DocumentBuilder()
+ .SetKey("pkg$db/namespace", "person1")
+ .SetSchema("Person")
+ .AddStringProperty("Name", "Alice")
+ .Build();
+ DocumentProto person2 = DocumentBuilder()
+ .SetKey(R"(pkg$db/name#space\\)", "person2")
+ .SetSchema("Person")
+ .AddStringProperty("Name", "Bob")
+ .Build();
+
+ DocumentProto email1 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email1")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 1")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+ DocumentProto email2 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email2")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 2")
+ .AddStringProperty("sender",
+ R"(pkg$db/name\#space\\\\#person2)") // escaped
+ .Build();
+ DocumentProto email3 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email3")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 3")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(person2));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email2));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, doc_store_->Put(email3));
+
+ ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone,
+ /*score=*/3.0);
+ ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone,
+ /*score=*/4.0);
+ ScoredDocumentHit scored_doc_hit5(document_id5, kSectionIdMaskNone,
+ /*score=*/5.0);
+
+ // Parent ScoredDocumentHits: all Person documents
+ std::vector<ScoredDocumentHit> parent_scored_document_hits = {
+ scored_doc_hit2, scored_doc_hit1};
+
+ // Child ScoredDocumentHits: all Email documents
+ std::vector<ScoredDocumentHit> child_scored_document_hits = {
+ scored_doc_hit5, scored_doc_hit4, scored_doc_hit3};
+
+ JoinSpecProto join_spec;
+ join_spec.set_max_joined_child_count(100);
+ join_spec.set_parent_property_expression(
+ std::string(JoinProcessor::kQualifiedIdExpr));
+ join_spec.set_child_property_expression("sender");
+ join_spec.set_aggregation_scoring_strategy(
+ JoinSpecProto::AggregationScoringStrategy::COUNT);
+ join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by(
+ ScoringSpecProto::Order::DESC);
+
+ JoinProcessor join_processor(doc_store_.get());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::vector<JoinedScoredDocumentHit> joined_result_document_hits,
+ join_processor.Join(join_spec, std::move(parent_scored_document_hits),
+ std::move(child_scored_document_hits)));
+ EXPECT_THAT(
+ joined_result_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/1.0,
+ /*parent_scored_document_hit=*/scored_doc_hit2,
+ /*child_scored_document_hits=*/{scored_doc_hit4})),
+ EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/2.0,
+ /*parent_scored_document_hit=*/scored_doc_hit1,
+ /*child_scored_document_hits=*/
+ {scored_doc_hit5, scored_doc_hit3}))));
+}
+
+TEST_F(JoinProcessorTest, ShouldIgnoreChildDocumentsWithoutJoiningProperty) {
+ DocumentProto person1 = DocumentBuilder()
+ .SetKey("pkg$db/namespace", "person1")
+ .SetSchema("Person")
+ .AddStringProperty("Name", "Alice")
+ .Build();
+
+ DocumentProto email1 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email1")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 1")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+ DocumentProto email2 = DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email2")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 2")
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(email1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email2));
+
+ ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone,
+ /*score=*/5.0);
+ ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone,
+ /*score=*/6.0);
+
+ // Parent ScoredDocumentHits: all Person documents
+ std::vector<ScoredDocumentHit> parent_scored_document_hits = {
+ scored_doc_hit1};
+
+ // Child ScoredDocumentHits: all Email documents
+ std::vector<ScoredDocumentHit> child_scored_document_hits = {scored_doc_hit2,
+ scored_doc_hit3};
+
+ JoinSpecProto join_spec;
+ join_spec.set_max_joined_child_count(100);
+ join_spec.set_parent_property_expression(
+ std::string(JoinProcessor::kQualifiedIdExpr));
+ join_spec.set_child_property_expression("sender");
+ join_spec.set_aggregation_scoring_strategy(
+ JoinSpecProto::AggregationScoringStrategy::COUNT);
+ join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by(
+ ScoringSpecProto::Order::DESC);
+
+ JoinProcessor join_processor(doc_store_.get());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::vector<JoinedScoredDocumentHit> joined_result_document_hits,
+ join_processor.Join(join_spec, std::move(parent_scored_document_hits),
+ std::move(child_scored_document_hits)));
+ // Since Email2 doesn't have "sender" property, it should be ignored.
+ EXPECT_THAT(
+ joined_result_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/1.0, /*parent_scored_document_hit=*/scored_doc_hit1,
+ /*child_scored_document_hits=*/{scored_doc_hit2}))));
+}
+
+TEST_F(JoinProcessorTest, ShouldIgnoreChildDocumentsWithInvalidQualifiedId) {
+ DocumentProto person1 = DocumentBuilder()
+ .SetKey("pkg$db/namespace", "person1")
+ .SetSchema("Person")
+ .AddStringProperty("Name", "Alice")
+ .Build();
+
+ DocumentProto email1 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email1")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 1")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+ DocumentProto email2 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email2")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 2")
+ .AddStringProperty(
+ "sender",
+ "pkg$db/namespace#person2") // qualified id is invalid since
+ // person2 doesn't exist.
+ .Build();
+ DocumentProto email3 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email3")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 3")
+ .AddStringProperty("sender",
+ R"(pkg$db/namespace\#person1)") // invalid format
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(email1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email2));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email3));
+
+ ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone,
+ /*score=*/0.0);
+
+ // Parent ScoredDocumentHits: all Person documents
+ std::vector<ScoredDocumentHit> parent_scored_document_hits = {
+ scored_doc_hit1};
+
+ // Child ScoredDocumentHits: all Email documents
+ std::vector<ScoredDocumentHit> child_scored_document_hits = {
+ scored_doc_hit2, scored_doc_hit3, scored_doc_hit4};
+
+ JoinSpecProto join_spec;
+ join_spec.set_max_joined_child_count(100);
+ join_spec.set_parent_property_expression(
+ std::string(JoinProcessor::kQualifiedIdExpr));
+ join_spec.set_child_property_expression("sender");
+ join_spec.set_aggregation_scoring_strategy(
+ JoinSpecProto::AggregationScoringStrategy::COUNT);
+ join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by(
+ ScoringSpecProto::Order::DESC);
+
+ JoinProcessor join_processor(doc_store_.get());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::vector<JoinedScoredDocumentHit> joined_result_document_hits,
+ join_processor.Join(join_spec, std::move(parent_scored_document_hits),
+ std::move(child_scored_document_hits)));
+ // Email 2 and email 3 (document id 3 and 4) contain invalid qualified ids.
+ // Join processor should ignore them.
+ EXPECT_THAT(joined_result_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/1.0,
+ /*parent_scored_document_hit=*/scored_doc_hit1,
+ /*child_scored_document_hits=*/{scored_doc_hit2}))));
+}
+
+TEST_F(JoinProcessorTest, LeftJoinShouldReturnParentWithoutChildren) {
+ DocumentProto person1 = DocumentBuilder()
+ .SetKey("pkg$db/namespace", "person1")
+ .SetSchema("Person")
+ .AddStringProperty("Name", "Alice")
+ .Build();
+ DocumentProto person2 = DocumentBuilder()
+ .SetKey(R"(pkg$db/name#space\\)", "person2")
+ .SetSchema("Person")
+ .AddStringProperty("Name", "Bob")
+ .Build();
+
+ DocumentProto email1 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email1")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 1")
+ .AddStringProperty("sender",
+ R"(pkg$db/name\#space\\\\#person2)") // escaped
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(person2));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email1));
+
+ ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone,
+ /*score=*/3.0);
+
+ // Parent ScoredDocumentHits: all Person documents
+ std::vector<ScoredDocumentHit> parent_scored_document_hits = {
+ scored_doc_hit2, scored_doc_hit1};
+
+ // Child ScoredDocumentHits: all Email documents
+ std::vector<ScoredDocumentHit> child_scored_document_hits = {scored_doc_hit3};
+
+ JoinSpecProto join_spec;
+ join_spec.set_max_joined_child_count(100);
+ join_spec.set_parent_property_expression(
+ std::string(JoinProcessor::kQualifiedIdExpr));
+ join_spec.set_child_property_expression("sender");
+ join_spec.set_aggregation_scoring_strategy(
+ JoinSpecProto::AggregationScoringStrategy::COUNT);
+ join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by(
+ ScoringSpecProto::Order::DESC);
+
+ JoinProcessor join_processor(doc_store_.get());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::vector<JoinedScoredDocumentHit> joined_result_document_hits,
+ join_processor.Join(join_spec, std::move(parent_scored_document_hits),
+ std::move(child_scored_document_hits)));
+ // Person1 has no child documents, but left join should also include it.
+ EXPECT_THAT(
+ joined_result_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/1.0,
+ /*parent_scored_document_hit=*/scored_doc_hit2,
+ /*child_scored_document_hits=*/{scored_doc_hit3})),
+ EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/0.0,
+ /*parent_scored_document_hit=*/scored_doc_hit1,
+ /*child_scored_document_hits=*/{}))));
+}
+
+TEST_F(JoinProcessorTest, ShouldSortChildDocumentsByRankingStrategy) {
+ DocumentProto person1 = DocumentBuilder()
+ .SetKey("pkg$db/namespace", "person1")
+ .SetSchema("Person")
+ .AddStringProperty("Name", "Alice")
+ .Build();
+
+ DocumentProto email1 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email1")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 1")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+ DocumentProto email2 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email2")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 2")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+ DocumentProto email3 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email3")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 3")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(email1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email2));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email3));
+
+ ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone,
+ /*score=*/2.0);
+ ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone,
+ /*score=*/5.0);
+ ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone,
+ /*score=*/3.0);
+
+ // Parent ScoredDocumentHits: all Person documents
+ std::vector<ScoredDocumentHit> parent_scored_document_hits = {
+ scored_doc_hit1};
+
+ // Child ScoredDocumentHits: all Email documents
+ std::vector<ScoredDocumentHit> child_scored_document_hits = {
+ scored_doc_hit2, scored_doc_hit3, scored_doc_hit4};
+
+ JoinSpecProto join_spec;
+ join_spec.set_max_joined_child_count(100);
+ join_spec.set_parent_property_expression(
+ std::string(JoinProcessor::kQualifiedIdExpr));
+ join_spec.set_child_property_expression("sender");
+ join_spec.set_aggregation_scoring_strategy(
+ JoinSpecProto::AggregationScoringStrategy::COUNT);
+ join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by(
+ ScoringSpecProto::Order::DESC);
+
+ JoinProcessor join_processor(doc_store_.get());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::vector<JoinedScoredDocumentHit> joined_result_document_hits,
+ join_processor.Join(join_spec, std::move(parent_scored_document_hits),
+ std::move(child_scored_document_hits)));
+ // Child documents should be sorted according to the (nested) ranking
+ // strategy.
+ EXPECT_THAT(
+ joined_result_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/3.0, /*parent_scored_document_hit=*/scored_doc_hit1,
+ /*child_scored_document_hits=*/
+ {scored_doc_hit3, scored_doc_hit4, scored_doc_hit2}))));
+}
+
+TEST_F(JoinProcessorTest,
+ ShouldTruncateByRankingStrategyIfExceedingMaxJoinedChildCount) {
+ DocumentProto person1 = DocumentBuilder()
+ .SetKey("pkg$db/namespace", "person1")
+ .SetSchema("Person")
+ .AddStringProperty("Name", "Alice")
+ .Build();
+ DocumentProto person2 = DocumentBuilder()
+ .SetKey(R"(pkg$db/name#space\\)", "person2")
+ .SetSchema("Person")
+ .AddStringProperty("Name", "Bob")
+ .Build();
+
+ DocumentProto email1 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email1")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 1")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+ DocumentProto email2 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email2")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 2")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+ DocumentProto email3 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email3")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 3")
+ .AddStringProperty("sender", "pkg$db/namespace#person1")
+ .Build();
+ DocumentProto email4 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email4")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 4")
+ .AddStringProperty("sender",
+ R"(pkg$db/name\#space\\\\#person2)") // escaped
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(person2));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email1));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email2));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, doc_store_->Put(email3));
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id6, doc_store_->Put(email4));
+
+ ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone,
+ /*score=*/0.0);
+ ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone,
+ /*score=*/2.0);
+ ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone,
+ /*score=*/5.0);
+ ScoredDocumentHit scored_doc_hit5(document_id5, kSectionIdMaskNone,
+ /*score=*/3.0);
+ ScoredDocumentHit scored_doc_hit6(document_id6, kSectionIdMaskNone,
+ /*score=*/1.0);
+
+ // Parent ScoredDocumentHits: all Person documents
+ std::vector<ScoredDocumentHit> parent_scored_document_hits = {
+ scored_doc_hit1, scored_doc_hit2};
+
+ // Child ScoredDocumentHits: all Email documents
+ std::vector<ScoredDocumentHit> child_scored_document_hits = {
+ scored_doc_hit3, scored_doc_hit4, scored_doc_hit5, scored_doc_hit6};
+
+ JoinSpecProto join_spec;
+ join_spec.set_max_joined_child_count(2);
+ join_spec.set_parent_property_expression(
+ std::string(JoinProcessor::kQualifiedIdExpr));
+ join_spec.set_child_property_expression("sender");
+ join_spec.set_aggregation_scoring_strategy(
+ JoinSpecProto::AggregationScoringStrategy::COUNT);
+ join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by(
+ ScoringSpecProto::Order::DESC);
+
+ JoinProcessor join_processor(doc_store_.get());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::vector<JoinedScoredDocumentHit> joined_result_document_hits,
+ join_processor.Join(join_spec, std::move(parent_scored_document_hits),
+ std::move(child_scored_document_hits)));
+ // Since we set max_joind_child_count as 2 and use DESC as the (nested)
+ // ranking strategy, parent document with # of child documents more than 2
+ // should only keep 2 child documents with higher scores and the rest should
+ // be truncated.
+ EXPECT_THAT(
+ joined_result_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/2.0,
+ /*parent_scored_document_hit=*/scored_doc_hit1,
+ /*child_scored_document_hits=*/
+ {scored_doc_hit4, scored_doc_hit5})),
+ EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/1.0,
+ /*parent_scored_document_hit=*/scored_doc_hit2,
+ /*child_scored_document_hits=*/{scored_doc_hit6}))));
+}
+
+TEST_F(JoinProcessorTest, ShouldAllowSelfJoining) {
+ DocumentProto email1 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "email1")
+ .SetSchema("Email")
+ .AddStringProperty("subject", "test subject 1")
+ .AddStringProperty("sender", "pkg$db/namespace#email1")
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(email1));
+
+ ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone,
+ /*score=*/0.0);
+
+ // Parent ScoredDocumentHits: all Person documents
+ std::vector<ScoredDocumentHit> parent_scored_document_hits = {
+ scored_doc_hit1};
+
+ // Child ScoredDocumentHits: all Email documents
+ std::vector<ScoredDocumentHit> child_scored_document_hits = {scored_doc_hit1};
+
+ JoinSpecProto join_spec;
+ join_spec.set_max_joined_child_count(100);
+ join_spec.set_parent_property_expression(
+ std::string(JoinProcessor::kQualifiedIdExpr));
+ join_spec.set_child_property_expression("sender");
+ join_spec.set_aggregation_scoring_strategy(
+ JoinSpecProto::AggregationScoringStrategy::COUNT);
+ join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by(
+ ScoringSpecProto::Order::DESC);
+
+ JoinProcessor join_processor(doc_store_.get());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::vector<JoinedScoredDocumentHit> joined_result_document_hits,
+ join_processor.Join(join_spec, std::move(parent_scored_document_hits),
+ std::move(child_scored_document_hits)));
+ EXPECT_THAT(joined_result_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit(
+ /*final_score=*/1.0,
+ /*parent_scored_document_hit=*/scored_doc_hit1,
+ /*child_scored_document_hits=*/{scored_doc_hit1}))));
+}
+
+// TODO(b/256022027): add unit tests for non-joinable property. If joinable
+// value type is unset, then qualifed id join should not
+// include the child document even if it contains a valid
+// qualified id string.
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc
index 910d722..fbd4504 100644
--- a/icing/query/advanced_query_parser/query-visitor.cc
+++ b/icing/query/advanced_query_parser/query-visitor.cc
@@ -366,5 +366,23 @@ void QueryVisitor::VisitNaryOperator(const NaryOperatorNode* node) {
pending_values_.push(std::move(pending_value));
}
+libtextclassifier3::StatusOr<QueryResults> QueryVisitor::ConsumeResults() && {
+ if (has_pending_error()) {
+ return std::move(pending_error_);
+ }
+ if (pending_values_.size() != 1) {
+ return absl_ports::InvalidArgumentError(
+ "Visitor does not contain a single root iterator.");
+ }
+ auto iterator_or = RetrieveIterator();
+ if (!iterator_or.ok()) {
+ return std::move(iterator_or).status();
+ }
+ QueryResults results;
+ results.root_iterator = std::move(iterator_or).ValueOrDie();
+ results.features_in_use = std::move(features_);
+ return results;
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/query/advanced_query_parser/query-visitor.h b/icing/query/advanced_query_parser/query-visitor.h
index 8bba7ea..c6b7d8e 100644
--- a/icing/query/advanced_query_parser/query-visitor.h
+++ b/icing/query/advanced_query_parser/query-visitor.h
@@ -26,10 +26,11 @@
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/index/numeric/numeric-index.h"
#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/query/query-features.h"
+#include "icing/query/query-results.h"
#include "icing/schema/schema-store.h"
#include "icing/store/document-store.h"
#include "icing/transform/normalizer.h"
-#include "icing/query/query-features.h"
namespace icing {
namespace lib {
@@ -60,27 +61,10 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor {
void VisitNaryOperator(const NaryOperatorNode* node) override;
// RETURNS:
- // - the DocHitInfoIterator that is the root of the query iterator tree
+ // - the QueryResults reflecting the AST that was visited
// - INVALID_ARGUMENT if the AST does not conform to supported expressions
// - NOT_FOUND if the AST refers to a property that does not exist
- libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> root() && {
- if (has_pending_error()) {
- return pending_error_;
- }
- if (pending_values_.size() != 1) {
- return absl_ports::InvalidArgumentError(
- "Visitor does not contain a single root iterator.");
- }
- auto iterator_or = RetrieveIterator();
- if (!iterator_or.ok()) {
- pending_error_ = std::move(iterator_or).status();
- return pending_error_;
- }
- return std::move(iterator_or).ValueOrDie();
- }
-
- // Returns the set of features used in the query.
- const std::unordered_set<Feature>& features() const { return features_; }
+ libtextclassifier3::StatusOr<QueryResults> ConsumeResults() &&;
private:
// A holder for intermediate results when processing child nodes.
@@ -185,15 +169,16 @@ class QueryVisitor : public AbstractSyntaxTreeVisitor {
std::stack<PendingValue> pending_values_;
libtextclassifier3::Status pending_error_;
+ // Set of features invoked in the query.
+ std::unordered_set<Feature> features_;
+
Index& index_; // Does not own!
const NumericIndex<int64_t>& numeric_index_; // Does not own!
const DocumentStore& document_store_; // Does not own!
const SchemaStore& schema_store_; // Does not own!
const Normalizer& normalizer_; // Does not own!
- TermMatchType::Code match_type_;
- // Set of features invoked in the query.
- std::unordered_set<Feature> features_;
+ TermMatchType::Code match_type_;
};
} // namespace lib
diff --git a/icing/query/advanced_query_parser/query-visitor_test.cc b/icing/query/advanced_query_parser/query-visitor_test.cc
index 2f816a8..2b5117b 100644
--- a/icing/query/advanced_query_parser/query-visitor_test.cc
+++ b/icing/query/advanced_query_parser/query-visitor_test.cc
@@ -146,10 +146,11 @@ TEST_F(QueryVisitorTest, SimpleLessThan) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kNumericSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId1, kDocumentId0));
}
@@ -176,10 +177,11 @@ TEST_F(QueryVisitorTest, SimpleLessThanEq) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kNumericSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId1, kDocumentId0));
}
@@ -206,10 +208,12 @@ TEST_F(QueryVisitorTest, SimpleEqual) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kNumericSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId2));
}
TEST_F(QueryVisitorTest, SimpleGreaterThanEq) {
@@ -235,10 +239,11 @@ TEST_F(QueryVisitorTest, SimpleGreaterThanEq) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kNumericSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId2, kDocumentId1));
}
@@ -265,10 +270,12 @@ TEST_F(QueryVisitorTest, SimpleGreaterThan) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kNumericSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId2));
}
// TODO(b/208654892) Properly handle negative numbers in query expressions.
@@ -296,10 +303,12 @@ TEST_F(QueryVisitorTest, DISABLED_IntMinLessThanEqual) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kNumericSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
}
TEST_F(QueryVisitorTest, IntMaxGreaterThanEqual) {
@@ -326,10 +335,12 @@ TEST_F(QueryVisitorTest, IntMaxGreaterThanEqual) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kNumericSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1));
}
TEST_F(QueryVisitorTest, NestedPropertyLessThan) {
@@ -357,10 +368,11 @@ TEST_F(QueryVisitorTest, NestedPropertyLessThan) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kNumericSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kNumericSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId1, kDocumentId0));
}
@@ -372,7 +384,7 @@ TEST_F(QueryVisitorTest, IntParsingError) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(std::move(query_visitor).root(),
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -384,7 +396,7 @@ TEST_F(QueryVisitorTest, NotEqualsUnsupported) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(std::move(query_visitor).root(),
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::UNIMPLEMENTED));
}
@@ -427,7 +439,7 @@ TEST_F(QueryVisitorTest, LessThanTooManyOperandsInvalid) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(std::move(query_visitor).root(),
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -449,7 +461,7 @@ TEST_F(QueryVisitorTest, LessThanTooFewOperandsInvalid) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(std::move(query_visitor).root(),
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -480,7 +492,7 @@ TEST_F(QueryVisitorTest, LessThanNonExistentPropertyNotFound) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(std::move(query_visitor).root(),
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::NOT_FOUND));
}
@@ -488,7 +500,7 @@ TEST_F(QueryVisitorTest, NeverVisitedReturnsInvalid) {
QueryVisitor query_visitor(index_.get(), numeric_index_.get(),
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
- EXPECT_THAT(std::move(query_visitor).root(),
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -517,7 +529,7 @@ TEST_F(QueryVisitorTest, DISABLED_IntMinLessThanInvalid) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(std::move(query_visitor).root(),
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -545,7 +557,7 @@ TEST_F(QueryVisitorTest, IntMaxGreaterThanInvalid) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(std::move(query_visitor).root(),
+ EXPECT_THAT(std::move(query_visitor).ConsumeResults(),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
@@ -574,9 +586,9 @@ TEST_F(QueryVisitorTest, SingleTerm) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId1, kDocumentId0));
}
@@ -605,10 +617,11 @@ TEST_F(QueryVisitorTest, SingleVerbatimTerm) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kVerbatimSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId1, kDocumentId0));
}
@@ -650,10 +663,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingQuote) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kVerbatimSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId2));
}
// 2. How does a user represent a escape char (\) that immediately precedes the
@@ -687,10 +702,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingEscape) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kVerbatimSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1));
}
// 3. How do we handle other escaped chars?
@@ -726,10 +743,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kVerbatimSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
// Issue a query for the verbatim token `foobar\y`.
query = R"("foobar\\y")";
@@ -738,10 +757,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingNonSpecialChar) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor_two);
- EXPECT_THAT(query_visitor_two.features(), ElementsAre(kVerbatimSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(root_iterator,
- std::move(query_visitor_two).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor_two).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kVerbatimSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId2));
}
// This isn't a special case, but is fairly useful for demonstrating. There are
@@ -778,10 +799,12 @@ TEST_F(QueryVisitorTest, VerbatimTermNewLine) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kVerbatimSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
// Now, issue a query for the verbatim token `foobar\n`.
query = R"("foobar\\n")";
@@ -790,10 +813,12 @@ TEST_F(QueryVisitorTest, VerbatimTermNewLine) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor_two);
- EXPECT_THAT(query_visitor_two.features(), ElementsAre(kVerbatimSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(root_iterator,
- std::move(query_visitor_two).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor_two).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kVerbatimSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId2));
}
TEST_F(QueryVisitorTest, VerbatimTermEscapingComplex) {
@@ -824,10 +849,12 @@ TEST_F(QueryVisitorTest, VerbatimTermEscapingComplex) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), ElementsAre(kVerbatimSearchFeature));
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use,
+ ElementsAre(kVerbatimSearchFeature));
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
}
TEST_F(QueryVisitorTest, SingleMinusTerm) {
@@ -866,10 +893,11 @@ TEST_F(QueryVisitorTest, SingleMinusTerm) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId2));
}
TEST_F(QueryVisitorTest, SingleNotTerm) {
@@ -908,10 +936,11 @@ TEST_F(QueryVisitorTest, SingleNotTerm) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId2));
}
TEST_F(QueryVisitorTest, ImplicitAndTerms) {
Index::Editor editor = index_->Edit(kDocumentId0, kSectionId1,
@@ -937,10 +966,11 @@ TEST_F(QueryVisitorTest, ImplicitAndTerms) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1));
}
TEST_F(QueryVisitorTest, ExplicitAndTerms) {
@@ -967,10 +997,11 @@ TEST_F(QueryVisitorTest, ExplicitAndTerms) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId1));
}
TEST_F(QueryVisitorTest, OrTerms) {
@@ -997,10 +1028,10 @@ TEST_F(QueryVisitorTest, OrTerms) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId2, kDocumentId0));
}
@@ -1030,10 +1061,10 @@ TEST_F(QueryVisitorTest, AndOrTermPrecedence) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId2, kDocumentId1));
// Should be interpreted like `(bar OR baz) foo`
@@ -1043,10 +1074,10 @@ TEST_F(QueryVisitorTest, AndOrTermPrecedence) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor_two);
- EXPECT_THAT(query_visitor_two.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(root_iterator,
- std::move(query_visitor_two).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor_two).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId2, kDocumentId1));
query = "(bar OR baz) foo";
@@ -1055,10 +1086,10 @@ TEST_F(QueryVisitorTest, AndOrTermPrecedence) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor_three);
- EXPECT_THAT(query_visitor_three.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(root_iterator,
- std::move(query_visitor_three).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor_three).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId2, kDocumentId1));
}
@@ -1103,10 +1134,10 @@ TEST_F(QueryVisitorTest, AndOrNotPrecedence) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId2, kDocumentId0));
query = "foo NOT (bar OR baz)";
@@ -1115,10 +1146,11 @@ TEST_F(QueryVisitorTest, AndOrNotPrecedence) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor_two);
- EXPECT_THAT(query_visitor_two.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(root_iterator,
- std::move(query_visitor_two).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0));
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor_two).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId0));
}
TEST_F(QueryVisitorTest, PropertyFilter) {
@@ -1169,10 +1201,10 @@ TEST_F(QueryVisitorTest, PropertyFilter) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId1, kDocumentId0));
}
@@ -1224,10 +1256,10 @@ TEST_F(QueryVisitorTest, PropertyFilterWithGrouping) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
ElementsAre(kDocumentId1, kDocumentId0));
}
@@ -1279,10 +1311,11 @@ TEST_F(QueryVisitorTest, PropertyFilterWithNot) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor);
- EXPECT_THAT(query_visitor.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
- std::move(query_visitor).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+ ICING_ASSERT_OK_AND_ASSIGN(QueryResults query_results,
+ std::move(query_visitor).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId2));
query = "NOT prop1:(foo OR bar)";
ICING_ASSERT_OK_AND_ASSIGN(root_node, ParseQueryHelper(query));
@@ -1290,10 +1323,11 @@ TEST_F(QueryVisitorTest, PropertyFilterWithNot) {
document_store_.get(), schema_store_.get(),
normalizer_.get(), TERM_MATCH_PREFIX);
root_node->Accept(&query_visitor_two);
- EXPECT_THAT(query_visitor_two.features(), IsEmpty());
- ICING_ASSERT_OK_AND_ASSIGN(root_iterator,
- std::move(query_visitor_two).root());
- EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+ ICING_ASSERT_OK_AND_ASSIGN(query_results,
+ std::move(query_visitor_two).ConsumeResults());
+ EXPECT_THAT(query_results.features_in_use, IsEmpty());
+ EXPECT_THAT(GetDocumentIds(query_results.root_iterator.get()),
+ ElementsAre(kDocumentId2));
}
} // namespace
diff --git a/icing/query/query-features.h b/icing/query/query-features.h
index 56825f7..1471063 100644
--- a/icing/query/query-features.h
+++ b/icing/query/query-features.h
@@ -22,13 +22,17 @@ namespace icing {
namespace lib {
// A feature used in a query.
-using Feature = std::string_view;
-
// All feature values here must be kept in sync with its counterpart in:
// androidx-main/frameworks/support/appsearch/appsearch/src/main/java/androidx/appsearch/app/Features.java
+using Feature = std::string_view;
+
+// This feature relates to the use of the numeric comparison operators in the
+// advanced query language. Ex. `price < 10`.
constexpr Feature kNumericSearchFeature =
"NUMERIC_SEARCH"; // Features#NUMERIC_SEARCH
+// This feature relates to the use of the STRING terminal in the advanced query
+// language. Ex. `"foo?bar"` is treated as a single term - `foo?bar`.
constexpr Feature kVerbatimSearchFeature =
"VERBATIM_SEARCH"; // Features#VERBATIM_SEARCH
diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc
index 7860684..283d83d 100644
--- a/icing/query/query-processor.cc
+++ b/icing/query/query-processor.cc
@@ -207,10 +207,7 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery(
&schema_store_, &normalizer_,
search_spec.term_match_type());
tree_root->Accept(&query_visitor);
- results.features_in_use = query_visitor.features();
- ICING_ASSIGN_OR_RETURN(results.root_iterator,
- std::move(query_visitor).root());
- return results;
+ return std::move(query_visitor).ConsumeResults();
}
// TODO(cassiewang): Collect query stats to populate the SearchResultsProto
diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc
index 161f180..c22f6aa 100644
--- a/icing/query/query-processor_test.cc
+++ b/icing/query/query-processor_test.cc
@@ -59,7 +59,6 @@ namespace lib {
namespace {
using ::testing::ElementsAre;
-using ::testing::ElementsAreArray;
using ::testing::IsEmpty;
using ::testing::SizeIs;
using ::testing::UnorderedElementsAre;
@@ -293,7 +292,6 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) {
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
EXPECT_THAT(
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
@@ -320,18 +318,14 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "hello");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "world");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("world", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""],
UnorderedElementsAre("hello", "world"));
@@ -359,7 +353,6 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) {
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
TermMatchType::Code term_match_type = TermMatchType::PREFIX;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
EXPECT_THAT(
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
@@ -383,14 +376,12 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "he");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "he", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he"));
EXPECT_THAT(results.query_term_iterators, SizeIs(1));
@@ -442,14 +433,12 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatchWithMaxSectionID) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "he");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "he", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he"));
EXPECT_THAT(results.query_term_iterators, SizeIs(1));
@@ -476,7 +465,6 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) {
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
EXPECT_THAT(
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
@@ -500,14 +488,13 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "hello");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello"));
EXPECT_THAT(results.query_term_iterators, SizeIs(1));
@@ -534,7 +521,6 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) {
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
EXPECT_THAT(
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
@@ -554,18 +540,18 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) {
EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id);
EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(),
section_id_mask);
+
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "hello");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map)));
}
-
ASSERT_FALSE(results.root_iterator->Advance().ok());
// TODO(b/208654892) Support Query Terms with advanced query
@@ -597,7 +583,6 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) {
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
EXPECT_THAT(
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
@@ -624,18 +609,14 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "hello");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "world");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("world", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""],
UnorderedElementsAre("hello", "world"));
@@ -663,7 +644,6 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) {
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
TermMatchType::Code term_match_type = TermMatchType::PREFIX;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
EXPECT_THAT(
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
@@ -687,13 +667,12 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "he");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "he", expected_section_ids_tf_map)));
}
ASSERT_FALSE(results.root_iterator->Advance().ok());
@@ -726,7 +705,6 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) {
// Populate the index
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
TermMatchType::Code term_match_type = TermMatchType::PREFIX;
EXPECT_THAT(
@@ -755,18 +733,14 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "he");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "wo");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("he", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("wo", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he", "wo"));
EXPECT_THAT(results.query_term_iterators, SizeIs(2));
@@ -792,7 +766,6 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) {
// Populate the index
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
TermMatchType::Code term_match_type = TermMatchType::PREFIX;
EXPECT_THAT(AddTokenToIndex(document_id, section_id,
@@ -821,18 +794,14 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "hello");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "wo");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("wo", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "wo"));
EXPECT_THAT(results.query_term_iterators, SizeIs(2));
@@ -863,7 +832,6 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) {
// Populate the index
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY;
EXPECT_THAT(
@@ -888,33 +856,34 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) {
EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id2);
EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(),
section_id_mask);
+
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "world");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("world", expected_section_ids_tf_map)));
}
ASSERT_THAT(results.root_iterator->Advance(), IsOk());
EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id1);
EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(),
section_id_mask);
+
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "hello");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""],
UnorderedElementsAre("hello", "world"));
@@ -946,7 +915,6 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) {
// Populate the index
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
TermMatchType::Code term_match_type = TermMatchType::PREFIX;
EXPECT_THAT(
@@ -975,13 +943,12 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "wo");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "wo", expected_section_ids_tf_map)));
}
ASSERT_THAT(results.root_iterator->Advance(), IsOk());
@@ -992,14 +959,12 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "he");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "he", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he", "wo"));
EXPECT_THAT(results.query_term_iterators, SizeIs(2));
@@ -1030,7 +995,6 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) {
// Populate the index
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
EXPECT_THAT(AddTokenToIndex(document_id1, section_id,
TermMatchType::EXACT_ONLY, "hello"),
@@ -1058,13 +1022,12 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "wo");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
+ EXPECT_THAT(matched_terms_stats, ElementsAre(EqualsTermMatchInfo(
+ "wo", expected_section_ids_tf_map)));
}
ASSERT_THAT(results.root_iterator->Advance(), IsOk());
@@ -1075,14 +1038,13 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term
- EXPECT_EQ(matched_terms_stats.at(0).term, "hello");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("hello", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "wo"));
EXPECT_THAT(results.query_term_iterators, SizeIs(2));
@@ -1112,7 +1074,6 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
// Populate the index
SectionId section_id = 0;
SectionIdMask section_id_mask = 1U << section_id;
- std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1};
TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY;
// Document 1 has content "animal puppy dog"
@@ -1159,17 +1120,14 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "puppy");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "dog");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies));
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("puppy", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("dog", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""],
@@ -1203,17 +1161,15 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "animal");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "kitten");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies));
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("animal", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("kitten", expected_section_ids_tf_map)));
}
ASSERT_THAT(results.root_iterator->Advance(), IsOk());
@@ -1225,18 +1181,15 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "animal");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "puppy");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies));
-
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("animal", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("puppy", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""],
UnorderedElementsAre("animal", "puppy", "kitten"));
@@ -1267,17 +1220,15 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 1}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "kitten");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "cat");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(term_frequencies));
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("kitten", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("cat", expected_section_ids_tf_map)));
EXPECT_THAT(results.query_terms, SizeIs(1));
EXPECT_THAT(results.query_terms[""],
@@ -1901,9 +1852,6 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) {
QueryResults results,
query_processor_->ParseSearch(search_spec,
ScoringSpecProto::RankingStrategy::NONE));
- // Since need_hit_term_frequency is false, the expected term frequencies
- // should all be 0.
- Hit::TermFrequencyArray exp_term_frequencies{0};
// Descending order of valid DocumentIds
// The first Document to match (Document 2) matches on 'animal' AND 'kitten'
@@ -1915,17 +1863,17 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ // Since need_hit_term_frequency is false, the expected term frequency for
+ // the section with the hit should be 0.
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 0}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "animal");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(exp_term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "kitten");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(exp_term_frequencies));
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(
+ EqualsTermMatchInfo("animal", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("kitten", expected_section_ids_tf_map)));
}
// The second Document to match (Document 1) matches on 'animal' AND 'puppy'
@@ -1937,17 +1885,14 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) {
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ std::unordered_map<SectionId, Hit::TermFrequency>
+ expected_section_ids_tf_map = {{section_id, 0}};
std::vector<TermMatchInfo> matched_terms_stats;
results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats);
- ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms
- EXPECT_EQ(matched_terms_stats.at(0).term, "animal");
- EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(0).term_frequencies,
- ElementsAreArray(exp_term_frequencies));
- EXPECT_EQ(matched_terms_stats.at(1).term, "puppy");
- EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask);
- EXPECT_THAT(matched_terms_stats.at(1).term_frequencies,
- ElementsAreArray(exp_term_frequencies));
+ EXPECT_THAT(
+ matched_terms_stats,
+ ElementsAre(EqualsTermMatchInfo("animal", expected_section_ids_tf_map),
+ EqualsTermMatchInfo("puppy", expected_section_ids_tf_map)));
// This should be empty because ranking_strategy != RELEVANCE_SCORE
EXPECT_THAT(results.query_term_iterators, IsEmpty());
@@ -2009,6 +1954,7 @@ TEST_P(QueryProcessorTest, DeletedFilter) {
expectedDocHitInfo.UpdateSection(/*section_id=*/0);
EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()),
ElementsAre(expectedDocHitInfo));
+
// TODO(b/208654892) Support Query Terms with advanced query
if (GetParam() !=
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
diff --git a/icing/scoring/advanced_scoring/advanced-scorer.h b/icing/scoring/advanced_scoring/advanced-scorer.h
index 077d734..763499b 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer.h
+++ b/icing/scoring/advanced_scoring/advanced-scorer.h
@@ -60,13 +60,20 @@ class AdvancedScorer : public Scorer {
bm25f_calculator_->PrepareToScore(query_term_iterators);
}
+ bool is_constant() const { return score_expression_->is_constant_double(); }
+
private:
explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression,
std::unique_ptr<Bm25fCalculator> bm25f_calculator,
double default_score)
: score_expression_(std::move(score_expression)),
bm25f_calculator_(std::move(bm25f_calculator)),
- default_score_(default_score) {}
+ default_score_(default_score) {
+ if (is_constant()) {
+ ICING_LOG(WARNING)
+ << "The advanced scoring expression will evaluate to a constant.";
+ }
+ }
std::unique_ptr<ScoreExpression> score_expression_;
std::unique_ptr<Bm25fCalculator> bm25f_calculator_;
diff --git a/icing/scoring/advanced_scoring/advanced-scorer_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_test.cc
index 36d38a2..b0b32e9 100644
--- a/icing/scoring/advanced_scoring/advanced-scorer_test.cc
+++ b/icing/scoring/advanced_scoring/advanced-scorer_test.cc
@@ -437,7 +437,7 @@ TEST_F(AdvancedScorerTest, ComplexExpression) {
DocHitInfo docHitInfo = DocHitInfo(document_id);
ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<Scorer> scorer,
+ std::unique_ptr<AdvancedScorer> scorer,
AdvancedScorer::Create(CreateAdvancedScoringSpec(
"pow(sin(2), 2)"
// This is this.usageCount(1)
@@ -449,13 +449,14 @@ TEST_F(AdvancedScorerTest, ComplexExpression) {
"+ this.relevanceScore()"),
/*default_score=*/10, document_store_.get(),
schema_store_.get()));
+ EXPECT_FALSE(scorer->is_constant());
scorer->PrepareToScore(/*query_term_iterators=*/{});
ICING_ASSERT_OK(document_store_->ReportUsage(
CreateUsageReport("namespace", "uri", 0, UsageReport::USAGE_TYPE1)));
ICING_ASSERT_OK(document_store_->ReportUsage(
CreateUsageReport("namespace", "uri", 0, UsageReport::USAGE_TYPE1)));
- EXPECT_THAT(scorer->GetScore(docHitInfo),
+ EXPECT_THAT(scorer->GetScore(docHitInfo, /*query_it=*/nullptr),
DoubleNear(pow(sin(2), 2) +
2 / 12.34 *
(10 * pow(2 * 1, sin(2)) +
@@ -464,6 +465,18 @@ TEST_F(AdvancedScorerTest, ComplexExpression) {
kEps));
}
+TEST_F(AdvancedScorerTest, ConstantExpression) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<AdvancedScorer> scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec(
+ "pow(sin(2), 2)"
+ "+ log(2, 122) / 12.34"
+ "* (10 * pow(2 * 1, sin(2)) + 10 * (2 + 10))"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_TRUE(scorer->is_constant());
+}
+
// Should be a parsing Error
TEST_F(AdvancedScorerTest, EmptyExpression) {
EXPECT_THAT(
diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc
index 08da1c5..a8749df 100644
--- a/icing/scoring/advanced_scoring/score-expression.cc
+++ b/icing/scoring/advanced_scoring/score-expression.cc
@@ -17,19 +17,23 @@
namespace icing {
namespace lib {
-libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>>
+libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
OperatorScoreExpression::Create(
OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) {
if (children.empty()) {
return absl_ports::InvalidArgumentError(
"OperatorScoreExpression must have at least one argument.");
}
+ bool children_all_constant_double = true;
for (const auto& child : children) {
ICING_RETURN_ERROR_IF_NULL(child);
if (child->is_document_type()) {
return absl_ports::InvalidArgumentError(
"Operators are not supported for document type.");
}
+ if (!child->is_constant_double()) {
+ children_all_constant_double = false;
+ }
}
if (op == OperatorType::kNegative) {
if (children.size() != 1) {
@@ -37,8 +41,16 @@ OperatorScoreExpression::Create(
"Negative operator must have only 1 argument.");
}
}
- return std::unique_ptr<OperatorScoreExpression>(
- new OperatorScoreExpression(op, std::move(children)));
+ std::unique_ptr<ScoreExpression> expression =
+ std::unique_ptr<OperatorScoreExpression>(
+ new OperatorScoreExpression(op, std::move(children)));
+ if (children_all_constant_double) {
+ // Because all of the children are constants, this expression does not
+ // depend on the DocHitInto or query_it that are passed into it.
+ return ConstantScoreExpression::Create(
+ expression->eval(DocHitInfo(), /*query_it=*/nullptr));
+ }
+ return expression;
}
libtextclassifier3::StatusOr<double> OperatorScoreExpression::eval(
@@ -85,7 +97,7 @@ const std::unordered_map<std::string, MathFunctionScoreExpression::FunctionType>
{"sin", FunctionType::kSin}, {"cos", FunctionType::kCos},
{"tan", FunctionType::kTan}};
-libtextclassifier3::StatusOr<std::unique_ptr<MathFunctionScoreExpression>>
+libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
MathFunctionScoreExpression::Create(
FunctionType function_type,
std::vector<std::unique_ptr<ScoreExpression>> children) {
@@ -93,12 +105,16 @@ MathFunctionScoreExpression::Create(
return absl_ports::InvalidArgumentError(
"Math functions must have at least one argument.");
}
+ bool children_all_constant_double = true;
for (const auto& child : children) {
ICING_RETURN_ERROR_IF_NULL(child);
if (child->is_document_type()) {
return absl_ports::InvalidArgumentError(
"Math functions are not supported for document type.");
}
+ if (!child->is_constant_double()) {
+ children_all_constant_double = false;
+ }
}
switch (function_type) {
case FunctionType::kLog:
@@ -143,8 +159,16 @@ MathFunctionScoreExpression::Create(
case FunctionType::kMin:
break;
}
- return std::unique_ptr<MathFunctionScoreExpression>(
- new MathFunctionScoreExpression(function_type, std::move(children)));
+ std::unique_ptr<ScoreExpression> expression =
+ std::unique_ptr<MathFunctionScoreExpression>(
+ new MathFunctionScoreExpression(function_type, std::move(children)));
+ if (children_all_constant_double) {
+ // Because all of the children are constants, this expression does not
+ // depend on the DocHitInto or query_it that are passed into it.
+ return ConstantScoreExpression::Create(
+ expression->eval(DocHitInfo(), /*query_it=*/nullptr));
+ }
+ return expression;
}
libtextclassifier3::StatusOr<double> MathFunctionScoreExpression::eval(
diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h
index 533ca52..f80da33 100644
--- a/icing/scoring/advanced_scoring/score-expression.h
+++ b/icing/scoring/advanced_scoring/score-expression.h
@@ -31,8 +31,6 @@
namespace icing {
namespace lib {
-// TODO(b/261474063) Simplify every ScoreExpression node to
-// ConstantScoreExpression if its evaluation does not depend on a document.
class ScoreExpression {
public:
virtual ~ScoreExpression() = default;
@@ -49,6 +47,10 @@ class ScoreExpression {
// Indicate whether the current expression is of document type
virtual bool is_document_type() const { return false; }
+
+ // Indicate whether the current expression is a constant double.
+ // Returns true if and only if the object is of ConstantScoreExpression type.
+ virtual bool is_constant_double() const { return false; }
};
class ThisExpression : public ScoreExpression {
@@ -72,7 +74,8 @@ class ThisExpression : public ScoreExpression {
class ConstantScoreExpression : public ScoreExpression {
public:
- static std::unique_ptr<ConstantScoreExpression> Create(double c) {
+ static std::unique_ptr<ConstantScoreExpression> Create(
+ libtextclassifier3::StatusOr<double> c) {
return std::unique_ptr<ConstantScoreExpression>(
new ConstantScoreExpression(c));
}
@@ -82,10 +85,13 @@ class ConstantScoreExpression : public ScoreExpression {
return c_;
}
+ bool is_constant_double() const override { return true; }
+
private:
- explicit ConstantScoreExpression(double c) : c_(c) {}
+ explicit ConstantScoreExpression(libtextclassifier3::StatusOr<double> c)
+ : c_(c) {}
- double c_;
+ libtextclassifier3::StatusOr<double> c_;
};
class OperatorScoreExpression : public ScoreExpression {
@@ -93,12 +99,12 @@ class OperatorScoreExpression : public ScoreExpression {
enum class OperatorType { kPlus, kMinus, kNegative, kTimes, kDiv };
// RETURNS:
- // - An OperatorScoreExpression instance on success.
+ // - An OperatorScoreExpression instance on success if not simplifiable.
+ // - A ConstantScoreExpression instance on success if simplifiable.
// - FAILED_PRECONDITION on any null pointer in children.
// - INVALID_ARGUMENT on type errors.
- static libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>>
- Create(OperatorType op,
- std::vector<std::unique_ptr<ScoreExpression>> children);
+ static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
+ OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children);
libtextclassifier3::StatusOr<double> eval(
const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override;
@@ -129,13 +135,13 @@ class MathFunctionScoreExpression : public ScoreExpression {
static const std::unordered_map<std::string, FunctionType> kFunctionNames;
// RETURNS:
- // - A MathFunctionScoreExpression instance on success.
+ // - A MathFunctionScoreExpression instance on success if not simplifiable.
+ // - A ConstantScoreExpression instance on success if simplifiable.
// - FAILED_PRECONDITION on any null pointer in children.
// - INVALID_ARGUMENT on type errors.
- static libtextclassifier3::StatusOr<
- std::unique_ptr<MathFunctionScoreExpression>>
- Create(FunctionType function_type,
- std::vector<std::unique_ptr<ScoreExpression>> children);
+ static libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Create(
+ FunctionType function_type,
+ std::vector<std::unique_ptr<ScoreExpression>> children);
libtextclassifier3::StatusOr<double> eval(
const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override;
diff --git a/icing/scoring/advanced_scoring/score-expression_test.cc b/icing/scoring/advanced_scoring/score-expression_test.cc
new file mode 100644
index 0000000..b49b658
--- /dev/null
+++ b/icing/scoring/advanced_scoring/score-expression_test.cc
@@ -0,0 +1,186 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/scoring/advanced_scoring/score-expression.h"
+
+#include <cmath>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/testing/common-matchers.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::Eq;
+
+class NonConstantScoreExpression : public ScoreExpression {
+ public:
+ static std::unique_ptr<NonConstantScoreExpression> Create() {
+ return std::make_unique<NonConstantScoreExpression>();
+ }
+
+ libtextclassifier3::StatusOr<double> eval(
+ const DocHitInfo&, const DocHitInfoIterator*) override {
+ return 0;
+ }
+
+ bool is_constant_double() const override { return false; }
+};
+
+template <typename... Args>
+std::vector<std::unique_ptr<ScoreExpression>> MakeChildren(Args... args) {
+ std::vector<std::unique_ptr<ScoreExpression>> children;
+ (children.push_back(std::move(args)), ...);
+ return children;
+}
+
+TEST(ScoreExpressionTest, OperatorSimplification) {
+ // 1 + 1 = 2
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ScoreExpression> expression,
+ OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kPlus,
+ MakeChildren(ConstantScoreExpression::Create(1),
+ ConstantScoreExpression::Create(1))));
+ ASSERT_TRUE(expression->is_constant_double());
+ EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2)));
+
+ // 1 - 2 - 3 = -4
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kMinus,
+ MakeChildren(ConstantScoreExpression::Create(1),
+ ConstantScoreExpression::Create(2),
+ ConstantScoreExpression::Create(3))));
+ ASSERT_TRUE(expression->is_constant_double());
+ EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-4)));
+
+ // 1 * 2 * 3 * 4 = 24
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kTimes,
+ MakeChildren(ConstantScoreExpression::Create(1),
+ ConstantScoreExpression::Create(2),
+ ConstantScoreExpression::Create(3),
+ ConstantScoreExpression::Create(4))));
+ ASSERT_TRUE(expression->is_constant_double());
+ EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(24)));
+
+ // 1 / 2 / 4 = 0.125
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kDiv,
+ MakeChildren(ConstantScoreExpression::Create(1),
+ ConstantScoreExpression::Create(2),
+ ConstantScoreExpression::Create(4))));
+ ASSERT_TRUE(expression->is_constant_double());
+ EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(0.125)));
+
+ // -(2) = -2
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kNegative,
+ MakeChildren(ConstantScoreExpression::Create(2))));
+ ASSERT_TRUE(expression->is_constant_double());
+ EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(-2)));
+}
+
+TEST(ScoreExpressionTest, MathFunctionSimplification) {
+ // pow(2, 2) = 4
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ScoreExpression> expression,
+ MathFunctionScoreExpression::Create(
+ MathFunctionScoreExpression::FunctionType::kPow,
+ MakeChildren(ConstantScoreExpression::Create(2),
+ ConstantScoreExpression::Create(2))));
+ ASSERT_TRUE(expression->is_constant_double());
+ EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(4)));
+
+ // abs(-2) = 2
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, MathFunctionScoreExpression::Create(
+ MathFunctionScoreExpression::FunctionType::kAbs,
+ MakeChildren(ConstantScoreExpression::Create(-2))));
+ ASSERT_TRUE(expression->is_constant_double());
+ EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(2)));
+
+ // log(e) = 1
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, MathFunctionScoreExpression::Create(
+ MathFunctionScoreExpression::FunctionType::kLog,
+ MakeChildren(ConstantScoreExpression::Create(M_E))));
+ ASSERT_TRUE(expression->is_constant_double());
+ EXPECT_THAT(expression->eval(DocHitInfo(), nullptr), IsOkAndHolds(Eq(1)));
+}
+
+TEST(ScoreExpressionTest, CannotSimplifyNonConstant) {
+ // 1 + non_constant = non_constant
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<ScoreExpression> expression,
+ OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kPlus,
+ MakeChildren(ConstantScoreExpression::Create(1),
+ NonConstantScoreExpression::Create())));
+ ASSERT_FALSE(expression->is_constant_double());
+
+ // non_constant * non_constant = non_constant
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kTimes,
+ MakeChildren(NonConstantScoreExpression::Create(),
+ NonConstantScoreExpression::Create())));
+ ASSERT_FALSE(expression->is_constant_double());
+
+ // -(non_constant) = non_constant
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kNegative,
+ MakeChildren(NonConstantScoreExpression::Create())));
+ ASSERT_FALSE(expression->is_constant_double());
+
+ // pow(non_constant, 2) = non_constant
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, MathFunctionScoreExpression::Create(
+ MathFunctionScoreExpression::FunctionType::kPow,
+ MakeChildren(NonConstantScoreExpression::Create(),
+ ConstantScoreExpression::Create(2))));
+ ASSERT_FALSE(expression->is_constant_double());
+
+ // abs(non_constant) = non_constant
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, MathFunctionScoreExpression::Create(
+ MathFunctionScoreExpression::FunctionType::kAbs,
+ MakeChildren(NonConstantScoreExpression::Create())));
+ ASSERT_FALSE(expression->is_constant_double());
+
+ // log(non_constant) = non_constant
+ ICING_ASSERT_OK_AND_ASSIGN(
+ expression, MathFunctionScoreExpression::Create(
+ MathFunctionScoreExpression::FunctionType::kLog,
+ MakeChildren(NonConstantScoreExpression::Create())));
+ ASSERT_FALSE(expression->is_constant_double());
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc
index 1396dcc..ea2e190 100644
--- a/icing/scoring/advanced_scoring/scoring-visitor.cc
+++ b/icing/scoring/advanced_scoring/scoring-visitor.cc
@@ -132,8 +132,8 @@ void ScoringVisitor::VisitUnaryOperator(const UnaryOperatorNode* node) {
std::vector<std::unique_ptr<ScoreExpression>> children;
children.push_back(pop_stack());
- libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>>
- expression = OperatorScoreExpression::Create(
+ libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> expression =
+ OperatorScoreExpression::Create(
OperatorScoreExpression::OperatorType::kNegative,
std::move(children));
if (!expression.ok()) {
@@ -153,8 +153,8 @@ void ScoringVisitor::VisitNaryOperator(const NaryOperatorNode* node) {
children.push_back(pop_stack());
}
- libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>>
- expression = absl_ports::InvalidArgumentError(
+ libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> expression =
+ absl_ports::InvalidArgumentError(
absl_ports::StrCat("Unknown Nary operator: ", node->operator_text()));
if (node->operator_text() == "PLUS") {
diff --git a/icing/scoring/scored-document-hit.h b/icing/scoring/scored-document-hit.h
index 141049e..5fc2f3a 100644
--- a/icing/scoring/scored-document-hit.h
+++ b/icing/scoring/scored-document-hit.h
@@ -123,8 +123,8 @@ class JoinedScoredDocumentHit {
};
explicit JoinedScoredDocumentHit(
- double final_score, ScoredDocumentHit&& parent_scored_document_hit,
- std::vector<ScoredDocumentHit>&& child_scored_document_hits)
+ double final_score, ScoredDocumentHit parent_scored_document_hit,
+ std::vector<ScoredDocumentHit> child_scored_document_hits)
: final_score_(final_score),
parent_scored_document_hit_(std::move(parent_scored_document_hit)),
child_scored_document_hits_(std::move(child_scored_document_hits)) {}
diff --git a/icing/scoring/scorer-factory.cc b/icing/scoring/scorer-factory.cc
index 600fe6b..f75b564 100644
--- a/icing/scoring/scorer-factory.cc
+++ b/icing/scoring/scorer-factory.cc
@@ -213,8 +213,9 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
return AdvancedScorer::Create(scoring_spec, default_score, document_store,
schema_store);
case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE:
- ICING_LOG(WARNING)
- << "JOIN_AGGREGATE_SCORE not implemented, falling back to NoScorer";
+ // Use join aggregate score to rank. Since the aggregation score is
+ // calculated by child documents after joining (in JoinProcessor), we can
+ // simply use NoScorer for parent documents.
[[fallthrough]];
case ScoringSpecProto::RankingStrategy::NONE:
return std::make_unique<NoScorer>(default_score);
diff --git a/icing/store/document-store.cc b/icing/store/document-store.cc
index 3add705..9e79790 100644
--- a/icing/store/document-store.cc
+++ b/icing/store/document-store.cc
@@ -222,7 +222,7 @@ libtextclassifier3::StatusOr<DocumentId> DocumentStore::Put(
DocumentProto&& document, int32_t num_tokens,
PutDocumentStatsProto* put_document_stats) {
document.mutable_internal_fields()->set_length_in_tokens(num_tokens);
- return InternalPut(document, put_document_stats);
+ return InternalPut(std::move(document), put_document_stats);
}
DocumentStore::~DocumentStore() {
@@ -840,7 +840,7 @@ libtextclassifier3::Status DocumentStore::UpdateHeader(const Crc32& checksum) {
}
libtextclassifier3::StatusOr<DocumentId> DocumentStore::InternalPut(
- DocumentProto& document, PutDocumentStatsProto* put_document_stats) {
+ DocumentProto&& document, PutDocumentStatsProto* put_document_stats) {
std::unique_ptr<Timer> put_timer = clock_.GetNewTimer();
ICING_RETURN_IF_ERROR(document_validator_.Validate(document));
@@ -1714,7 +1714,7 @@ DocumentStore::OptimizeInto(const std::string& new_directory,
}
// Guaranteed to have a document now.
- DocumentProto document_to_keep = document_or.ValueOrDie();
+ DocumentProto document_to_keep = std::move(document_or).ValueOrDie();
libtextclassifier3::StatusOr<DocumentId> new_document_id_or;
if (document_to_keep.internal_fields().length_in_tokens() == 0) {
@@ -1729,11 +1729,12 @@ DocumentStore::OptimizeInto(const std::string& new_directory,
TokenizedDocument tokenized_document(
std::move(tokenized_document_or).ValueOrDie());
new_document_id_or = new_doc_store->Put(
- document_to_keep, tokenized_document.num_string_tokens());
+ std::move(document_to_keep), tokenized_document.num_string_tokens());
} else {
// TODO(b/144458732): Implement a more robust version of
// TC_ASSIGN_OR_RETURN that can support error logging.
- new_document_id_or = new_doc_store->InternalPut(document_to_keep);
+ new_document_id_or =
+ new_doc_store->InternalPut(std::move(document_to_keep));
}
if (!new_document_id_or.ok()) {
ICING_LOG(ERROR) << new_document_id_or.status().error_message()
diff --git a/icing/store/document-store.h b/icing/store/document-store.h
index 58977bf..bda351d 100644
--- a/icing/store/document-store.h
+++ b/icing/store/document-store.h
@@ -628,7 +628,7 @@ class DocumentStore {
libtextclassifier3::Status UpdateHeader(const Crc32& checksum);
libtextclassifier3::StatusOr<DocumentId> InternalPut(
- DocumentProto& document,
+ DocumentProto&& document,
PutDocumentStatsProto* put_document_stats = nullptr);
// Helper function to do batch deletes. Documents with the given
diff --git a/icing/testing/common-matchers.cc b/icing/testing/common-matchers.cc
new file mode 100644
index 0000000..cd4e446
--- /dev/null
+++ b/icing/testing/common-matchers.cc
@@ -0,0 +1,124 @@
+// Copyright (C) 2023 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "icing/testing/common-matchers.h"
+
+namespace icing {
+namespace lib {
+
+ExtractTermFrequenciesResult ExtractTermFrequencies(
+ const std::unordered_map<SectionId, Hit::TermFrequency>&
+ section_ids_tf_map) {
+ ExtractTermFrequenciesResult result;
+ for (const auto& [section_id, tf] : section_ids_tf_map) {
+ result.term_frequencies[section_id] = tf;
+ result.section_mask |= UINT64_C(1) << section_id;
+ }
+ return result;
+}
+
+CheckTermFrequencyResult CheckTermFrequency(
+ const std::array<Hit::TermFrequency, kTotalNumSections>&
+ expected_term_frequencies,
+ const std::array<Hit::TermFrequency, kTotalNumSections>&
+ actual_term_frequencies) {
+ CheckTermFrequencyResult result;
+ for (SectionId section_id = 0; section_id < kTotalNumSections; ++section_id) {
+ if (expected_term_frequencies.at(section_id) !=
+ actual_term_frequencies.at(section_id)) {
+ result.term_frequencies_match = false;
+ }
+ }
+ result.actual_term_frequencies_str =
+ absl_ports::StrCat("[",
+ absl_ports::StrJoin(actual_term_frequencies, ",",
+ absl_ports::NumberFormatter()),
+ "]");
+ result.expected_term_frequencies_str =
+ absl_ports::StrCat("[",
+ absl_ports::StrJoin(expected_term_frequencies, ",",
+ absl_ports::NumberFormatter()),
+ "]");
+ return result;
+}
+
+std::string StatusCodeToString(libtextclassifier3::StatusCode code) {
+ switch (code) {
+ case libtextclassifier3::StatusCode::OK:
+ return "OK";
+ case libtextclassifier3::StatusCode::CANCELLED:
+ return "CANCELLED";
+ case libtextclassifier3::StatusCode::UNKNOWN:
+ return "UNKNOWN";
+ case libtextclassifier3::StatusCode::INVALID_ARGUMENT:
+ return "INVALID_ARGUMENT";
+ case libtextclassifier3::StatusCode::DEADLINE_EXCEEDED:
+ return "DEADLINE_EXCEEDED";
+ case libtextclassifier3::StatusCode::NOT_FOUND:
+ return "NOT_FOUND";
+ case libtextclassifier3::StatusCode::ALREADY_EXISTS:
+ return "ALREADY_EXISTS";
+ case libtextclassifier3::StatusCode::PERMISSION_DENIED:
+ return "PERMISSION_DENIED";
+ case libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED:
+ return "RESOURCE_EXHAUSTED";
+ case libtextclassifier3::StatusCode::FAILED_PRECONDITION:
+ return "FAILED_PRECONDITION";
+ case libtextclassifier3::StatusCode::ABORTED:
+ return "ABORTED";
+ case libtextclassifier3::StatusCode::OUT_OF_RANGE:
+ return "OUT_OF_RANGE";
+ case libtextclassifier3::StatusCode::UNIMPLEMENTED:
+ return "UNIMPLEMENTED";
+ case libtextclassifier3::StatusCode::INTERNAL:
+ return "INTERNAL";
+ case libtextclassifier3::StatusCode::UNAVAILABLE:
+ return "UNAVAILABLE";
+ case libtextclassifier3::StatusCode::DATA_LOSS:
+ return "DATA_LOSS";
+ case libtextclassifier3::StatusCode::UNAUTHENTICATED:
+ return "UNAUTHENTICATED";
+ default:
+ return "";
+ }
+}
+
+std::string ProtoStatusCodeToString(StatusProto::Code code) {
+ switch (code) {
+ case StatusProto::OK:
+ return "OK";
+ case StatusProto::UNKNOWN:
+ return "UNKNOWN";
+ case StatusProto::INVALID_ARGUMENT:
+ return "INVALID_ARGUMENT";
+ case StatusProto::NOT_FOUND:
+ return "NOT_FOUND";
+ case StatusProto::ALREADY_EXISTS:
+ return "ALREADY_EXISTS";
+ case StatusProto::OUT_OF_SPACE:
+ return "OUT_OF_SPACE";
+ case StatusProto::FAILED_PRECONDITION:
+ return "FAILED_PRECONDITION";
+ case StatusProto::ABORTED:
+ return "ABORTED";
+ case StatusProto::INTERNAL:
+ return "INTERNAL";
+ case StatusProto::WARNING_DATA_LOSS:
+ return "WARNING_DATA_LOSS";
+ default:
+ return "";
+ }
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/testing/common-matchers.h b/icing/testing/common-matchers.h
index e090800..db7b7ef 100644
--- a/icing/testing/common-matchers.h
+++ b/icing/testing/common-matchers.h
@@ -22,11 +22,11 @@
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/status_macros.h"
-#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "icing/absl_ports/str_join.h"
#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/hit/hit.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
#include "icing/legacy/core/icing-string-util.h"
#include "icing/portable/equals-proto.h"
@@ -35,7 +35,6 @@
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
#include "icing/scoring/scored-document-hit.h"
-#include "icing/util/status-macros.h"
namespace icing {
namespace lib {
@@ -69,43 +68,85 @@ MATCHER_P2(EqualsDocHitInfo, document_id, section_ids, "") {
actual.hit_section_ids_mask() == section_mask;
}
+struct ExtractTermFrequenciesResult {
+ std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies = {0};
+ SectionIdMask section_mask = kSectionIdMaskNone;
+};
+// Extracts the term frequencies represented by the section_ids_tf_map.
+// Returns:
+// - a SectionIdMask representing all sections that appears as entries in the
+// map, even if they have an entry with term_frequency==0
+// - an array representing the term frequencies for each section. Sections not
+// present in section_ids_tf_map have a term frequency of 0.
+ExtractTermFrequenciesResult ExtractTermFrequencies(
+ const std::unordered_map<SectionId, Hit::TermFrequency>&
+ section_ids_tf_map);
+
+struct CheckTermFrequencyResult {
+ std::string expected_term_frequencies_str;
+ std::string actual_term_frequencies_str;
+ bool term_frequencies_match = true;
+};
+// Checks that the term frequencies in actual_term_frequencies match those
+// specified in expected_section_ids_tf_map. If there is no entry in
+// expected_section_ids_tf_map, then it is assumed that the term frequency for
+// that section is 0.
+// Returns:
+// - a bool indicating if the term frequencies match
+// - debug strings representing the contents of the actual and expected term
+// term frequency arrays.
+CheckTermFrequencyResult CheckTermFrequency(
+ const std::array<Hit::TermFrequency, kTotalNumSections>&
+ expected_term_frequencies,
+ const std::array<Hit::TermFrequency, kTotalNumSections>&
+ actual_term_frequencies);
+
// Used to match a DocHitInfo
MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id,
section_ids_to_term_frequencies_map, "") {
const DocHitInfoTermFrequencyPair& actual = arg;
- SectionIdMask section_mask = kSectionIdMaskNone;
-
- bool term_frequency_as_expected = true;
- std::vector<Hit::TermFrequency> expected_tfs;
- std::vector<Hit::TermFrequency> actual_tfs;
- for (auto itr = section_ids_to_term_frequencies_map.begin();
- itr != section_ids_to_term_frequencies_map.end(); itr++) {
- SectionId section_id = itr->first;
- section_mask |= UINT64_C(1) << section_id;
- expected_tfs.push_back(itr->second);
- actual_tfs.push_back(actual.hit_term_frequency(section_id));
- if (actual.hit_term_frequency(section_id) != itr->second) {
- term_frequency_as_expected = false;
- }
+ std::array<Hit::TermFrequency, kTotalNumSections> actual_tf_array;
+ for (SectionId section_id = 0; section_id < kTotalNumSections; ++section_id) {
+ actual_tf_array[section_id] = actual.hit_term_frequency(section_id);
}
- std::string actual_term_frequencies = absl_ports::StrCat(
- "[", absl_ports::StrJoin(actual_tfs, ",", absl_ports::NumberFormatter()),
- "]");
- std::string expected_term_frequencies = absl_ports::StrCat(
- "[",
- absl_ports::StrJoin(expected_tfs, ",", absl_ports::NumberFormatter()),
- "]");
+ ExtractTermFrequenciesResult expected =
+ ExtractTermFrequencies(section_ids_to_term_frequencies_map);
+ CheckTermFrequencyResult check_tf_result =
+ CheckTermFrequency(expected.term_frequencies, actual_tf_array);
+
*result_listener << IcingStringUtil::StringPrintf(
"(actual is {document_id=%d, section_mask=%" PRIu64
", term_frequencies=%s}, but expected was "
"{document_id=%d, section_mask=%" PRIu64 ", term_frequencies=%s}.)",
actual.doc_hit_info().document_id(),
actual.doc_hit_info().hit_section_ids_mask(),
- actual_term_frequencies.c_str(), document_id, section_mask,
- expected_term_frequencies.c_str());
+ check_tf_result.actual_term_frequencies_str.c_str(), document_id,
+ expected.section_mask,
+ check_tf_result.expected_term_frequencies_str.c_str());
return actual.doc_hit_info().document_id() == document_id &&
- actual.doc_hit_info().hit_section_ids_mask() == section_mask &&
- term_frequency_as_expected;
+ actual.doc_hit_info().hit_section_ids_mask() ==
+ expected.section_mask &&
+ check_tf_result.term_frequencies_match;
+}
+
+MATCHER_P2(EqualsTermMatchInfo, term, section_ids_to_term_frequencies_map, "") {
+ const TermMatchInfo& actual = arg;
+ std::string term_str(term);
+ ExtractTermFrequenciesResult expected =
+ ExtractTermFrequencies(section_ids_to_term_frequencies_map);
+ CheckTermFrequencyResult check_tf_result =
+ CheckTermFrequency(expected.term_frequencies, actual.term_frequencies);
+ *result_listener << IcingStringUtil::StringPrintf(
+ "(actual is {term=%s, section_mask=%" PRIu64
+ ", term_frequencies=%s}, but expected was "
+ "{term=%s, section_mask=%" PRIu64 ", term_frequencies=%s}.)",
+ actual.term.data(), actual.section_ids_mask,
+ check_tf_result.actual_term_frequencies_str.c_str(), term_str.data(),
+ expected.section_mask,
+ check_tf_result.expected_term_frequencies_str.c_str());
+ return actual.term == term &&
+ actual.section_ids_mask == expected.section_mask &&
+ check_tf_result.term_frequencies_match;
}
class ScoredDocumentHitFormatter {
@@ -337,73 +378,9 @@ MATCHER_P(EqualsSetSchemaResult, expected, "") {
return false;
}
-inline std::string StatusCodeToString(libtextclassifier3::StatusCode code) {
- switch (code) {
- case libtextclassifier3::StatusCode::OK:
- return "OK";
- case libtextclassifier3::StatusCode::CANCELLED:
- return "CANCELLED";
- case libtextclassifier3::StatusCode::UNKNOWN:
- return "UNKNOWN";
- case libtextclassifier3::StatusCode::INVALID_ARGUMENT:
- return "INVALID_ARGUMENT";
- case libtextclassifier3::StatusCode::DEADLINE_EXCEEDED:
- return "DEADLINE_EXCEEDED";
- case libtextclassifier3::StatusCode::NOT_FOUND:
- return "NOT_FOUND";
- case libtextclassifier3::StatusCode::ALREADY_EXISTS:
- return "ALREADY_EXISTS";
- case libtextclassifier3::StatusCode::PERMISSION_DENIED:
- return "PERMISSION_DENIED";
- case libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED:
- return "RESOURCE_EXHAUSTED";
- case libtextclassifier3::StatusCode::FAILED_PRECONDITION:
- return "FAILED_PRECONDITION";
- case libtextclassifier3::StatusCode::ABORTED:
- return "ABORTED";
- case libtextclassifier3::StatusCode::OUT_OF_RANGE:
- return "OUT_OF_RANGE";
- case libtextclassifier3::StatusCode::UNIMPLEMENTED:
- return "UNIMPLEMENTED";
- case libtextclassifier3::StatusCode::INTERNAL:
- return "INTERNAL";
- case libtextclassifier3::StatusCode::UNAVAILABLE:
- return "UNAVAILABLE";
- case libtextclassifier3::StatusCode::DATA_LOSS:
- return "DATA_LOSS";
- case libtextclassifier3::StatusCode::UNAUTHENTICATED:
- return "UNAUTHENTICATED";
- default:
- return "";
- }
-}
+std::string StatusCodeToString(libtextclassifier3::StatusCode code);
-inline std::string ProtoStatusCodeToString(StatusProto::Code code) {
- switch (code) {
- case StatusProto::OK:
- return "OK";
- case StatusProto::UNKNOWN:
- return "UNKNOWN";
- case StatusProto::INVALID_ARGUMENT:
- return "INVALID_ARGUMENT";
- case StatusProto::NOT_FOUND:
- return "NOT_FOUND";
- case StatusProto::ALREADY_EXISTS:
- return "ALREADY_EXISTS";
- case StatusProto::OUT_OF_SPACE:
- return "OUT_OF_SPACE";
- case StatusProto::FAILED_PRECONDITION:
- return "FAILED_PRECONDITION";
- case StatusProto::ABORTED:
- return "ABORTED";
- case StatusProto::INTERNAL:
- return "INTERNAL";
- case StatusProto::WARNING_DATA_LOSS:
- return "WARNING_DATA_LOSS";
- default:
- return "";
- }
-}
+std::string ProtoStatusCodeToString(StatusProto::Code code);
MATCHER(IsOk, "") {
libtextclassifier3::StatusAdapter adapter(arg);
diff --git a/icing/util/logging.h b/icing/util/logging.h
index 7742302..23280dc 100644
--- a/icing/util/logging.h
+++ b/icing/util/logging.h
@@ -131,13 +131,29 @@ class LogMessage {
inline constexpr char kIcingLoggingTag[] = "AppSearchIcing";
-#define ICING_VLOG(verbose_level) \
- ::icing::lib::LogMessage(::icing::lib::LogSeverity::VERBOSE, verbose_level, \
- __FILE__, __LINE__) \
- .stream()
-#define ICING_LOG(severity) \
- ::icing::lib::LogMessage(::icing::lib::LogSeverity::severity, \
- /*verbosity=*/0, __FILE__, __LINE__) \
+// Define consts to make it easier to refer to log severities in code.
+constexpr ::icing::lib::LogSeverity::Code VERBOSE =
+ ::icing::lib::LogSeverity::VERBOSE;
+
+constexpr ::icing::lib::LogSeverity::Code DBG = ::icing::lib::LogSeverity::DBG;
+
+constexpr ::icing::lib::LogSeverity::Code INFO =
+ ::icing::lib::LogSeverity::INFO;
+
+constexpr ::icing::lib::LogSeverity::Code WARNING =
+ ::icing::lib::LogSeverity::WARNING;
+
+constexpr ::icing::lib::LogSeverity::Code ERROR =
+ ::icing::lib::LogSeverity::ERROR;
+
+constexpr ::icing::lib::LogSeverity::Code FATAL =
+ ::icing::lib::LogSeverity::FATAL;
+
+#define ICING_VLOG(verbose_level) \
+ ::icing::lib::LogMessage(VERBOSE, verbose_level, __FILE__, __LINE__).stream()
+
+#define ICING_LOG(severity) \
+ ::icing::lib::LogMessage(severity, /*verbosity=*/0, __FILE__, __LINE__) \
.stream()
} // namespace lib
diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto
index b2435d0..c9e2b1d 100644
--- a/proto/icing/proto/search.proto
+++ b/proto/icing/proto/search.proto
@@ -505,13 +505,16 @@ message JoinSpecProto {
// taken on it. If JOIN_AGGREGATE_SCORE is used in the base SearchSpecProto,
// the COUNT value will rank entity documents based on the number of child
// documents.
- enum AggregationScore {
- UNDEFINED = 0;
- COUNT = 1;
- MIN = 2;
- AVG = 3;
- MAX = 4;
- SUM = 5;
+ message AggregationScoringStrategy {
+ enum Code {
+ NONE = 0; // No aggregation strategy for child documents and use parent
+ // document score.
+ COUNT = 1;
+ MIN = 2;
+ AVG = 3;
+ MAX = 4;
+ SUM = 5;
+ }
}
- optional AggregationScore aggregation_score_strategy = 5 [default = COUNT];
+ optional AggregationScoringStrategy.Code aggregation_scoring_strategy = 5;
}
diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt
index 00194ce..a4f3a30 100644
--- a/synced_AOSP_CL_number.txt
+++ b/synced_AOSP_CL_number.txt
@@ -1 +1 @@
-set(synced_AOSP_CL_number=496703735)
+set(synced_AOSP_CL_number=500254546)