diff options
author | Tim Barron <tjbarron@google.com> | 2021-10-21 16:01:05 -0700 |
---|---|---|
committer | Tim Barron <tjbarron@google.com> | 2021-10-21 16:01:05 -0700 |
commit | da1b8986e7c873efa45529b8adc4a32490eb9c3c (patch) | |
tree | 1cc9dbe185e88e71c7c82ede8ba02578a36ef78f /icing/scoring | |
parent | 8555f998fccca3aea3f6f67d44fce04775ddea97 (diff) | |
download | icing-da1b8986e7c873efa45529b8adc4a32490eb9c3c.tar.gz |
Sync from upstream.
Descriptions:
================
Replace refs to c lib headers w/ c++ stdlib equivalents.
================
Update IDF component of BM25F Calculator in IcingLib
================
Expose QuerySuggestions API.
================
Change the tokenizer used in QuerySuggest.
================
Add SectionWeights API to Icing.
================
Apply SectionWeights to BM25F Scoring.
================
Replaces uses of u_strTo/FromUTF32 w/ u_strTo/FromUTF8.
Bug: 152934343
Bug: 202308641
Bug: 203700301
Change-Id: Ic884a84e5ff4c9c04b2cd6dd1fce90765aa4446e
Diffstat (limited to 'icing/scoring')
-rw-r--r-- | icing/scoring/bm25f-calculator.cc | 51 | ||||
-rw-r--r-- | icing/scoring/bm25f-calculator.h | 32 | ||||
-rw-r--r-- | icing/scoring/score-and-rank_benchmark.cc | 125 | ||||
-rw-r--r-- | icing/scoring/scorer.cc | 19 | ||||
-rw-r--r-- | icing/scoring/scorer.h | 4 | ||||
-rw-r--r-- | icing/scoring/scorer_test.cc | 191 | ||||
-rw-r--r-- | icing/scoring/scoring-processor.cc | 9 | ||||
-rw-r--r-- | icing/scoring/scoring-processor.h | 4 | ||||
-rw-r--r-- | icing/scoring/scoring-processor_test.cc | 377 | ||||
-rw-r--r-- | icing/scoring/section-weights.cc | 146 | ||||
-rw-r--r-- | icing/scoring/section-weights.h | 95 | ||||
-rw-r--r-- | icing/scoring/section-weights_test.cc | 386 |
12 files changed, 1308 insertions, 131 deletions
diff --git a/icing/scoring/bm25f-calculator.cc b/icing/scoring/bm25f-calculator.cc index 4822d7f..28d385e 100644 --- a/icing/scoring/bm25f-calculator.cc +++ b/icing/scoring/bm25f-calculator.cc @@ -26,6 +26,7 @@ #include "icing/store/corpus-associated-scoring-data.h" #include "icing/store/corpus-id.h" #include "icing/store/document-associated-score-data.h" +#include "icing/store/document-filter-data.h" #include "icing/store/document-id.h" namespace icing { @@ -42,8 +43,11 @@ constexpr float k1_ = 1.2f; constexpr float b_ = 0.7f; // TODO(b/158603900): add tests for Bm25fCalculator -Bm25fCalculator::Bm25fCalculator(const DocumentStore* document_store) - : document_store_(document_store) {} +Bm25fCalculator::Bm25fCalculator( + const DocumentStore* document_store, + std::unique_ptr<SectionWeights> section_weights) + : document_store_(document_store), + section_weights_(std::move(section_weights)) {} // During initialization, Bm25fCalculator iterates through // hit-iterators for each query term to pre-compute n(q_i) for each corpus under @@ -121,9 +125,9 @@ float Bm25fCalculator::ComputeScore(const DocHitInfoIterator* query_it, // Compute inverse document frequency (IDF) weight for query term in the given // corpus, and cache it in the map. // -// N - n(q_i) + 0.5 -// IDF(q_i) = log(1 + ------------------) -// n(q_i) + 0.5 +// N - n(q_i) + 0.5 +// IDF(q_i) = ln(1 + ------------------) +// n(q_i) + 0.5 // // where N is the number of documents in the corpus, and n(q_i) is the number // of documents in the corpus containing the query term q_i. @@ -149,7 +153,7 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, uint32_t num_docs = csdata.num_docs(); uint32_t nqi = corpus_nqi_map_[corpus_term_info.value]; float idf = - nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi - 0.5f)) : 0.0f; + nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi + 0.5f)) : 0.0f; corpus_idf_map_.insert({corpus_term_info.value, idf}); ICING_VLOG(1) << IcingStringUtil::StringPrintf( "corpus_id:%d term:%s N:%d nqi:%d idf:%f", corpus_id, @@ -158,6 +162,11 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, } // Get per corpus average document length and cache the result in the map. +// The average doc length is calculated as: +// +// total_tokens_in_corpus +// Avg Doc Length = ------------------------- +// num_docs_in_corpus + 1 float Bm25fCalculator::GetCorpusAvgDocLength(CorpusId corpus_id) { auto iter = corpus_avgdl_map_.find(corpus_id); if (iter != corpus_avgdl_map_.end()) { @@ -191,8 +200,8 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency( const DocumentAssociatedScoreData& data) { uint32_t dl = data.length_in_tokens(); float avgdl = GetCorpusAvgDocLength(data.corpus_id()); - float f_q = - ComputeTermFrequencyForMatchedSections(data.corpus_id(), term_match_info); + float f_q = ComputeTermFrequencyForMatchedSections( + data.corpus_id(), term_match_info, hit_info.document_id()); float normalized_tf = f_q * (k1_ + 1) / (f_q + k1_ * (1 - b_ + b_ * dl / avgdl)); @@ -202,23 +211,41 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency( return normalized_tf; } -// Note: once we support section weights, we should update this function to -// compute the weighted term frequency. float Bm25fCalculator::ComputeTermFrequencyForMatchedSections( - CorpusId corpus_id, const TermMatchInfo& term_match_info) const { + CorpusId corpus_id, const TermMatchInfo& term_match_info, + DocumentId document_id) const { float sum = 0.0f; SectionIdMask sections = term_match_info.section_ids_mask; + SchemaTypeId schema_type_id = GetSchemaTypeId(document_id); + while (sections != 0) { SectionId section_id = __builtin_ctz(sections); sections &= ~(1u << section_id); Hit::TermFrequency tf = term_match_info.term_frequencies[section_id]; + double weighted_tf = tf * section_weights_->GetNormalizedSectionWeight( + schema_type_id, section_id); if (tf != Hit::kNoTermFrequency) { - sum += tf; + sum += weighted_tf; } } return sum; } +SchemaTypeId Bm25fCalculator::GetSchemaTypeId(DocumentId document_id) const { + auto filter_data_or = document_store_->GetDocumentFilterData(document_id); + if (!filter_data_or.ok()) { + // This should never happen. The only failure case for + // GetDocumentFilterData is if the document_id is outside of the range of + // allocated document_ids, which shouldn't be possible since we're getting + // this document_id from the posting lists. + ICING_LOG(WARNING) << IcingStringUtil::StringPrintf( + "No document filter data for document [%d]", document_id); + return kInvalidSchemaTypeId; + } + DocumentFilterData data = filter_data_or.ValueOrDie(); + return data.schema_type_id(); +} + } // namespace lib } // namespace icing diff --git a/icing/scoring/bm25f-calculator.h b/icing/scoring/bm25f-calculator.h index 91b4f24..05009d8 100644 --- a/icing/scoring/bm25f-calculator.h +++ b/icing/scoring/bm25f-calculator.h @@ -22,6 +22,7 @@ #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/legacy/index/icing-bit-util.h" +#include "icing/scoring/section-weights.h" #include "icing/store/corpus-id.h" #include "icing/store/document-store.h" @@ -62,7 +63,8 @@ namespace lib { // see: glossary/bm25 class Bm25fCalculator { public: - explicit Bm25fCalculator(const DocumentStore *document_store_); + explicit Bm25fCalculator(const DocumentStore *document_store_, + std::unique_ptr<SectionWeights> section_weights_); // Precompute and cache statistics relevant to BM25F. // Populates term_id_map_ and corpus_nqi_map_ for use while scoring other @@ -108,18 +110,43 @@ class Bm25fCalculator { } }; + // Returns idf weight for the term and provided corpus. float GetCorpusIdfWeightForTerm(std::string_view term, CorpusId corpus_id); + + // Returns the average document length for the corpus. The average is + // calculated as the sum of tokens in the corpus' documents over the total + // number of documents plus one. float GetCorpusAvgDocLength(CorpusId corpus_id); + + // Returns the normalized term frequency for the term match and document hit. + // This normalizes the term frequency by applying smoothing parameters and + // factoring document length. float ComputedNormalizedTermFrequency( const TermMatchInfo &term_match_info, const DocHitInfo &hit_info, const DocumentAssociatedScoreData &data); + + // Returns the weighted term frequency for the term match and document. For + // each section the term is present, we scale the term frequency by its + // section weight. We return the sum of the weighted term frequencies over all + // sections. float ComputeTermFrequencyForMatchedSections( - CorpusId corpus_id, const TermMatchInfo &term_match_info) const; + CorpusId corpus_id, const TermMatchInfo &term_match_info, + DocumentId document_id) const; + // Returns the schema type id for the document by retrieving it from the + // DocumentFilterData. + SchemaTypeId GetSchemaTypeId(DocumentId document_id) const; + + // Clears cached scoring data and prepares the calculator for a new scoring + // run. void Clear(); const DocumentStore *document_store_; // Does not own. + // Used for accessing normalized section weights when computing the weighted + // term frequency. + std::unique_ptr<SectionWeights> section_weights_; + // Map from query term to compact term ID. // Necessary as a key to the other maps. // The use of the string_view as key here means that the query_term_iterators @@ -130,7 +157,6 @@ class Bm25fCalculator { // Necessary to calculate the normalized term frequency. // This information is cached in the DocumentStore::CorpusScoreCache std::unordered_map<CorpusId, float> corpus_avgdl_map_; - // Map from <corpus ID, term ID> to number of documents containing term q_i, // called n(q_i). // Necessary to calculate IDF(q_i) (inverse document frequency). diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc index e940e98..cc1d995 100644 --- a/icing/scoring/score-and-rank_benchmark.cc +++ b/icing/scoring/score-and-rank_benchmark.cc @@ -117,7 +117,8 @@ void BM_ScoreAndRankDocumentHitsByDocumentScore(benchmark::State& state) { scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get())); + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -220,7 +221,8 @@ void BM_ScoreAndRankDocumentHitsByCreationTime(benchmark::State& state) { ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get())); + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -322,7 +324,8 @@ void BM_ScoreAndRankDocumentHitsNoScoring(benchmark::State& state) { scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::NONE); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get())); + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -390,6 +393,122 @@ BENCHMARK(BM_ScoreAndRankDocumentHitsNoScoring) ->ArgPair(10000, 18000) ->ArgPair(10000, 20000); +void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) { + const std::string base_dir = GetTestTempDir() + "/score_and_rank_benchmark"; + const std::string document_store_dir = base_dir + "/document_store"; + const std::string schema_store_dir = base_dir + "/schema_store"; + + // Creates file directories + Filesystem filesystem; + filesystem.DeleteDirectoryRecursively(base_dir.c_str()); + filesystem.CreateDirectoryRecursively(document_store_dir.c_str()); + filesystem.CreateDirectoryRecursively(schema_store_dir.c_str()); + + Clock clock; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SchemaStore> schema_store, + SchemaStore::Create(&filesystem, base_dir, &clock)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem, document_store_dir, &clock, + schema_store.get())); + std::unique_ptr<DocumentStore> document_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK(schema_store->SetSchema(CreateSchemaWithEmailType())); + + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); + + int num_to_score = state.range(0); + int num_of_documents = state.range(1); + + std::mt19937 random_generator; + std::uniform_int_distribution<int> distribution( + 1, std::numeric_limits<int>::max()); + + SectionId section_id = 0; + SectionIdMask section_id_mask = 1U << section_id; + + // Puts documents into document store + std::vector<DocHitInfo> doc_hit_infos; + for (int i = 0; i < num_of_documents; i++) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store->Put(CreateEmailDocument( + /*id=*/i, /*document_score=*/1, + /*creation_timestamp_ms=*/1), + /*num_tokens=*/10)); + DocHitInfo doc_hit = DocHitInfo(document_id, section_id_mask); + // Set five matches for term "foo" for each document hit. + doc_hit.UpdateSection(section_id, /*hit_term_frequency=*/5); + doc_hit_infos.push_back(doc_hit); + } + + ScoredDocumentHitComparator scored_document_hit_comparator( + /*is_descending=*/true); + + for (auto _ : state) { + // Creates a dummy DocHitInfoIterator with results, we need to pause the + // timer here so that the cost of copying test data is not included. + state.PauseTiming(); + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + // Create a query term iterator that assigns the document hits to term + // "foo". + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + state.ResumeTiming(); + + std::vector<ScoredDocumentHit> scored_document_hits = + scoring_processor->Score(std::move(doc_hit_info_iterator), num_to_score, + &query_term_iterators); + + BuildHeapInPlace(&scored_document_hits, scored_document_hit_comparator); + // Ranks and gets the first page, 20 is a common page size + std::vector<ScoredDocumentHit> results = + PopTopResultsFromHeap(&scored_document_hits, /*num_results=*/20, + scored_document_hit_comparator); + } + + // Clean up + document_store.reset(); + schema_store.reset(); + filesystem.DeleteDirectoryRecursively(base_dir.c_str()); +} +BENCHMARK(BM_ScoreAndRankDocumentHitsByRelevanceScoring) + // num_to_score, num_of_documents in document store + ->ArgPair(1000, 30000) + ->ArgPair(3000, 30000) + ->ArgPair(5000, 30000) + ->ArgPair(7000, 30000) + ->ArgPair(9000, 30000) + ->ArgPair(11000, 30000) + ->ArgPair(13000, 30000) + ->ArgPair(15000, 30000) + ->ArgPair(17000, 30000) + ->ArgPair(19000, 30000) + ->ArgPair(21000, 30000) + ->ArgPair(23000, 30000) + ->ArgPair(25000, 30000) + ->ArgPair(27000, 30000) + ->ArgPair(29000, 30000) + // Starting from this line, we're trying to see if num_of_documents affects + // performance + ->ArgPair(10000, 10000) + ->ArgPair(10000, 12000) + ->ArgPair(10000, 14000) + ->ArgPair(10000, 16000) + ->ArgPair(10000, 18000) + ->ArgPair(10000, 20000); + } // namespace } // namespace lib diff --git a/icing/scoring/scorer.cc b/icing/scoring/scorer.cc index a4734b4..5f33e66 100644 --- a/icing/scoring/scorer.cc +++ b/icing/scoring/scorer.cc @@ -22,6 +22,7 @@ #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/scoring.pb.h" #include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/util/status-macros.h" @@ -156,11 +157,12 @@ class NoScorer : public Scorer { }; libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( - ScoringSpecProto::RankingStrategy::Code rank_by, double default_score, - const DocumentStore* document_store) { + const ScoringSpecProto& scoring_spec, double default_score, + const DocumentStore* document_store, const SchemaStore* schema_store) { ICING_RETURN_ERROR_IF_NULL(document_store); + ICING_RETURN_ERROR_IF_NULL(schema_store); - switch (rank_by) { + switch (scoring_spec.rank_by()) { case ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE: return std::make_unique<DocumentScoreScorer>(document_store, default_score); @@ -168,7 +170,12 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( return std::make_unique<DocumentCreationTimestampScorer>(document_store, default_score); case ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE: { - auto bm25f_calculator = std::make_unique<Bm25fCalculator>(document_store); + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store, scoring_spec)); + + auto bm25f_calculator = std::make_unique<Bm25fCalculator>( + document_store, std::move(section_weights)); return std::make_unique<RelevanceScoreScorer>(std::move(bm25f_calculator), default_score); } @@ -183,8 +190,8 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP: [[fallthrough]]; case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP: - return std::make_unique<UsageScorer>(document_store, rank_by, - default_score); + return std::make_unique<UsageScorer>( + document_store, scoring_spec.rank_by(), default_score); case ScoringSpecProto::RankingStrategy::NONE: return std::make_unique<NoScorer>(default_score); } diff --git a/icing/scoring/scorer.h b/icing/scoring/scorer.h index a22db0f..abdd5ca 100644 --- a/icing/scoring/scorer.h +++ b/icing/scoring/scorer.h @@ -43,8 +43,8 @@ class Scorer { // FAILED_PRECONDITION on any null pointer input // INVALID_ARGUMENT if fails to create an instance static libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( - ScoringSpecProto::RankingStrategy::Code rank_by, double default_score, - const DocumentStore* document_store); + const ScoringSpecProto& scoring_spec, double default_score, + const DocumentStore* document_store, const SchemaStore* schema_store); // Returns a non-negative score of a document. The score can be a // document-associated score which comes from the DocumentProto directly, an diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc index 8b89514..f22a31a 100644 --- a/icing/scoring/scorer_test.cc +++ b/icing/scoring/scorer_test.cc @@ -27,6 +27,7 @@ #include "icing/proto/scoring.pb.h" #include "icing/schema-builder.h" #include "icing/schema/schema-store.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" @@ -91,6 +92,8 @@ class ScorerTest : public testing::Test { DocumentStore* document_store() { return document_store_.get(); } + SchemaStore* schema_store() { return schema_store_.get(); } + const FakeClock& fake_clock1() { return fake_clock1_; } const FakeClock& fake_clock2() { return fake_clock2_; } @@ -121,17 +124,37 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri, return usage_report; } -TEST_F(ScorerTest, CreationWithNullPointerShouldFail) { - EXPECT_THAT(Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/0, /*document_store=*/nullptr), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +ScoringSpecProto CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::Code ranking_strategy) { + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by(ranking_strategy); + return scoring_spec; +} + +TEST_F(ScorerTest, CreationWithNullDocumentStoreShouldFail) { + EXPECT_THAT( + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, /*document_store=*/nullptr, + schema_store()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(ScorerTest, CreationWithNullSchemaStoreShouldFail) { + EXPECT_THAT( + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, document_store(), + /*schema_store=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); // Non existent document id DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1); @@ -153,8 +176,9 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -185,8 +209,9 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -213,8 +238,9 @@ TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); @@ -235,8 +261,9 @@ TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5)); @@ -259,8 +286,9 @@ TEST_F(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); @@ -290,8 +318,9 @@ TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) { document_store()->Put(test_document2)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(document_id1); DocHitInfo docHitInfo2 = DocHitInfo(document_id2); @@ -316,16 +345,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -357,16 +389,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -398,16 +433,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -439,19 +477,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -499,19 +540,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -559,19 +603,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -607,8 +654,9 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::NONE, - /*default_score=*/3, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE), + /*default_score=*/3, document_store(), schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0); DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1); @@ -618,8 +666,10 @@ TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, Scorer::Create(ScoringSpecProto::RankingStrategy::NONE, - /*default_score=*/111, document_store())); + scorer, + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE), + /*default_score=*/111, document_store(), schema_store())); docHitInfo1 = DocHitInfo(/*document_id_in=*/4); docHitInfo2 = DocHitInfo(/*document_id_in=*/5); @@ -643,9 +693,10 @@ TEST_F(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); // Create usage report for the maximum allowable timestamp. diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc index 24480ef..e36f3bb 100644 --- a/icing/scoring/scoring-processor.cc +++ b/icing/scoring/scoring-processor.cc @@ -39,19 +39,20 @@ constexpr double kDefaultScoreInAscendingOrder = libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> ScoringProcessor::Create(const ScoringSpecProto& scoring_spec, - const DocumentStore* document_store) { + const DocumentStore* document_store, + const SchemaStore* schema_store) { ICING_RETURN_ERROR_IF_NULL(document_store); + ICING_RETURN_ERROR_IF_NULL(schema_store); bool is_descending_order = scoring_spec.order_by() == ScoringSpecProto::Order::DESC; ICING_ASSIGN_OR_RETURN( std::unique_ptr<Scorer> scorer, - Scorer::Create(scoring_spec.rank_by(), + Scorer::Create(scoring_spec, is_descending_order ? kDefaultScoreInDescendingOrder : kDefaultScoreInAscendingOrder, - document_store)); - + document_store, schema_store)); // Using `new` to access a non-public constructor. return std::unique_ptr<ScoringProcessor>( new ScoringProcessor(std::move(scorer))); diff --git a/icing/scoring/scoring-processor.h b/icing/scoring/scoring-processor.h index 2289605..e7d09b1 100644 --- a/icing/scoring/scoring-processor.h +++ b/icing/scoring/scoring-processor.h @@ -40,8 +40,8 @@ class ScoringProcessor { // A ScoringProcessor on success // FAILED_PRECONDITION on any null pointer input static libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> Create( - const ScoringSpecProto& scoring_spec, - const DocumentStore* document_store); + const ScoringSpecProto& scoring_spec, const DocumentStore* document_store, + const SchemaStore* schema_store); // Assigns scores to DocHitInfos from the given DocHitInfoIterator and returns // a vector of ScoredDocumentHits. The size of results is no more than diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc index 125e2a7..7e5cb0f 100644 --- a/icing/scoring/scoring-processor_test.cc +++ b/icing/scoring/scoring-processor_test.cc @@ -69,11 +69,24 @@ class ScoringProcessorTest : public testing::Test { // Creates a simple email schema SchemaProto test_email_schema = SchemaBuilder() - .AddType(SchemaTypeConfigBuilder().SetType("email").AddProperty( - PropertyConfigBuilder() - .SetName("subject") - .SetDataType(TYPE_STRING) - .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType(SchemaTypeConfigBuilder() + .SetType("email") + .AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL))) .Build(); ICING_ASSERT_OK(schema_store_->SetSchema(test_email_schema)); } @@ -86,6 +99,8 @@ class ScoringProcessorTest : public testing::Test { DocumentStore* document_store() { return document_store_.get(); } + SchemaStore* schema_store() { return schema_store_.get(); } + private: const std::string test_dir_; const std::string doc_store_dir_; @@ -139,16 +154,46 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri, return usage_report; } -TEST_F(ScoringProcessorTest, CreationWithNullPointerShouldFail) { +TypePropertyWeights CreateTypePropertyWeights( + std::string schema_type, std::vector<PropertyWeight> property_weights) { + TypePropertyWeights type_property_weights; + type_property_weights.set_schema_type(std::move(schema_type)); + type_property_weights.mutable_property_weights()->Reserve( + property_weights.size()); + + for (PropertyWeight& property_weight : property_weights) { + *type_property_weights.add_property_weights() = std::move(property_weight); + } + + return type_property_weights; +} + +PropertyWeight CreatePropertyWeight(std::string path, double weight) { + PropertyWeight property_weight; + property_weight.set_path(std::move(path)); + property_weight.set_weight(weight); + return property_weight; +} + +TEST_F(ScoringProcessorTest, CreationWithNullDocumentStoreShouldFail) { + ScoringSpecProto spec_proto; + EXPECT_THAT(ScoringProcessor::Create(spec_proto, /*document_store=*/nullptr, + schema_store()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(ScoringProcessorTest, CreationWithNullSchemaStoreShouldFail) { ScoringSpecProto spec_proto; - EXPECT_THAT(ScoringProcessor::Create(spec_proto, /*document_store=*/nullptr), + EXPECT_THAT(ScoringProcessor::Create(spec_proto, document_store(), + /*schema_store=*/nullptr), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScoringProcessorTest, ShouldCreateInstance) { ScoringSpecProto spec_proto; spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); - ICING_EXPECT_OK(ScoringProcessor::Create(spec_proto, document_store())); + ICING_EXPECT_OK( + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); } TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { @@ -163,7 +208,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/5), @@ -189,7 +234,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/-1), @@ -219,7 +264,7 @@ TEST_F(ScoringProcessorTest, ShouldRespectNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/2), @@ -251,7 +296,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByDocumentScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -306,7 +351,7 @@ TEST_F(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -316,11 +361,11 @@ TEST_F(ScoringProcessorTest, // the document's length determines the final score. Document shorter than the // average corpus length are slightly boosted. ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask, - /*score=*/0.255482); + /*score=*/0.187114); ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask, - /*score=*/0.115927); + /*score=*/0.084904); ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask, - /*score=*/0.166435); + /*score=*/0.121896); EXPECT_THAT( scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3, &query_term_iterators), @@ -375,7 +420,7 @@ TEST_F(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -384,11 +429,11 @@ TEST_F(ScoringProcessorTest, // Since the three documents all contain the query term "foo" exactly once // and they have the same length, they will have the same BM25F scoret. ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask, - /*score=*/0.16173716); + /*score=*/0.118455); ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask, - /*score=*/0.16173716); + /*score=*/0.118455); ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask, - /*score=*/0.16173716); + /*score=*/0.118455); EXPECT_THAT( scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3, &query_term_iterators), @@ -448,7 +493,7 @@ TEST_F(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -457,11 +502,11 @@ TEST_F(ScoringProcessorTest, // Since the three documents all have the same length, the score is decided by // the frequency of the query term "foo". ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask1, - /*score=*/0.309497); + /*score=*/0.226674); ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask2, - /*score=*/0.16173716); + /*score=*/0.118455); ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask3, - /*score=*/0.268599); + /*score=*/0.196720); EXPECT_THAT( scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3, &query_term_iterators), @@ -470,6 +515,280 @@ TEST_F(ScoringProcessorTest, EqualsScoredDocumentHit(expected_scored_doc_hit3))); } +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_HitTermWithZeroFrequency) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/10)); + + // Document 1 contains the term "foo" 0 times in the "subject" property + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/0); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1}; + + // Creates a dummy DocHitInfoIterator with 1 result for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask section_id_mask1 = 0b00000001; + + // Since the document hit has zero frequency, expect a score of zero. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask1, + /*score=*/0.000000); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/1, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_SameHitFrequencyDifferentPropertyWeights) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store()->Put(document2, /*num_tokens=*/1)); + + // Document 1 contains the term "foo" 1 time in the "body" property + SectionId body_section_id = 0; + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); + + // Document 2 contains the term "foo" 1 time in the "subject" property + SectionId subject_section_id = 1; + DocHitInfo doc_hit_info2(document_id2); + doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2}; + + // Creates a dummy DocHitInfoIterator with 2 results for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + PropertyWeight body_property_weight = + CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5); + PropertyWeight subject_property_weight = + CreatePropertyWeight(/*path=*/"subject", /*weight=*/2.0); + *spec_proto.add_type_property_weights() = CreateTypePropertyWeights( + /*schema_type=*/"email", {body_property_weight, subject_property_weight}); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask body_section_id_mask = 1U << body_section_id; + SectionIdMask subject_section_id_mask = 1U << subject_section_id; + + // We expect document 2 to have a higher score than document 1 as it matches + // "foo" in the "subject" property, which is weighed higher than the "body" + // property. Final scores are computed with smoothing applied. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, body_section_id_mask, + /*score=*/0.053624); + ScoredDocumentHit expected_scored_doc_hit2(document_id2, + subject_section_id_mask, + /*score=*/0.153094); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/2, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1), + EqualsScoredDocumentHit(expected_scored_doc_hit2))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_WithImplicitPropertyWeight) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store()->Put(document2, /*num_tokens=*/1)); + + // Document 1 contains the term "foo" 1 time in the "body" property + SectionId body_section_id = 0; + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); + + // Document 2 contains the term "foo" 1 time in the "subject" property + SectionId subject_section_id = 1; + DocHitInfo doc_hit_info2(document_id2); + doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2}; + + // Creates a dummy DocHitInfoIterator with 2 results for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + PropertyWeight body_property_weight = + CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5); + *spec_proto.add_type_property_weights() = CreateTypePropertyWeights( + /*schema_type=*/"email", {body_property_weight}); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask body_section_id_mask = 1U << body_section_id; + SectionIdMask subject_section_id_mask = 1U << subject_section_id; + + // We expect document 2 to have a higher score than document 1 as it matches + // "foo" in the "subject" property, which is weighed higher than the "body" + // property. This is because the "subject" property is implictly given a + // a weight of 1.0, the default weight value. Final scores are computed with + // smoothing applied. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, body_section_id_mask, + /*score=*/0.094601); + ScoredDocumentHit expected_scored_doc_hit2(document_id2, + subject_section_id_mask, + /*score=*/0.153094); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/2, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1), + EqualsScoredDocumentHit(expected_scored_doc_hit2))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_WithDefaultPropertyWeight) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/1)); + + // Document 1 contains the term "foo" 1 time in the "body" property + SectionId body_section_id = 0; + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1}; + + // Creates a dummy DocHitInfoIterator with 1 result for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + *spec_proto.add_type_property_weights() = + CreateTypePropertyWeights(/*schema_type=*/"email", {}); + + // Creates a ScoringProcessor with no explicit weights set. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + ScoringSpecProto spec_proto_with_weights; + spec_proto_with_weights.set_rank_by( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + PropertyWeight body_property_weight = CreatePropertyWeight(/*path=*/"body", + /*weight=*/1.0); + *spec_proto_with_weights.add_type_property_weights() = + CreateTypePropertyWeights(/*schema_type=*/"email", + {body_property_weight}); + + // Creates a ScoringProcessor with default weight set for "body" property. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor_with_weights, + ScoringProcessor::Create(spec_proto_with_weights, document_store(), + schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + // Create a doc hit iterator + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators_scoring_with_weights; + query_term_iterators_scoring_with_weights["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask body_section_id_mask = 1U << body_section_id; + + // We expect document 1 to have the same score whether a weight is explicitly + // set to 1.0 or implictly scored with the default weight. Final scores are + // computed with smoothing applied. + ScoredDocumentHit expected_scored_doc_hit(document_id1, body_section_id_mask, + /*score=*/0.208191); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/1, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit))); + + // Restore ownership of doc hit iterator and query term iterator to test. + doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + EXPECT_THAT(scoring_processor_with_weights->Score( + std::move(doc_hit_info_iterator), + /*num_to_score=*/1, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit))); +} + TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -509,7 +828,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -569,7 +888,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageCount) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -629,7 +948,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -665,7 +984,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNoScores) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/4), ElementsAre(EqualsScoredDocumentHit(scored_document_hit_default), @@ -714,7 +1033,7 @@ TEST_F(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), diff --git a/icing/scoring/section-weights.cc b/icing/scoring/section-weights.cc new file mode 100644 index 0000000..c4afe7f --- /dev/null +++ b/icing/scoring/section-weights.cc @@ -0,0 +1,146 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/section-weights.h" + +#include <cfloat> +#include <unordered_map> +#include <utility> + +#include "icing/proto/scoring.pb.h" +#include "icing/schema/section.h" +#include "icing/util/logging.h" + +namespace icing { +namespace lib { + +namespace { + +// Normalizes all weights in the map to be in range (0.0, 1.0], where the max +// weight is normalized to 1.0. +inline void NormalizeSectionWeights( + double max_weight, std::unordered_map<SectionId, double>& section_weights) { + for (auto& raw_weight : section_weights) { + raw_weight.second = raw_weight.second / max_weight; + } +} +} // namespace + +libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>> +SectionWeights::Create(const SchemaStore* schema_store, + const ScoringSpecProto& scoring_spec) { + ICING_RETURN_ERROR_IF_NULL(schema_store); + + std::unordered_map<SchemaTypeId, NormalizedSectionWeights> + schema_property_weight_map; + for (const TypePropertyWeights& type_property_weights : + scoring_spec.type_property_weights()) { + std::string_view schema_type = type_property_weights.schema_type(); + auto schema_type_id_or = schema_store->GetSchemaTypeId(schema_type); + if (!schema_type_id_or.ok()) { + ICING_LOG(WARNING) << "No schema type id found for schema type: " + << schema_type; + continue; + } + SchemaTypeId schema_type_id = schema_type_id_or.ValueOrDie(); + auto section_metadata_list_or = + schema_store->GetSectionMetadata(schema_type.data()); + if (!section_metadata_list_or.ok()) { + ICING_LOG(WARNING) << "No metadata found for schema type: " + << schema_type; + continue; + } + + const std::vector<SectionMetadata>* metadata_list = + section_metadata_list_or.ValueOrDie(); + + std::unordered_map<std::string, double> property_paths_weights; + for (const PropertyWeight& property_weight : + type_property_weights.property_weights()) { + double property_path_weight = property_weight.weight(); + + // Return error on negative and zero weights. + if (property_path_weight <= 0.0) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Property weight for property path \"%s\" is negative or zero. " + "Negative and zero weights are invalid.", + property_weight.path().c_str())); + } + property_paths_weights.insert( + {property_weight.path(), property_path_weight}); + } + NormalizedSectionWeights normalized_section_weights = + ExtractNormalizedSectionWeights(property_paths_weights, *metadata_list); + + schema_property_weight_map.insert( + {schema_type_id, + {/*section_weights*/ std::move( + normalized_section_weights.section_weights), + /*default_weight*/ normalized_section_weights.default_weight}}); + } + // Using `new` to access a non-public constructor. + return std::unique_ptr<SectionWeights>( + new SectionWeights(std::move(schema_property_weight_map))); +} + +double SectionWeights::GetNormalizedSectionWeight(SchemaTypeId schema_type_id, + SectionId section_id) const { + auto schema_type_map = schema_section_weight_map_.find(schema_type_id); + if (schema_type_map == schema_section_weight_map_.end()) { + // Return default weight if the schema type has no weights specified. + return kDefaultSectionWeight; + } + + auto section_weight = + schema_type_map->second.section_weights.find(section_id); + if (section_weight == schema_type_map->second.section_weights.end()) { + // If there is no entry for SectionId, the weight is implicitly the + // normalized default weight. + return schema_type_map->second.default_weight; + } + return section_weight->second; +} + +inline SectionWeights::NormalizedSectionWeights +SectionWeights::ExtractNormalizedSectionWeights( + const std::unordered_map<std::string, double>& raw_weights, + const std::vector<SectionMetadata>& metadata_list) { + double max_weight = 0.0; + std::unordered_map<SectionId, double> section_weights; + for (const SectionMetadata& section_metadata : metadata_list) { + std::string_view metadata_path = section_metadata.path; + double section_weight = kDefaultSectionWeight; + auto iter = raw_weights.find(metadata_path.data()); + if (iter != raw_weights.end()) { + section_weight = iter->second; + section_weights.insert({section_metadata.id, section_weight}); + } + // Replace max if we see new max weight. + max_weight = std::max(max_weight, section_weight); + } + + NormalizeSectionWeights(max_weight, section_weights); + // Set normalized default weight to 1.0 in case there is no section + // metadata and max_weight is 0.0 (we should not see this case). + double normalized_default_weight = max_weight == 0.0 + ? kDefaultSectionWeight + : kDefaultSectionWeight / max_weight; + SectionWeights::NormalizedSectionWeights normalized_section_weights = + SectionWeights::NormalizedSectionWeights(); + normalized_section_weights.section_weights = std::move(section_weights); + normalized_section_weights.default_weight = normalized_default_weight; + return normalized_section_weights; +} +} // namespace lib +} // namespace icing diff --git a/icing/scoring/section-weights.h b/icing/scoring/section-weights.h new file mode 100644 index 0000000..23a9188 --- /dev/null +++ b/icing/scoring/section-weights.h @@ -0,0 +1,95 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_SECTION_WEIGHTS_H_ +#define ICING_SCORING_SECTION_WEIGHTS_H_ + +#include <unordered_map> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/schema/schema-store.h" +#include "icing/store/document-store.h" + +namespace icing { +namespace lib { + +inline constexpr double kDefaultSectionWeight = 1.0; + +// Provides functions for setting and retrieving section weights for schema +// type properties. Section weights are used to promote and demote term matches +// in sections when scoring results. Section weights are provided by property +// path, and can range from (0, DBL_MAX]. The SectionId is matched to the +// property path by going over the schema type's section metadata. Weights that +// correspond to a valid property path are then normalized against the maxmium +// section weight, and put into map for quick access for scorers. By default, +// a section is given a raw, pre-normalized weight of 1.0. +class SectionWeights { + public: + // SectionWeights instances should not be copied. + SectionWeights(const SectionWeights&) = delete; + SectionWeights& operator=(const SectionWeights&) = delete; + + // Factory function to create a SectionWeights instance. Raw weights are + // provided through the ScoringSpecProto. Provided property paths for weights + // are validated against the schema type's section metadata. If the property + // path doesn't exist, the property weight is ignored. If a weight is 0 or + // negative, an invalid argument error is returned. Raw weights are then + // normalized against the maximum weight for that schema type. + // + // Returns: + // A SectionWeights instance on success + // FAILED_PRECONDITION on any null pointer input + // INVALID_ARGUMENT if a provided weight for a property path is less than or + // equal to 0. + static libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>> Create( + const SchemaStore* schema_store, const ScoringSpecProto& scoring_spec); + + // Returns the normalized section weight by SchemaTypeId and SectionId. If + // the SchemaTypeId, or the SectionId for a SchemaTypeId, is not found in the + // normalized weights map, the default weight is returned instead. + double GetNormalizedSectionWeight(SchemaTypeId schema_type_id, + SectionId section_id) const; + + private: + // Holds the normalized section weights for a schema type, as well as the + // normalized default weight for sections that have no weight set. + struct NormalizedSectionWeights { + std::unordered_map<SectionId, double> section_weights; + double default_weight; + }; + + explicit SectionWeights( + const std::unordered_map<SchemaTypeId, NormalizedSectionWeights> + schema_section_weight_map) + : schema_section_weight_map_(std::move(schema_section_weight_map)) {} + + // Creates a map of section ids to normalized weights from the raw property + // path weight map and section metadata and calculates the normalized default + // section weight. + static inline SectionWeights::NormalizedSectionWeights + ExtractNormalizedSectionWeights( + const std::unordered_map<std::string, double>& raw_weights, + const std::vector<SectionMetadata>& metadata_list); + + // A map of (SchemaTypeId -> SectionId -> Normalized Weight), allows for fast + // look up of normalized weights. This is precomputed when creating a + // SectionWeights instance. + std::unordered_map<SchemaTypeId, NormalizedSectionWeights> + schema_section_weight_map_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_SECTION_WEIGHTS_H_ diff --git a/icing/scoring/section-weights_test.cc b/icing/scoring/section-weights_test.cc new file mode 100644 index 0000000..b90c3d5 --- /dev/null +++ b/icing/scoring/section-weights_test.cc @@ -0,0 +1,386 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/section-weights.h" + +#include <cfloat> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/proto/scoring.pb.h" +#include "icing/schema-builder.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { +using ::testing::Eq; + +class SectionWeightsTest : public testing::Test { + protected: + SectionWeightsTest() + : test_dir_(GetTestTempDir() + "/icing"), + schema_store_dir_(test_dir_ + "/schema_store") {} + + void SetUp() override { + // Creates file directories + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + + SchemaTypeConfigProto sender_schema = + SchemaTypeConfigBuilder() + .SetType("sender") + .AddProperty(PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .Build(); + SchemaTypeConfigProto email_schema = + SchemaTypeConfigBuilder() + .SetType("email") + .AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(PropertyConfigProto_DataType_Code_STRING) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(PropertyConfigProto_DataType_Code_STRING) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("sender") + .SetDataTypeDocument( + "sender", /*index_nested_properties=*/true) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .Build(); + SchemaProto schema = + SchemaBuilder().AddType(sender_schema).AddType(email_schema).Build(); + + ICING_ASSERT_OK(schema_store_->SetSchema(schema)); + } + + void TearDown() override { + schema_store_.reset(); + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + SchemaStore *schema_store() { return schema_store_.get(); } + + private: + const std::string test_dir_; + const std::string schema_store_dir_; + Filesystem filesystem_; + FakeClock fake_clock_; + std::unique_ptr<SchemaStore> schema_store_; +}; + +TEST_F(SectionWeightsTest, ShouldNormalizeSinglePropertyWeight) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("sender"); + + PropertyWeight *property_weight = + type_property_weights->add_property_weights(); + property_weight->set_weight(5.0); + property_weight->set_path("name"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id, + schema_store()->GetSchemaTypeId("sender")); + + // section_id 0 corresponds to property "name". + // We expect 1.0 as there is only one property in the "sender" schema type + // so it should take the max normalized weight of 1.0. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id, + /*section_id=*/0), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldAcceptMaxWeightValue) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("sender"); + + PropertyWeight *property_weight = + type_property_weights->add_property_weights(); + property_weight->set_weight(DBL_MAX); + property_weight->set_path("name"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id, + schema_store()->GetSchemaTypeId("sender")); + + // section_id 0 corresponds to property "name". + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id, + /*section_id=*/0), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldFailWithNegativeWeights) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + PropertyWeight *body_propery_weight = + type_property_weights->add_property_weights(); + body_propery_weight->set_weight(-100.0); + body_propery_weight->set_path("body"); + + EXPECT_THAT(SectionWeights::Create(schema_store(), spec_proto).status(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(SectionWeightsTest, ShouldFailWithZeroWeight) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("sender"); + + PropertyWeight *property_weight = + type_property_weights->add_property_weights(); + property_weight->set_weight(0.0); + property_weight->set_path("name"); + + EXPECT_THAT(SectionWeights::Create(schema_store(), spec_proto).status(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(SectionWeightsTest, ShouldReturnDefaultIfTypePropertyWeightsNotSet) { + ScoringSpecProto spec_proto; + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(kDefaultSectionWeight)); +} + +TEST_F(SectionWeightsTest, ShouldSetNestedPropertyWeights) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + PropertyWeight *body_property_weight = + type_property_weights->add_property_weights(); + body_property_weight->set_weight(1.0); + body_property_weight->set_path("body"); + + PropertyWeight *subject_property_weight = + type_property_weights->add_property_weights(); + subject_property_weight->set_weight(100.0); + subject_property_weight->set_path("subject"); + + PropertyWeight *nested_property_weight = + type_property_weights->add_property_weights(); + nested_property_weight->set_weight(50.0); + nested_property_weight->set_path("sender.name"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + // Normalized weight for "body" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(0.01)); + // Normalized weight for "sender.name" property (the nested property). + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(0.5)); + // Normalized weight for "subject" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldNormalizeIfAllWeightsBelowOne) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + PropertyWeight *body_property_weight = + type_property_weights->add_property_weights(); + body_property_weight->set_weight(0.1); + body_property_weight->set_path("body"); + + PropertyWeight *sender_name_weight = + type_property_weights->add_property_weights(); + sender_name_weight->set_weight(0.2); + sender_name_weight->set_path("sender.name"); + + PropertyWeight *subject_property_weight = + type_property_weights->add_property_weights(); + subject_property_weight->set_weight(0.4); + subject_property_weight->set_path("subject"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + // Normalized weight for "body" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(1.0 / 4.0)); + // Normalized weight for "sender.name" property (the nested property). + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(2.0 / 4.0)); + // Normalized weight for "subject" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldSetNestedPropertyWeightSeparatelyForTypes) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *email_type_property_weights = + spec_proto.add_type_property_weights(); + email_type_property_weights->set_schema_type("email"); + + PropertyWeight *body_property_weight = + email_type_property_weights->add_property_weights(); + body_property_weight->set_weight(1.0); + body_property_weight->set_path("body"); + + PropertyWeight *subject_property_weight = + email_type_property_weights->add_property_weights(); + subject_property_weight->set_weight(100.0); + subject_property_weight->set_path("subject"); + + PropertyWeight *sender_name_property_weight = + email_type_property_weights->add_property_weights(); + sender_name_property_weight->set_weight(50.0); + sender_name_property_weight->set_path("sender.name"); + + TypePropertyWeights *sender_type_property_weights = + spec_proto.add_type_property_weights(); + sender_type_property_weights->set_schema_type("sender"); + + PropertyWeight *sender_property_weight = + sender_type_property_weights->add_property_weights(); + sender_property_weight->set_weight(25.0); + sender_property_weight->set_path("sender"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id, + schema_store()->GetSchemaTypeId("sender")); + + // Normalized weight for "sender.name" property (the nested property) + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(0.5)); + // Normalized weight for "name" property for "sender" schema type. As it is + // the only property of the type, it should take the max normalized weight of + // 1.0. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldSkipNonExistentPathWhenSettingWeights) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + // If this property weight isn't skipped, then the max property weight would + // be set to 100.0 and all weights would be normalized against the max. + PropertyWeight *non_valid_property_weight = + type_property_weights->add_property_weights(); + non_valid_property_weight->set_weight(100.0); + non_valid_property_weight->set_path("sender.organization"); + + PropertyWeight *subject_property_weight = + type_property_weights->add_property_weights(); + subject_property_weight->set_weight(10.0); + subject_property_weight->set_path("subject"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + // Normalized weight for "body" property. Because the weight is not explicitly + // set, it is set to the default of 1.0 before being normalized. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(0.1)); + // Normalized weight for "sender.name" property (the nested property). Because + // the weight is not explicitly set, it is set to the default of 1.0 before + // being normalized. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(0.1)); + // Normalized weight for "subject" property. Because the invalid property path + // is skipped when assigning weights, subject takes the max normalized weight + // of 1.0 instead. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +} // namespace + +} // namespace lib +} // namespace icing |