diff options
Diffstat (limited to 'icing/join')
-rw-r--r-- | icing/join/aggregation-scorer.cc (renamed from icing/join/aggregate-scorer.cc) | 86 | ||||
-rw-r--r-- | icing/join/aggregation-scorer.h (renamed from icing/join/aggregate-scorer.h) | 12 | ||||
-rw-r--r-- | icing/join/aggregation-scorer_test.cc | 215 | ||||
-rw-r--r-- | icing/join/join-processor.cc | 23 | ||||
-rw-r--r-- | icing/join/join-processor_test.cc | 624 |
5 files changed, 912 insertions, 48 deletions
diff --git a/icing/join/aggregate-scorer.cc b/icing/join/aggregation-scorer.cc index 7b17482..3dee3dd 100644 --- a/icing/join/aggregate-scorer.cc +++ b/icing/join/aggregation-scorer.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/join/aggregate-scorer.h" +#include "icing/join/aggregation-scorer.h" #include <algorithm> #include <memory> @@ -25,24 +25,26 @@ namespace icing { namespace lib { -class MinAggregateScorer : public AggregateScorer { +class CountAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { - return std::min_element(children.begin(), children.end(), - [](const ScoredDocumentHit& lhs, - const ScoredDocumentHit& rhs) -> bool { - return lhs.score() < rhs.score(); - }) - ->score(); + return children.size(); } }; -class MaxAggregateScorer : public AggregateScorer { +class MinAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { - return std::max_element(children.begin(), children.end(), + if (children.empty()) { + // Return 0 if there is no child document. + // For non-empty children with negative scores, they are considered "worse + // than" 0, so it is correct to return 0 for empty children to assign it a + // rank higher than non-empty children with negative scores. + return 0.0; + } + return std::min_element(children.begin(), children.end(), [](const ScoredDocumentHit& lhs, const ScoredDocumentHit& rhs) -> bool { return lhs.score() < rhs.score(); @@ -51,41 +53,59 @@ class MaxAggregateScorer : public AggregateScorer { } }; -class AverageAggregateScorer : public AggregateScorer { +class AverageAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { - if (children.empty()) return 0.0; + if (children.empty()) { + // Return 0 if there is no child document. + // For non-empty children with negative scores, they are considered "worse + // than" 0, so it is correct to return 0 for empty children to assign it a + // rank higher than non-empty children with negative scores. + return 0.0; + } return std::reduce( children.begin(), children.end(), 0.0, - [](const double& prev, const ScoredDocumentHit& item) -> double { + [](double prev, const ScoredDocumentHit& item) -> double { return prev + item.score(); }) / children.size(); } }; -class CountAggregateScorer : public AggregateScorer { +class MaxAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { - return children.size(); + if (children.empty()) { + // Return 0 if there is no child document. + // For non-empty children with negative scores, they are considered "worse + // than" 0, so it is correct to return 0 for empty children to assign it a + // rank higher than non-empty children with negative scores. + return 0.0; + } + return std::max_element(children.begin(), children.end(), + [](const ScoredDocumentHit& lhs, + const ScoredDocumentHit& rhs) -> bool { + return lhs.score() < rhs.score(); + }) + ->score(); } }; -class SumAggregateScorer : public AggregateScorer { +class SumAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { return std::reduce( children.begin(), children.end(), 0.0, - [](const double& prev, const ScoredDocumentHit& item) -> double { + [](double prev, const ScoredDocumentHit& item) -> double { return prev + item.score(); }); } }; -class DefaultAggregateScorer : public AggregateScorer { +class DefaultAggregationScorer : public AggregationScorer { public: double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) override { @@ -93,23 +113,25 @@ class DefaultAggregateScorer : public AggregateScorer { } }; -std::unique_ptr<AggregateScorer> AggregateScorer::Create( +std::unique_ptr<AggregationScorer> AggregationScorer::Create( const JoinSpecProto& join_spec) { - switch (join_spec.aggregation_score_strategy()) { - case JoinSpecProto_AggregationScore_MIN: - return std::make_unique<MinAggregateScorer>(); - case JoinSpecProto_AggregationScore_MAX: - return std::make_unique<MaxAggregateScorer>(); - case JoinSpecProto_AggregationScore_COUNT: - return std::make_unique<CountAggregateScorer>(); - case JoinSpecProto_AggregationScore_AVG: - return std::make_unique<AverageAggregateScorer>(); - case JoinSpecProto_AggregationScore_SUM: - return std::make_unique<SumAggregateScorer>(); - case JoinSpecProto_AggregationScore_UNDEFINED: + switch (join_spec.aggregation_scoring_strategy()) { + case JoinSpecProto::AggregationScoringStrategy::COUNT: + return std::make_unique<CountAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::MIN: + return std::make_unique<MinAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::AVG: + return std::make_unique<AverageAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::MAX: + return std::make_unique<MaxAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::SUM: + return std::make_unique<SumAggregationScorer>(); + case JoinSpecProto::AggregationScoringStrategy::NONE: + // No aggregation strategy means using parent document score, so fall + // through to return DefaultAggregationScorer. [[fallthrough]]; default: - return std::make_unique<DefaultAggregateScorer>(); + return std::make_unique<DefaultAggregationScorer>(); } } diff --git a/icing/join/aggregate-scorer.h b/icing/join/aggregation-scorer.h index 27731b9..3d38cf0 100644 --- a/icing/join/aggregate-scorer.h +++ b/icing/join/aggregation-scorer.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef ICING_JOIN_AGGREGATE_SCORER_H_ -#define ICING_JOIN_AGGREGATE_SCORER_H_ +#ifndef ICING_JOIN_AGGREGATION_SCORER_H_ +#define ICING_JOIN_AGGREGATION_SCORER_H_ #include <memory> #include <vector> @@ -24,12 +24,12 @@ namespace icing { namespace lib { -class AggregateScorer { +class AggregationScorer { public: - static std::unique_ptr<AggregateScorer> Create( + static std::unique_ptr<AggregationScorer> Create( const JoinSpecProto& join_spec); - virtual ~AggregateScorer() = default; + virtual ~AggregationScorer() = default; virtual double GetScore(const ScoredDocumentHit& parent, const std::vector<ScoredDocumentHit>& children) = 0; @@ -38,4 +38,4 @@ class AggregateScorer { } // namespace lib } // namespace icing -#endif // ICING_JOIN_AGGREGATE_SCORER_H_ +#endif // ICING_JOIN_AGGREGATION_SCORER_H_ diff --git a/icing/join/aggregation-scorer_test.cc b/icing/join/aggregation-scorer_test.cc new file mode 100644 index 0000000..19a7239 --- /dev/null +++ b/icing/join/aggregation-scorer_test.cc @@ -0,0 +1,215 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/join/aggregation-scorer.h" + +#include <algorithm> +#include <iterator> +#include <memory> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/proto/search.pb.h" +#include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::DoubleEq; + +struct AggregationScorerTestParam { + double ans; + JoinSpecProto::AggregationScoringStrategy::Code scoring_strategy; + double parent_score; + std::vector<double> child_scores; + + explicit AggregationScorerTestParam( + double ans_in, + JoinSpecProto::AggregationScoringStrategy::Code scoring_strategy_in, + double parent_score_in, std::vector<double> child_scores_in) + : ans(ans_in), + scoring_strategy(scoring_strategy_in), + parent_score(std::move(parent_score_in)), + child_scores(std::move(child_scores_in)) {} +}; + +class AggregationScorerTest + : public ::testing::TestWithParam<AggregationScorerTestParam> {}; + +TEST_P(AggregationScorerTest, GetScore) { + static constexpr DocumentId kDefaultDocumentId = 0; + + const AggregationScorerTestParam& param = GetParam(); + // Test AggregationScorer by creating some ScoredDocumentHits for parent and + // child documents. DocumentId and SectionIdMask won't affect the aggregation + // score calculation, so just simply set default values. + // Parent document + ScoredDocumentHit parent_scored_document_hit( + kDefaultDocumentId, kSectionIdMaskNone, param.parent_score); + // Child documents + std::vector<ScoredDocumentHit> child_scored_document_hits; + child_scored_document_hits.reserve(param.child_scores.size()); + std::transform(param.child_scores.cbegin(), param.child_scores.cend(), + std::back_inserter(child_scored_document_hits), + [](double score) -> ScoredDocumentHit { + return ScoredDocumentHit(kDefaultDocumentId, + kSectionIdMaskNone, score); + }); + + JoinSpecProto join_spec; + join_spec.set_aggregation_scoring_strategy(param.scoring_strategy); + std::unique_ptr<AggregationScorer> scorer = + AggregationScorer::Create(join_spec); + EXPECT_THAT( + scorer->GetScore(parent_scored_document_hit, child_scored_document_hits), + DoubleEq(param.ans)); +} + +INSTANTIATE_TEST_SUITE_P( + CountAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/5, JoinSpecProto::AggregationScoringStrategy::COUNT, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child. + AggregationScorerTestParam( + /*ans_in=*/1, JoinSpecProto::AggregationScoringStrategy::COUNT, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::COUNT, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + MinAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/1, JoinSpecProto::AggregationScoringStrategy::MIN, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child, greater than parent. + AggregationScorerTestParam( + /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::MIN, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // Only one child, smaller than parent. + AggregationScorerTestParam( + /*ans_in=*/50, JoinSpecProto::AggregationScoringStrategy::MIN, + /*parent_score_in=*/98, + /*child_scores_in=*/{50}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::MIN, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + AverageAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/4.6, JoinSpecProto::AggregationScoringStrategy::AVG, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child. + AggregationScorerTestParam( + /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::AVG, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::AVG, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + MaxAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/8, JoinSpecProto::AggregationScoringStrategy::MAX, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child, greater than parent. + AggregationScorerTestParam( + /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::MAX, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // Only one child, smaller than parent. + AggregationScorerTestParam( + /*ans_in=*/50, JoinSpecProto::AggregationScoringStrategy::MAX, + /*parent_score_in=*/98, + /*child_scores_in=*/{50}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::MAX, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + SumAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/23, JoinSpecProto::AggregationScoringStrategy::SUM, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child. + AggregationScorerTestParam( + /*ans_in=*/123, JoinSpecProto::AggregationScoringStrategy::SUM, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/0, JoinSpecProto::AggregationScoringStrategy::SUM, + /*parent_score_in=*/0, + /*child_scores_in=*/{}))); + +INSTANTIATE_TEST_SUITE_P( + DefaultAggregationScorerTest, AggregationScorerTest, + testing::Values( + // General case. + AggregationScorerTestParam( + /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE, + /*parent_score_in=*/98, + /*child_scores_in=*/{8, 3, 1, 4, 7}), + // Only one child, greater than parent. + AggregationScorerTestParam( + /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE, + /*parent_score_in=*/98, + /*child_scores_in=*/{123}), + // Only one child, smaller than parent. + AggregationScorerTestParam( + /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE, + /*parent_score_in=*/98, + /*child_scores_in=*/{50}), + // No child. + AggregationScorerTestParam( + /*ans_in=*/98, JoinSpecProto::AggregationScoringStrategy::NONE, + /*parent_score_in=*/98, + /*child_scores_in=*/{}))); + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/join/join-processor.cc b/icing/join/join-processor.cc index 9b17396..7700397 100644 --- a/icing/join/join-processor.cc +++ b/icing/join/join-processor.cc @@ -23,6 +23,7 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/join/aggregation-scorer.h" #include "icing/join/qualified-id.h" #include "icing/proto/scoring.pb.h" #include "icing/proto/search.pb.h" @@ -45,8 +46,6 @@ JoinProcessor::Join( ScoringSpecProto::Order::DESC)); // TODO(b/256022027): - // - Aggregate scoring - // - Calculate the aggregated score if strategy is AGGREGATION_SCORING. // - Optimization // - Cache property to speed up property retrieval. // - If there is no cache, then we still have the flexibility to fetch it @@ -93,6 +92,9 @@ JoinProcessor::Join( } } + std::unique_ptr<AggregationScorer> aggregation_scorer = + AggregationScorer::Create(join_spec); + std::vector<JoinedScoredDocumentHit> joined_scored_document_hits; joined_scored_document_hits.reserve(parent_scored_document_hits.size()); @@ -110,14 +112,15 @@ JoinProcessor::Join( "Parent property expression must be ", kQualifiedIdExpr)); } - // TODO(b/256022027): Derive final score from - // parent_id_to_child_map[parent_doc_id] and - // join_spec.aggregation_score_strategy() - double final_score = parent.score(); - joined_scored_document_hits.emplace_back( - final_score, std::move(parent), - std::vector<ScoredDocumentHit>( - std::move(parent_id_to_child_map[parent_doc_id]))); + std::vector<ScoredDocumentHit> children; + if (auto iter = parent_id_to_child_map.find(parent_doc_id); + iter != parent_id_to_child_map.end()) { + children = std::move(iter->second); + } + + double final_score = aggregation_scorer->GetScore(parent, children); + joined_scored_document_hits.emplace_back(final_score, std::move(parent), + std::move(children)); } return joined_scored_document_hits; diff --git a/icing/join/join-processor_test.cc b/icing/join/join-processor_test.cc new file mode 100644 index 0000000..70eaf3f --- /dev/null +++ b/icing/join/join-processor_test.cc @@ -0,0 +1,624 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/join/join-processor.h" + +#include <memory> +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/proto/document.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/proto/scoring.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; + +class JoinProcessorTest : public ::testing::Test { + protected: + void SetUp() override { + test_dir_ = GetTestTempDir() + "/icing_join_processor_test"; + filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder().SetType("Person").AddProperty( + PropertyConfigBuilder() + .SetName("Name") + .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType( + SchemaTypeConfigBuilder() + .SetType("Email") + .AddProperty(PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString(TERM_MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("sender") + .SetDataTypeJoinableString( + JOINABLE_VALUE_TYPE_QUALIFIED_ID) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + doc_store_ = std::move(create_result.document_store); + } + + void TearDown() override { + doc_store_.reset(); + schema_store_.reset(); + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + Filesystem filesystem_; + std::string test_dir_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<DocumentStore> doc_store_; + FakeClock fake_clock_; +}; + +TEST_F(JoinProcessorTest, JoinByQualifiedId) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + DocumentProto person2 = DocumentBuilder() + .SetKey(R"(pkg$db/name#space\\)", "person2") + .SetSchema("Person") + .AddStringProperty("Name", "Bob") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .AddStringProperty("sender", + R"(pkg$db/name\#space\\\\#person2)") // escaped + .Build(); + DocumentProto email3 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email3") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 3") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(person2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, doc_store_->Put(email3)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/3.0); + ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone, + /*score=*/4.0); + ScoredDocumentHit scored_doc_hit5(document_id5, kSectionIdMaskNone, + /*score=*/5.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit2, scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = { + scored_doc_hit5, scored_doc_hit4, scored_doc_hit3}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit2, + /*child_scored_document_hits=*/{scored_doc_hit4})), + EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/2.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/ + {scored_doc_hit5, scored_doc_hit3})))); +} + +TEST_F(JoinProcessorTest, ShouldIgnoreChildDocumentsWithoutJoiningProperty) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email2)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/5.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/6.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = {scored_doc_hit2, + scored_doc_hit3}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Since Email2 doesn't have "sender" property, it should be ignored. + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/{scored_doc_hit2})))); +} + +TEST_F(JoinProcessorTest, ShouldIgnoreChildDocumentsWithInvalidQualifiedId) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .AddStringProperty( + "sender", + "pkg$db/namespace#person2") // qualified id is invalid since + // person2 doesn't exist. + .Build(); + DocumentProto email3 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email3") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 3") + .AddStringProperty("sender", + R"(pkg$db/namespace\#person1)") // invalid format + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email3)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone, + /*score=*/0.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = { + scored_doc_hit2, scored_doc_hit3, scored_doc_hit4}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Email 2 and email 3 (document id 3 and 4) contain invalid qualified ids. + // Join processor should ignore them. + EXPECT_THAT(joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/{scored_doc_hit2})))); +} + +TEST_F(JoinProcessorTest, LeftJoinShouldReturnParentWithoutChildren) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + DocumentProto person2 = DocumentBuilder() + .SetKey(R"(pkg$db/name#space\\)", "person2") + .SetSchema("Person") + .AddStringProperty("Name", "Bob") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", + R"(pkg$db/name\#space\\\\#person2)") // escaped + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(person2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email1)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/3.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit2, scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = {scored_doc_hit3}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Person1 has no child documents, but left join should also include it. + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit2, + /*child_scored_document_hits=*/{scored_doc_hit3})), + EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/0.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/{})))); +} + +TEST_F(JoinProcessorTest, ShouldSortChildDocumentsByRankingStrategy) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email3 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email3") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 3") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email3)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/2.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/5.0); + ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone, + /*score=*/3.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = { + scored_doc_hit2, scored_doc_hit3, scored_doc_hit4}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Child documents should be sorted according to the (nested) ranking + // strategy. + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/3.0, /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/ + {scored_doc_hit3, scored_doc_hit4, scored_doc_hit2})))); +} + +TEST_F(JoinProcessorTest, + ShouldTruncateByRankingStrategyIfExceedingMaxJoinedChildCount) { + DocumentProto person1 = DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("Name", "Alice") + .Build(); + DocumentProto person2 = DocumentBuilder() + .SetKey(R"(pkg$db/name#space\\)", "person2") + .SetSchema("Person") + .AddStringProperty("Name", "Bob") + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email2 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 2") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email3 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email3") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 3") + .AddStringProperty("sender", "pkg$db/namespace#person1") + .Build(); + DocumentProto email4 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email4") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 4") + .AddStringProperty("sender", + R"(pkg$db/name\#space\\\\#person2)") // escaped + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(person1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, doc_store_->Put(person2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, doc_store_->Put(email1)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, doc_store_->Put(email2)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, doc_store_->Put(email3)); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id6, doc_store_->Put(email4)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit2(document_id2, kSectionIdMaskNone, + /*score=*/0.0); + ScoredDocumentHit scored_doc_hit3(document_id3, kSectionIdMaskNone, + /*score=*/2.0); + ScoredDocumentHit scored_doc_hit4(document_id4, kSectionIdMaskNone, + /*score=*/5.0); + ScoredDocumentHit scored_doc_hit5(document_id5, kSectionIdMaskNone, + /*score=*/3.0); + ScoredDocumentHit scored_doc_hit6(document_id6, kSectionIdMaskNone, + /*score=*/1.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1, scored_doc_hit2}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = { + scored_doc_hit3, scored_doc_hit4, scored_doc_hit5, scored_doc_hit6}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(2); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + // Since we set max_joind_child_count as 2 and use DESC as the (nested) + // ranking strategy, parent document with # of child documents more than 2 + // should only keep 2 child documents with higher scores and the rest should + // be truncated. + EXPECT_THAT( + joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/2.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/ + {scored_doc_hit4, scored_doc_hit5})), + EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit2, + /*child_scored_document_hits=*/{scored_doc_hit6})))); +} + +TEST_F(JoinProcessorTest, ShouldAllowSelfJoining) { + DocumentProto email1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("sender", "pkg$db/namespace#email1") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, doc_store_->Put(email1)); + + ScoredDocumentHit scored_doc_hit1(document_id1, kSectionIdMaskNone, + /*score=*/0.0); + + // Parent ScoredDocumentHits: all Person documents + std::vector<ScoredDocumentHit> parent_scored_document_hits = { + scored_doc_hit1}; + + // Child ScoredDocumentHits: all Email documents + std::vector<ScoredDocumentHit> child_scored_document_hits = {scored_doc_hit1}; + + JoinSpecProto join_spec; + join_spec.set_max_joined_child_count(100); + join_spec.set_parent_property_expression( + std::string(JoinProcessor::kQualifiedIdExpr)); + join_spec.set_child_property_expression("sender"); + join_spec.set_aggregation_scoring_strategy( + JoinSpecProto::AggregationScoringStrategy::COUNT); + join_spec.mutable_nested_spec()->mutable_scoring_spec()->set_order_by( + ScoringSpecProto::Order::DESC); + + JoinProcessor join_processor(doc_store_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<JoinedScoredDocumentHit> joined_result_document_hits, + join_processor.Join(join_spec, std::move(parent_scored_document_hits), + std::move(child_scored_document_hits))); + EXPECT_THAT(joined_result_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(JoinedScoredDocumentHit( + /*final_score=*/1.0, + /*parent_scored_document_hit=*/scored_doc_hit1, + /*child_scored_document_hits=*/{scored_doc_hit1})))); +} + +// TODO(b/256022027): add unit tests for non-joinable property. If joinable +// value type is unset, then qualifed id join should not +// include the child document even if it contains a valid +// qualified id string. + +} // namespace + +} // namespace lib +} // namespace icing |