diff options
author | Tim Barron <tjbarron@google.com> | 2022-12-14 16:47:45 -0800 |
---|---|---|
committer | Tim Barron <tjbarron@google.com> | 2022-12-14 16:47:45 -0800 |
commit | b096c00e7fa192b00fd9a923e675fe5c512f1e0f (patch) | |
tree | 6576896481ba45dac735f75b4b2b3fa7274f2ecd | |
parent | ffe0a56986f0d400cb63a3d957ac010c3717020f (diff) | |
parent | 94b21a83007fc0ce1bda6a7518d52f4ce07fb48d (diff) | |
download | icing-b096c00e7fa192b00fd9a923e675fe5c512f1e0f.tar.gz |
Merge remote-tracking branch 'aosp/upstream-master' into androidx-main
* aosp/upstream-master:
Fix go/oag/2355951 to actually sync from upstream.
Sync from upstream.
Sync from upstream.
Descriptions:
======================================================================
Create class `QualifiedId`
======================================================================
Switch JoinProcessor to use new class FullyQualifiedId
======================================================================
Add `JoinableConfig` proto
======================================================================
Implement document-based functions for the Advanced Scoring Language
======================================================================
Support the RelevanceScore function for the Advanced Scoring Language
======================================================================
Enable the document-based member functions for Advanced Scoring Language
Bug: 256022027
Bug: 261474063
Change-Id: Ib2b2b4cfe71e1cfff2363cdfb8c19ed5b22c8983
22 files changed, 1235 insertions, 230 deletions
diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index 60e347e..68282c5 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -54,6 +54,7 @@ #include "icing/proto/storage.pb.h" #include "icing/proto/term.pb.h" #include "icing/proto/usage.pb.h" +#include "icing/query/advanced_query_parser/lexer.h" #include "icing/query/query-processor.h" #include "icing/query/query-results.h" #include "icing/query/suggestion-processor.h" @@ -64,6 +65,7 @@ #include "icing/schema/schema-store.h" #include "icing/schema/schema-util.h" #include "icing/schema/section.h" +#include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/priority-queue-scored-document-hits-ranker.h" #include "icing/scoring/scored-document-hit.h" #include "icing/scoring/scored-document-hits-ranker.h" @@ -435,6 +437,31 @@ bool ShouldRebuildIndex(const OptimizeStatsProto& optimize_stats) { return num_invalid_documents >= optimize_stats.num_original_documents() * 0.9; } +// Useful method to get RankingStrategy if advanced scoring is enabled. When the +// "RelevanceScore" function is used in the advanced scoring expression, +// RankingStrategy will be treated as RELEVANCE_SCORE in order to prepare the +// necessary information needed for calculating relevance score. +libtextclassifier3::StatusOr<ScoringSpecProto::RankingStrategy::Code> +GetRankingStrategyFromScoringSpec(const ScoringSpecProto& scoring_spec) { + if (scoring_spec.advanced_scoring_expression().empty()) { + return scoring_spec.rank_by(); + } + // TODO(b/261474063) The Lexer will be called again when creating the + // AdvancedScorer instance. Consider refactoring the code to allow the Lexer + // to be called only once. + Lexer lexer(scoring_spec.advanced_scoring_expression(), + Lexer::Language::SCORING); + ICING_ASSIGN_OR_RETURN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + for (const Lexer::LexerToken& token : lexer_tokens) { + if (token.type == Lexer::TokenType::FUNCTION_NAME && + token.text == RelevanceScoreFunctionScoreExpression::kFunctionName) { + return ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE; + } + } + return ScoringSpecProto::RankingStrategy::NONE; +} + } // namespace IcingSearchEngine::IcingSearchEngine(const IcingSearchEngineOptions& options, @@ -1842,8 +1869,14 @@ IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore( std::unique_ptr<QueryProcessor> query_processor = std::move(query_processor_or).ValueOrDie(); - auto query_results_or = - query_processor->ParseSearch(search_spec, scoring_spec.rank_by()); + auto ranking_strategy_or = GetRankingStrategyFromScoringSpec(scoring_spec); + libtextclassifier3::StatusOr<QueryResults> query_results_or; + if (ranking_strategy_or.ok()) { + query_results_or = query_processor->ParseSearch( + search_spec, ranking_strategy_or.ValueOrDie()); + } else { + query_results_or = ranking_strategy_or.status(); + } if (!query_results_or.ok()) { return QueryScoringResults( std::move(query_results_or).status(), /*query_terms_in=*/{}, diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc index 8cb7e7f..2816f70 100644 --- a/icing/icing-search-engine_test.cc +++ b/icing/icing-search-engine_test.cc @@ -27,6 +27,7 @@ #include "icing/file/filesystem.h" #include "icing/file/mock-filesystem.h" #include "icing/jni/jni-cache.h" +#include "icing/join/join-processor.h" #include "icing/legacy/index/icing-mock-filesystem.h" #include "icing/portable/endian.h" #include "icing/portable/equals-proto.h" @@ -5103,6 +5104,70 @@ TEST_F(IcingSearchEngineTest, Bm25fRelevanceScoringOneNamespace) { "namespace1/uri6")); // 'food' 1 time } +TEST_F(IcingSearchEngineTest, Bm25fRelevanceScoringOneNamespaceAdvanced) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); + EXPECT_THAT(icing.SetSchema(CreateEmailSchema()).status(), ProtoIsOk()); + + // Create and index documents in namespace "namespace1". + DocumentProto document = CreateEmailDocument( + "namespace1", "namespace1/uri0", /*score=*/10, "sushi belmont", + "fresh fish. inexpensive. good sushi."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri1", /*score=*/13, "peacock koriander", + "indian food. buffet. spicy food. kadai chicken."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri2", /*score=*/4, + "panda express", + "chinese food. cheap. inexpensive. kung pao."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri3", /*score=*/23, + "speederia pizza", + "thin-crust pizza. good and fast."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri4", /*score=*/8, + "whole foods", + "salads. pizza. organic food. expensive."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri5", /*score=*/18, "peets coffee", + "espresso. decaf. brewed coffee. whole beans. excellent coffee."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri6", /*score=*/4, "costco", + "bulk. cheap whole beans. frozen fish. food samples."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri7", /*score=*/4, + "starbucks coffee", + "habit. birthday rewards. good coffee"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); + search_spec.set_query("coffee OR food"); + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + scoring_spec.set_advanced_scoring_expression("this.relevanceScore() * 2 + 1"); + scoring_spec.set_rank_by( + ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION); + SearchResultProto search_result_proto = icing.Search( + search_spec, scoring_spec, ResultSpecProto::default_instance()); + + // Result should be in descending score order + EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + // Both doc5 and doc7 have "coffee" in name and text sections. + // However, doc5 has more matches in the text section. + // Documents with "food" are ranked lower as the term "food" is commonly + // present in this corpus, and thus, has a lower IDF. + EXPECT_THAT(GetUrisFromSearchResults(search_result_proto), + ElementsAre("namespace1/uri5", // 'coffee' 3 times + "namespace1/uri7", // 'coffee' 2 times + "namespace1/uri1", // 'food' 2 times + "namespace1/uri4", // 'food' 2 times + "namespace1/uri2", // 'food' 1 time + "namespace1/uri6")); // 'food' 1 time +} + TEST_F(IcingSearchEngineTest, Bm25fRelevanceScoringOneNamespaceNotOperator) { IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); @@ -9895,7 +9960,7 @@ TEST_F(IcingSearchEngineTest, IcingShouldWorkFor64Sections) { EqualsSearchResultIgnoreStatsAndScores(expected_no_documents)); } -TEST_F(IcingSearchEngineTest, SimpleJoin) { +TEST_F(IcingSearchEngineTest, JoinByQualifiedId) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder() @@ -9915,11 +9980,18 @@ TEST_F(IcingSearchEngineTest, SimpleJoin) { .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) .SetCardinality(CARDINALITY_OPTIONAL))) - .AddType(SchemaTypeConfigBuilder().SetType("Email").AddProperty( - PropertyConfigBuilder() - .SetName("subjectId") - .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) - .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType(SchemaTypeConfigBuilder() + .SetType("Email") + .AddProperty(PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("personQualifiedId") + .SetDataTypeJoinableString( + JOINABLE_VALUE_TYPE_QUALIFIED_ID) + .SetCardinality(CARDINALITY_OPTIONAL))) .Build(); DocumentProto person1 = @@ -9942,7 +10014,7 @@ TEST_F(IcingSearchEngineTest, SimpleJoin) { .Build(); DocumentProto person3 = DocumentBuilder() - .SetKey("pkg$db/name#space\\\\", "person3") + .SetKey(R"(pkg$db/name#space\\)", "person3") .SetSchema("Person") .AddStringProperty("firstName", "first3") .AddStringProperty("lastName", "last3") @@ -9954,21 +10026,25 @@ TEST_F(IcingSearchEngineTest, SimpleJoin) { DocumentBuilder() .SetKey("namespace", "email1") .SetSchema("Email") - .AddStringProperty("subjectId", "pkg$db/namespace#person1") + .AddStringProperty("subject", "test subject 1") + .AddStringProperty("personQualifiedId", "pkg$db/namespace#person1") .SetCreationTimestampMs(kDefaultCreationTimestampMs) .Build(); DocumentProto email2 = DocumentBuilder() .SetKey("namespace", "email2") .SetSchema("Email") - .AddStringProperty("subjectId", "pkg$db/namespace#person2") + .AddStringProperty("subject", "test subject 2") + .AddStringProperty("personQualifiedId", "pkg$db/namespace#person2") .SetCreationTimestampMs(kDefaultCreationTimestampMs) .Build(); DocumentProto email3 = DocumentBuilder() .SetKey("namespace", "email3") .SetSchema("Email") - .AddStringProperty("subjectId", "pkg$db/name\\#space\\\\#person3") + .AddStringProperty("subject", "test subject 3") + .AddStringProperty("personQualifiedId", + R"(pkg$db/name\#space\\\\#person3)") // escaped .SetCreationTimestampMs(kDefaultCreationTimestampMs) .Build(); @@ -9985,20 +10061,21 @@ TEST_F(IcingSearchEngineTest, SimpleJoin) { // Parent SearchSpec SearchSpecProto search_spec; search_spec.set_term_match_type(TermMatchType::PREFIX); - search_spec.set_query("first"); + search_spec.set_query("firstName:first"); // JoinSpec JoinSpecProto* join_spec = search_spec.mutable_join_spec(); // Set max_joined_child_count as 2, so only email 3, email2 will be included // in the nested result and email1 will be truncated. join_spec->set_max_joined_child_count(2); - join_spec->set_parent_property_expression("this.fullyQualifiedId()"); - join_spec->set_child_property_expression("subjectId"); + join_spec->set_parent_property_expression( + std::string(JoinProcessor::kFullyQualifiedIdExpr)); + join_spec->set_child_property_expression("personQualifiedId"); JoinSpecProto::NestedSpecProto* nested_spec = join_spec->mutable_nested_spec(); SearchSpecProto* nested_search_spec = nested_spec->mutable_search_spec(); nested_search_spec->set_term_match_type(TermMatchType::PREFIX); - nested_search_spec->set_query(""); + nested_search_spec->set_query("subject:test"); *nested_spec->mutable_scoring_spec() = GetDefaultScoringSpec(); *nested_spec->mutable_result_spec() = ResultSpecProto::default_instance(); diff --git a/icing/join/join-processor.cc b/icing/join/join-processor.cc index 7abd821..71fa75f 100644 --- a/icing/join/join-processor.cc +++ b/icing/join/join-processor.cc @@ -16,11 +16,14 @@ #include <algorithm> #include <functional> +#include <string> #include <string_view> #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/join/qualified-id.h" #include "icing/proto/scoring.pb.h" #include "icing/proto/search.pb.h" #include "icing/scoring/scored-document-hit.h" @@ -49,56 +52,53 @@ JoinProcessor::Join( // - If there is no cache, then we still have the flexibility to fetch it // from actual docs via DocumentStore. - // Break children down into maps. The keys of this map are the DocumentIds of - // the parent docs the child ScoredDocumentHits refer to. The values in this - // map are vectors of child ScoredDocumentHits that refer to a parent - // DocumentId. + // Step 1: group child documents by parent documentId. Currently we only + // support QualifiedId joining, so fetch the qualified id content of + // child_property_expression, break it down into namespace + uri, and + // lookup the DocumentId. + // The keys of this map are the DocumentIds of the parent docs the child + // ScoredDocumentHits refer to. The values in this map are vectors of child + // ScoredDocumentHits that refer to a parent DocumentId. std::unordered_map<DocumentId, std::vector<ScoredDocumentHit>> - parent_to_child_map; + parent_id_to_child_map; for (const ScoredDocumentHit& child : child_scored_document_hits) { std::string property_content = FetchPropertyExpressionValue( child.document_id(), join_spec.child_property_expression()); - // Try to split the property content by separators. - std::vector<int> separators_in_property_content = - GetSeparatorLocations(property_content, "#"); - - if (separators_in_property_content.size() != 1) { - // Skip the document if the qualified id isn't made up of the namespace - // and uri. StrSplit will return just the original string if there are no - // spaces. + // Parse qualified id. + libtextclassifier3::StatusOr<QualifiedId> qualified_id_or = + QualifiedId::Parse(property_content); + if (!qualified_id_or.ok()) { + ICING_VLOG(2) << "Skip content with invalid format of QualifiedId"; continue; } + QualifiedId qualified_id = std::move(qualified_id_or).ValueOrDie(); - std::string ns = - property_content.substr(0, separators_in_property_content[0]); - std::string uri = - property_content.substr(separators_in_property_content[0] + 1); - - UnescapeSeparator(ns, "#"); - UnescapeSeparator(uri, "#"); - - libtextclassifier3::StatusOr<DocumentId> doc_id_or = - doc_store_->GetDocumentId(ns, uri); - - if (!doc_id_or.ok()) { + // Lookup parent DocumentId. + libtextclassifier3::StatusOr<DocumentId> parent_doc_id_or = + doc_store_->GetDocumentId(qualified_id.name_space(), + qualified_id.uri()); + if (!parent_doc_id_or.ok()) { // Skip the document if getting errors. continue; } + DocumentId parent_doc_id = std::move(parent_doc_id_or).ValueOrDie(); - DocumentId parent_doc_id = std::move(doc_id_or).ValueOrDie(); - - // This assumes the child docs are already sorted. - if (parent_to_child_map[parent_doc_id].size() < + // Since we've already sorted child_scored_document_hits, just simply omit + // if the parent_id_to_child_map[parent_doc_id].size() has reached max + // joined child count. + if (parent_id_to_child_map[parent_doc_id].size() < join_spec.max_joined_child_count()) { - parent_to_child_map[parent_doc_id].push_back(std::move(child)); + parent_id_to_child_map[parent_doc_id].push_back(child); } } std::vector<JoinedScoredDocumentHit> joined_scored_document_hits; joined_scored_document_hits.reserve(parent_scored_document_hits.size()); - // Then add use child maps to add to parent ScoredDocumentHits. + // Step 2: iterate through all parent documentIds and construct + // JoinedScoredDocumentHit for each by looking up + // parent_id_to_child_map. for (ScoredDocumentHit& parent : parent_scored_document_hits) { DocumentId parent_doc_id = kInvalidDocumentId; if (join_spec.parent_property_expression() == kFullyQualifiedIdExpr) { @@ -106,34 +106,23 @@ JoinProcessor::Join( } else { // TODO(b/256022027): So far we only support kFullyQualifiedIdExpr for // parent_property_expression, we could support more. - return absl_ports::UnimplementedError( - join_spec.parent_property_expression() + - " must be \"fullyQualifiedId(this)\""); + return absl_ports::UnimplementedError(absl_ports::StrCat( + "Parent property expression must be ", kFullyQualifiedIdExpr)); } // TODO(b/256022027): Derive final score from - // parent_to_child_map[parent_doc_id] and + // parent_id_to_child_map[parent_doc_id] and // join_spec.aggregation_score_strategy() double final_score = parent.score(); joined_scored_document_hits.emplace_back( final_score, std::move(parent), std::vector<ScoredDocumentHit>( - std::move(parent_to_child_map[parent_doc_id]))); + std::move(parent_id_to_child_map[parent_doc_id]))); } return joined_scored_document_hits; } -// This loads a document and uses a property expression to fetch the value of -// the property from the document. The property expression may refer to nested -// document properties. We do not allow for repeated values in this property -// path, as that would allow for a single document to join to multiple -// documents. -// -// Returns: -// "" on document load error. -// "" if the property path is not found in the document. -// "" if part of the property path is a repeated value. std::string JoinProcessor::FetchPropertyExpressionValue( const DocumentId& document_id, const std::string& property_expression) const { @@ -150,31 +139,5 @@ std::string JoinProcessor::FetchPropertyExpressionValue( return std::string(GetString(&document, property_expression)); } -std::vector<int> JoinProcessor::GetSeparatorLocations( - const std::string& content, const std::string& separator) const { - std::vector<int> separators_in_property_content; - - for (int i = 0; i < content.length(); ++i) { - if (content[i] == '\\') { - // Skip the following character - i++; - } else if (content[i] == '#') { - // Unescaped separator - separators_in_property_content.push_back(i); - } - } - return separators_in_property_content; -} - -void JoinProcessor::UnescapeSeparator(std::string& property, - const std::string& separator) { - size_t start_pos = 0; - while ((start_pos = property.find("\\" + separator, start_pos)) != - std::string::npos) { - property.replace(start_pos, 2, "#"); - start_pos += 1; - } -} - } // namespace lib } // namespace icing diff --git a/icing/join/join-processor.h b/icing/join/join-processor.h index c919b22..dccea22 100644 --- a/icing/join/join-processor.h +++ b/icing/join/join-processor.h @@ -15,6 +15,8 @@ #ifndef ICING_JOIN_JOIN_PROCESSOR_H_ #define ICING_JOIN_JOIN_PROCESSOR_H_ +#include <string> +#include <string_view> #include <vector> #include "icing/text_classifier/lib3/utils/base/statusor.h" @@ -39,15 +41,22 @@ class JoinProcessor { std::vector<ScoredDocumentHit>&& child_scored_document_hits); private: + // Loads a document and uses a property expression to fetch the value of the + // property from the document. The property expression may refer to nested + // document properties. + // Note: currently we only support single joining, so we use the first element + // (index 0) for any repeated values. + // + // TODO(b/256022027): validate joinable property (and its upper-level) should + // not have REPEATED cardinality. + // + // Returns: + // "" on document load error. + // "" if the property path is not found in the document. std::string FetchPropertyExpressionValue( const DocumentId& document_id, const std::string& property_expression) const; - void UnescapeSeparator(std::string& property, const std::string& separator); - - std::vector<int> GetSeparatorLocations(const std::string& content, - const std::string& separator) const; - const DocumentStore* doc_store_; // Does not own. }; diff --git a/icing/join/qualified-id.cc b/icing/join/qualified-id.cc new file mode 100644 index 0000000..2a30c44 --- /dev/null +++ b/icing/join/qualified-id.cc @@ -0,0 +1,105 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/join/qualified-id.h" + +#include <string> +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +namespace { + +// Since we use '#' as the separator and '\' to escape '\' and '#', only these 2 +// characters are considered special characters to parse qualified id. +bool IsSpecialCharacter(char c) { + return c == QualifiedId::kEscapeChar || + c == QualifiedId::kNamespaceUriSeparator; +} + +// Helper function to verify the format (check the escape format and make sure +// number of separator '#' is 1) and find the position of the unique separator. +// +// Returns: +// A valid index of the separator on success. +// std::string::npos if the escape format of content is incorrect. +// std::string::npos if the content contains 0 or more than 1 separators. +size_t VerifyFormatAndGetSeparatorPosition(std::string_view content) { + size_t separator_pos = std::string::npos; + for (size_t i = 0; i < content.length(); ++i) { + if (content[i] == QualifiedId::kEscapeChar) { + // Advance to the next character. + ++i; + if (i >= content.length() || !IsSpecialCharacter(content[i])) { + // Invalid escape format. + return std::string::npos; + } + } else if (content[i] == QualifiedId::kNamespaceUriSeparator) { + if (separator_pos != std::string::npos) { + // Found another separator, so return std::string::npos since only one + // separator is allowed. + return std::string::npos; + } + separator_pos = i; + } + } + return separator_pos; +} + +// Helper function to unescape the content. +libtextclassifier3::StatusOr<std::string> Unescape(std::string_view content) { + std::string unescaped_content; + for (size_t i = 0; i < content.length(); ++i) { + if (content[i] == QualifiedId::kEscapeChar) { + // Advance to the next character. + ++i; + if (i >= content.length() || !IsSpecialCharacter(content[i])) { + // Invalid escape format. + return absl_ports::InvalidArgumentError("Invalid escape format"); + } + } + unescaped_content += content[i]; + } + return unescaped_content; +} + +} // namespace + +/* static */ libtextclassifier3::StatusOr<QualifiedId> QualifiedId::Parse( + std::string_view qualified_id_str) { + size_t separator_pos = VerifyFormatAndGetSeparatorPosition(qualified_id_str); + if (separator_pos == std::string::npos) { + return absl_ports::InvalidArgumentError( + "Failed to find the position of separator"); + } + + if (separator_pos == 0 || separator_pos + 1 >= qualified_id_str.length()) { + return absl_ports::InvalidArgumentError( + "Namespace or uri cannot be empty after parsing"); + } + + ICING_ASSIGN_OR_RETURN(std::string name_space, + Unescape(qualified_id_str.substr(0, separator_pos))); + ICING_ASSIGN_OR_RETURN(std::string uri, + Unescape(qualified_id_str.substr(separator_pos + 1))); + return QualifiedId(std::move(name_space), std::move(uri)); +} + +} // namespace lib +} // namespace icing diff --git a/icing/join/qualified-id.h b/icing/join/qualified-id.h new file mode 100644 index 0000000..eb6606a --- /dev/null +++ b/icing/join/qualified-id.h @@ -0,0 +1,65 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_JOIN_QUALIFIED_ID_H_ +#define ICING_JOIN_QUALIFIED_ID_H_ + +#include <string> +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" + +namespace icing { +namespace lib { + +// QualifiedId definition: namespace and uri. +// This is a wrapper class for parsing qualified id string. +// +// Qualified id string format: escape(namespace) + '#' + escape(uri). +// - Use '#' as the separator to concat namespace and uri +// - Use '\' to escape '\' and '#' in namespace and uri. +// - There should be 1 separator '#' in a qualified string, and the rest part +// should have correct escape format. +// - Raw namespace and uri cannot be empty. +class QualifiedId { + public: + static constexpr char kEscapeChar = '\\'; + static constexpr char kNamespaceUriSeparator = '#'; + + // Parses a qualified id string "<escaped(namespace)>#<escaped(uri)>" and + // creates an instance of QualifiedId. + // + // qualified_id_str: a qualified id string having the format mentioned above. + // + // Returns: + // - A QualifiedId instance with raw namespace and uri, on success. + // - INVALID_ARGUMENT_ERROR if the format of qualified_id_str is incorrect. + static libtextclassifier3::StatusOr<QualifiedId> Parse( + std::string_view qualified_id_str); + + explicit QualifiedId(std::string name_space, std::string uri) + : name_space_(std::move(name_space)), uri_(std::move(uri)) {} + + const std::string& name_space() const { return name_space_; } + const std::string& uri() const { return uri_; } + + private: + std::string name_space_; + std::string uri_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_JOIN_QUALIFIED_ID_H_ diff --git a/icing/join/qualified-id_test.cc b/icing/join/qualified-id_test.cc new file mode 100644 index 0000000..0c3750a --- /dev/null +++ b/icing/join/qualified-id_test.cc @@ -0,0 +1,141 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/join/qualified-id.h" + +#include <string> +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::Eq; + +TEST(QualifiedIdTest, ValidQualifiedIdWithoutSpecialCharacters) { + // "namespace#uri" -> "namespace" + "uri" + ICING_ASSERT_OK_AND_ASSIGN(QualifiedId id, + QualifiedId::Parse(R"(namespace#uri)")); + EXPECT_THAT(id.name_space(), Eq(R"(namespace)")); + EXPECT_THAT(id.uri(), R"(uri)"); +} + +TEST(QualifiedIdTest, ValidQualifiedIdWithEscapedSpecialCharacters) { + // "namespace\\#uri" -> "namespace\" + "uri" + ICING_ASSERT_OK_AND_ASSIGN(QualifiedId id1, + QualifiedId::Parse(R"(namespace\\#uri)")); + EXPECT_THAT(id1.name_space(), Eq(R"(namespace\)")); + EXPECT_THAT(id1.uri(), R"(uri)"); + + // "namespace\\\##uri" -> "namespace\#" + "uri" + ICING_ASSERT_OK_AND_ASSIGN(QualifiedId id2, + QualifiedId::Parse(R"(namespace\\\##uri)")); + EXPECT_THAT(id2.name_space(), Eq(R"(namespace\#)")); + EXPECT_THAT(id2.uri(), R"(uri)"); + + // "namespace#\#\\uri" -> "namespace" + "#\uri" + ICING_ASSERT_OK_AND_ASSIGN(QualifiedId id3, + QualifiedId::Parse(R"(namespace#\#\\uri)")); + EXPECT_THAT(id3.name_space(), Eq(R"(namespace)")); + EXPECT_THAT(id3.uri(), R"(#\uri)"); + + // "namespace\\\##\#\\uri" -> "namespace\#" + "#\uri" + ICING_ASSERT_OK_AND_ASSIGN(QualifiedId id4, + QualifiedId::Parse(R"(namespace\\\##\#\\uri)")); + EXPECT_THAT(id4.name_space(), Eq(R"(namespace\#)")); + EXPECT_THAT(id4.uri(), R"(#\uri)"); +} + +TEST(QualifiedIdTest, InvalidQualifiedIdWithEmptyNamespaceOrUri) { + // "#uri" + EXPECT_THAT(QualifiedId::Parse(R"(#uri)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "namespace#" + EXPECT_THAT(QualifiedId::Parse(R"(namespace#)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "#" + EXPECT_THAT(QualifiedId::Parse(R"(#)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(QualifiedIdTest, InvalidQualifiedIdWithInvalidEscape) { + // "namespace\" + // Add an additional '#' and use string_view trick to cover the index safe + // check when skipping the last '\'. + std::string str1 = R"(namespace\)" + R"(#)"; + EXPECT_THAT( + QualifiedId::Parse(std::string_view(str1.data(), str1.length() - 1)), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "names\pace#uri" + EXPECT_THAT(QualifiedId::Parse(R"(names\pace#uri)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "names\\\pace#uri" + EXPECT_THAT(QualifiedId::Parse(R"(names\\\pace#uri)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "namespace#uri\" + // Add an additional '#' and use string_view trick to cover the index safe + // check when skipping the last '\'. + std::string str2 = R"(namespace#uri\)" + R"(#)"; + EXPECT_THAT( + QualifiedId::Parse(std::string_view(str2.data(), str2.length() - 1)), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(QualifiedIdTest, InvalidQualifiedIdWithWrongNumberOfSeparators) { + // "" + EXPECT_THAT(QualifiedId::Parse(R"()"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "namespaceuri" + EXPECT_THAT(QualifiedId::Parse(R"(namespaceuri)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "namespace##uri" + EXPECT_THAT(QualifiedId::Parse(R"(namespace##uri)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "namespace#uri#others" + EXPECT_THAT(QualifiedId::Parse(R"(namespace#uri#others)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "namespace\#uri" + EXPECT_THAT(QualifiedId::Parse(R"(namespace\#uri)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "namespace\\##uri" + EXPECT_THAT(QualifiedId::Parse(R"(namespace\\##uri)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // "namespace#uri\\#others" + EXPECT_THAT(QualifiedId::Parse(R"(namespace#uri\\#)"), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/schema-builder.h b/icing/schema-builder.h index ea0a774..8d3aecb 100644 --- a/icing/schema-builder.h +++ b/icing/schema-builder.h @@ -69,6 +69,11 @@ constexpr PropertyConfigProto::DataType::Code TYPE_BYTES = constexpr PropertyConfigProto::DataType::Code TYPE_DOCUMENT = PropertyConfigProto::DataType::DOCUMENT; +constexpr JoinableConfig::ValueType::Code JOINABLE_VALUE_TYPE_NONE = + JoinableConfig::ValueType::NONE; +constexpr JoinableConfig::ValueType::Code JOINABLE_VALUE_TYPE_QUALIFIED_ID = + JoinableConfig::ValueType::QUALIFIED_ID; + class PropertyConfigBuilder { public: PropertyConfigBuilder() = default; @@ -95,6 +100,17 @@ class PropertyConfigBuilder { return *this; } + PropertyConfigBuilder& SetDataTypeJoinableString( + JoinableConfig::ValueType::Code join_value_type, + TermMatchType::Code match_type = TERM_MATCH_UNKNOWN, + StringIndexingConfig::TokenizerType::Code tokenizer = TOKENIZER_NONE) { + property_.set_data_type(PropertyConfigProto::DataType::STRING); + property_.mutable_joinable_config()->set_value_type(join_value_type); + property_.mutable_string_indexing_config()->set_term_match_type(match_type); + property_.mutable_string_indexing_config()->set_tokenizer_type(tokenizer); + return *this; + } + PropertyConfigBuilder& SetDataTypeInt64( IntegerIndexingConfig::NumericMatchType::Code numeric_match_type) { property_.set_data_type(PropertyConfigProto::DataType::INT64); diff --git a/icing/schema/schema-store.cc b/icing/schema/schema-store.cc index 5f4baa8..ddd9e3b 100644 --- a/icing/schema/schema-store.cc +++ b/icing/schema/schema-store.cc @@ -368,6 +368,11 @@ SchemaStore::SetSchema(SchemaProto&& new_schema, bool ignore_errors_and_delete_documents) { ICING_ASSIGN_OR_RETURN(SchemaUtil::DependencyMap new_dependency_map, SchemaUtil::Validate(new_schema)); + // TODO(b/256022027): validate and extract joinable properties. + // - Joinable config in non-string properties should be ignored, since + // currently we only support string joining. + // - If set joinable, the property itself and all of its parent (nested doc) + // properties should not have REPEATED cardinality. SetSchemaResult result; diff --git a/icing/scoring/advanced_scoring/advanced-scorer.cc b/icing/scoring/advanced_scoring/advanced-scorer.cc index 9d52fde..212a476 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer.cc @@ -20,6 +20,8 @@ #include "icing/query/advanced_query_parser/parser.h" #include "icing/scoring/advanced_scoring/score-expression.h" #include "icing/scoring/advanced_scoring/scoring-visitor.h" +#include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/section-weights.h" namespace icing { namespace lib { @@ -40,7 +42,13 @@ AdvancedScorer::Create(const ScoringSpecProto& scoring_spec, ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> tree_root, parser.ConsumeScoring()); - ScoringVisitor visitor(default_score); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store, scoring_spec)); + std::unique_ptr<Bm25fCalculator> bm25f_calculator = + std::make_unique<Bm25fCalculator>(document_store, + std::move(section_weights)); + ScoringVisitor visitor(default_score, document_store, schema_store, + bm25f_calculator.get()); tree_root->Accept(&visitor); ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression, @@ -50,8 +58,8 @@ AdvancedScorer::Create(const ScoringSpecProto& scoring_spec, "The root scoring expression will always be evaluated to a document, " "but a number is expected."); } - return std::unique_ptr<AdvancedScorer>( - new AdvancedScorer(std::move(expression), default_score)); + return std::unique_ptr<AdvancedScorer>(new AdvancedScorer( + std::move(expression), std::move(bm25f_calculator), default_score)); } } // namespace lib diff --git a/icing/scoring/advanced_scoring/advanced-scorer.h b/icing/scoring/advanced_scoring/advanced-scorer.h index 6557ba6..077d734 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer.h +++ b/icing/scoring/advanced_scoring/advanced-scorer.h @@ -22,6 +22,7 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/schema/schema-store.h" #include "icing/scoring/advanced_scoring/score-expression.h" +#include "icing/scoring/bm25f-calculator.h" #include "icing/scoring/scorer.h" #include "icing/store/document-store.h" @@ -50,13 +51,25 @@ class AdvancedScorer : public Scorer { return std::move(result).ValueOrDie(); } + void PrepareToScore( + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>* + query_term_iterators) override { + if (query_term_iterators == nullptr || query_term_iterators->empty()) { + return; + } + bm25f_calculator_->PrepareToScore(query_term_iterators); + } + private: explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression, + std::unique_ptr<Bm25fCalculator> bm25f_calculator, double default_score) : score_expression_(std::move(score_expression)), + bm25f_calculator_(std::move(bm25f_calculator)), default_score_(default_score) {} std::unique_ptr<ScoreExpression> score_expression_; + std::unique_ptr<Bm25fCalculator> bm25f_calculator_; double default_score_; }; diff --git a/icing/scoring/advanced_scoring/advanced-scorer_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_test.cc index 0d3a05c..36d38a2 100644 --- a/icing/scoring/advanced_scoring/advanced-scorer_test.cc +++ b/icing/scoring/advanced_scoring/advanced-scorer_test.cc @@ -114,6 +114,17 @@ DocumentProto CreateDocument( .Build(); } +UsageReport CreateUsageReport(std::string name_space, std::string uri, + int64 timestamp_ms, + UsageReport::UsageType usage_type) { + UsageReport usage_report; + usage_report.set_document_namespace(name_space); + usage_report.set_document_uri(uri); + usage_report.set_usage_timestamp_ms(timestamp_ms); + usage_report.set_usage_type(usage_type); + return usage_report; +} + ScoringSpecProto CreateAdvancedScoringSpec( const std::string& advanced_scoring_expression) { ScoringSpecProto scoring_spec; @@ -285,6 +296,174 @@ TEST_F(AdvancedScorerTest, BasicMathFunctionExpression) { EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps)); } +TEST_F(AdvancedScorerTest, DocumentScoreCreationTimestampFunctionExpression) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(CreateDocument( + "namespace", "uri", /*score=*/123, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs))); + DocHitInfo docHitInfo = DocHitInfo(document_id); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("this.documentScore()"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(123)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.creationTimestamp()"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(kDefaultCreationTimestampMs)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec( + "this.documentScore() + this.creationTimestamp()"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), + Eq(123 + kDefaultCreationTimestampMs)); +} + +TEST_F(AdvancedScorerTest, DocumentUsageFunctionExpression) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(CreateDocument("namespace", "uri"))); + DocHitInfo docHitInfo = DocHitInfo(document_id); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount(1) + this.usageCount(2) " + "+ this.usageLastUsedTimestamp(3)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); + ICING_ASSERT_OK(document_store_->ReportUsage( + CreateUsageReport("namespace", "uri", 100000, UsageReport::USAGE_TYPE1))); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1)); + ICING_ASSERT_OK(document_store_->ReportUsage( + CreateUsageReport("namespace", "uri", 200000, UsageReport::USAGE_TYPE2))); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(2)); + ICING_ASSERT_OK(document_store_->ReportUsage( + CreateUsageReport("namespace", "uri", 300000, UsageReport::USAGE_TYPE3))); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(300002)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(1)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(100000)); + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(2)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(200000)); + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageLastUsedTimestamp(3)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(300000)); +} + +TEST_F(AdvancedScorerTest, DocumentUsageFunctionOutOfRange) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(CreateDocument("namespace", "uri"))); + DocHitInfo docHitInfo = DocHitInfo(document_id); + + const double default_score = 123; + + // Should get default score for the following expressions that cause "runtime" + // errors. + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("this.usageCount(4)"), + default_score, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount(0)"), + default_score, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount(1.5)"), + default_score, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(default_score)); +} + +// scoring-processor_test.cc will help to get better test coverage for relevance +// score. +TEST_F(AdvancedScorerTest, RelevanceScoreFunctionScoreExpression) { + DocumentProto test_document = + DocumentBuilder() + .SetScore(5) + .SetKey("namespace", "uri") + .SetSchema("email") + .AddStringProperty("subject", "subject foo") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, + document_store_->Put(test_document)); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<AdvancedScorer> scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("this.relevanceScore()"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + scorer->PrepareToScore(/*query_term_iterators=*/{}); + + // Should get the default score. + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer->GetScore(docHitInfo, /*query_it=*/nullptr), Eq(10)); +} + +TEST_F(AdvancedScorerTest, ComplexExpression) { + const int64_t creation_timestamp_ms = 123; + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(CreateDocument("namespace", "uri", /*score=*/123, + creation_timestamp_ms))); + DocHitInfo docHitInfo = DocHitInfo(document_id); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec( + "pow(sin(2), 2)" + // This is this.usageCount(1) + "+ this.usageCount(this.documentScore() - 122)" + "/ 12.34" + "* (10 * pow(2 * 1, sin(2))" + "+ 10 * (2 + 10 + this.creationTimestamp()))" + // This should evaluate to default score. + "+ this.relevanceScore()"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + scorer->PrepareToScore(/*query_term_iterators=*/{}); + + ICING_ASSERT_OK(document_store_->ReportUsage( + CreateUsageReport("namespace", "uri", 0, UsageReport::USAGE_TYPE1))); + ICING_ASSERT_OK(document_store_->ReportUsage( + CreateUsageReport("namespace", "uri", 0, UsageReport::USAGE_TYPE1))); + EXPECT_THAT(scorer->GetScore(docHitInfo), + DoubleNear(pow(sin(2), 2) + + 2 / 12.34 * + (10 * pow(2 * 1, sin(2)) + + 10 * (2 + 10 + creation_timestamp_ms)) + + 10, + kEps)); +} + // Should be a parsing Error TEST_F(AdvancedScorerTest, EmptyExpression) { EXPECT_THAT( @@ -398,6 +577,47 @@ TEST_F(AdvancedScorerTest, MathTypeError) { StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } +TEST_F(AdvancedScorerTest, DocumentFunctionTypeError) { + const double default_score = 0; + + EXPECT_THAT(AdvancedScorer::Create( + CreateAdvancedScoringSpec("documentScore(1)"), default_score, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.creationTimestamp(1)"), + default_score, document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.usageCount()"), default_score, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create( + CreateAdvancedScoringSpec("usageLastUsedTimestamp(1, 1)"), + default_score, document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create( + CreateAdvancedScoringSpec("relevanceScore(1)"), default_score, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create( + CreateAdvancedScoringSpec("documentScore(this)"), + default_score, document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create( + CreateAdvancedScoringSpec("that.documentScore()"), + default_score, document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create( + CreateAdvancedScoringSpec("this.this.creationTimestamp()"), + default_score, document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("this.log(2)"), + default_score, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + } // namespace } // namespace lib diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc index cd77046..08da1c5 100644 --- a/icing/scoring/advanced_scoring/score-expression.cc +++ b/icing/scoring/advanced_scoring/score-expression.cc @@ -199,5 +199,132 @@ libtextclassifier3::StatusOr<double> MathFunctionScoreExpression::eval( return res; } +const std::unordered_map<std::string, + DocumentFunctionScoreExpression::FunctionType> + DocumentFunctionScoreExpression::kFunctionNames = { + {"documentScore", FunctionType::kDocumentScore}, + {"creationTimestamp", FunctionType::kCreationTimestamp}, + {"usageCount", FunctionType::kUsageCount}, + {"usageLastUsedTimestamp", FunctionType::kUsageLastUsedTimestamp}}; + +libtextclassifier3::StatusOr<std::unique_ptr<DocumentFunctionScoreExpression>> +DocumentFunctionScoreExpression::Create( + FunctionType function_type, + std::vector<std::unique_ptr<ScoreExpression>> children, + const DocumentStore* document_store, double default_score) { + if (children.empty()) { + return absl_ports::InvalidArgumentError( + "Document-based functions must have at least one argument."); + } + for (const auto& child : children) { + ICING_RETURN_ERROR_IF_NULL(child); + } + if (!children[0]->is_document_type()) { + return absl_ports::InvalidArgumentError( + "The first parameter of document-based functions must be \"this\"."); + } + switch (function_type) { + case FunctionType::kDocumentScore: + [[fallthrough]]; + case FunctionType::kCreationTimestamp: + if (children.size() != 1) { + return absl_ports::InvalidArgumentError( + "DocumentScore/CreationTimestamp must have 1 argument."); + } + break; + case FunctionType::kUsageCount: + [[fallthrough]]; + case FunctionType::kUsageLastUsedTimestamp: + if (children.size() != 2 || children[1]->is_document_type()) { + return absl_ports::InvalidArgumentError( + "UsageCount/UsageLastUsedTimestamp must have 2 arguments. The " + "first argument should be \"this\", and the second argument " + "should be the usage type."); + } + break; + } + return std::unique_ptr<DocumentFunctionScoreExpression>( + new DocumentFunctionScoreExpression(function_type, std::move(children), + document_store, default_score)); +} + +libtextclassifier3::StatusOr<double> DocumentFunctionScoreExpression::eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) { + switch (function_type_) { + case FunctionType::kDocumentScore: + [[fallthrough]]; + case FunctionType::kCreationTimestamp: { + ICING_ASSIGN_OR_RETURN(DocumentAssociatedScoreData score_data, + document_store_.GetDocumentAssociatedScoreData( + hit_info.document_id()), + default_score_); + if (function_type_ == FunctionType::kDocumentScore) { + return static_cast<double>(score_data.document_score()); + } + return static_cast<double>(score_data.creation_timestamp_ms()); + } + case FunctionType::kUsageCount: + [[fallthrough]]; + case FunctionType::kUsageLastUsedTimestamp: { + ICING_ASSIGN_OR_RETURN(double raw_usage_type, + children_[1]->eval(hit_info, query_it)); + int usage_type = (int)raw_usage_type; + if (usage_type < 1 || usage_type > 3 || raw_usage_type != usage_type) { + return absl_ports::InvalidArgumentError( + "Usage type must be an integer from 1 to 3"); + } + ICING_ASSIGN_OR_RETURN( + UsageStore::UsageScores usage_scores, + document_store_.GetUsageScores(hit_info.document_id()), + default_score_); + if (function_type_ == FunctionType::kUsageCount) { + if (usage_type == 1) { + return usage_scores.usage_type1_count; + } else if (usage_type == 2) { + return usage_scores.usage_type2_count; + } else { + return usage_scores.usage_type3_count; + } + } + if (usage_type == 1) { + return usage_scores.usage_type1_last_used_timestamp_s * 1000.0; + } else if (usage_type == 2) { + return usage_scores.usage_type2_last_used_timestamp_s * 1000.0; + } else { + return usage_scores.usage_type3_last_used_timestamp_s * 1000.0; + } + } + } +} + +libtextclassifier3::StatusOr< + std::unique_ptr<RelevanceScoreFunctionScoreExpression>> +RelevanceScoreFunctionScoreExpression::Create( + std::vector<std::unique_ptr<ScoreExpression>> children, + Bm25fCalculator* bm25f_calculator, double default_score) { + if (children.size() != 1) { + return absl_ports::InvalidArgumentError( + "relevanceScore must have 1 argument."); + } + ICING_RETURN_ERROR_IF_NULL(children[0]); + if (!children[0]->is_document_type()) { + return absl_ports::InvalidArgumentError( + "relevanceScore must take \"this\" as its argument."); + } + return std::unique_ptr<RelevanceScoreFunctionScoreExpression>( + new RelevanceScoreFunctionScoreExpression( + std::move(children), bm25f_calculator, default_score)); +} + +libtextclassifier3::StatusOr<double> +RelevanceScoreFunctionScoreExpression::eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) { + if (query_it == nullptr) { + return default_score_; + } + return static_cast<double>( + bm25f_calculator_.ComputeScore(query_it, hit_info, default_score_)); +} + } // namespace lib } // namespace icing diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h index 0e0c538..533ca52 100644 --- a/icing/scoring/advanced_scoring/score-expression.h +++ b/icing/scoring/advanced_scoring/score-expression.h @@ -24,6 +24,8 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/scoring/bm25f-calculator.h" +#include "icing/store/document-store.h" #include "icing/util/status-macros.h" namespace icing { @@ -148,6 +150,75 @@ class MathFunctionScoreExpression : public ScoreExpression { std::vector<std::unique_ptr<ScoreExpression>> children_; }; +class DocumentFunctionScoreExpression : public ScoreExpression { + public: + enum class FunctionType { + kDocumentScore, + kCreationTimestamp, + kUsageCount, + kUsageLastUsedTimestamp, + }; + + static const std::unordered_map<std::string, FunctionType> kFunctionNames; + + // RETURNS: + // - A DocumentFunctionScoreExpression instance on success. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr< + std::unique_ptr<DocumentFunctionScoreExpression>> + Create(FunctionType function_type, + std::vector<std::unique_ptr<ScoreExpression>> children, + const DocumentStore* document_store, double default_score); + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override; + + private: + explicit DocumentFunctionScoreExpression( + FunctionType function_type, + std::vector<std::unique_ptr<ScoreExpression>> children, + const DocumentStore* document_store, double default_score) + : children_(std::move(children)), + document_store_(*document_store), + default_score_(default_score), + function_type_(function_type) {} + + std::vector<std::unique_ptr<ScoreExpression>> children_; + const DocumentStore& document_store_; + double default_score_; + FunctionType function_type_; +}; + +class RelevanceScoreFunctionScoreExpression : public ScoreExpression { + public: + static constexpr std::string_view kFunctionName = "relevanceScore"; + + // RETURNS: + // - A RelevanceScoreFunctionScoreExpression instance on success. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr< + std::unique_ptr<RelevanceScoreFunctionScoreExpression>> + Create(std::vector<std::unique_ptr<ScoreExpression>> children, + Bm25fCalculator* bm25f_calculator, double default_score); + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override; + + private: + explicit RelevanceScoreFunctionScoreExpression( + std::vector<std::unique_ptr<ScoreExpression>> children, + Bm25fCalculator* bm25f_calculator, double default_score) + : children_(std::move(children)), + bm25f_calculator_(*bm25f_calculator), + default_score_(default_score) {} + + std::vector<std::unique_ptr<ScoreExpression>> children_; + Bm25fCalculator& bm25f_calculator_; + double default_score_; +}; + } // namespace lib } // namespace icing diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc index 7737213..1396dcc 100644 --- a/icing/scoring/advanced_scoring/scoring-visitor.cc +++ b/icing/scoring/advanced_scoring/scoring-visitor.cc @@ -35,15 +35,22 @@ void ScoringVisitor::VisitText(const TextNode* node) { } void ScoringVisitor::VisitMember(const MemberNode* node) { + bool is_member_function = node->function() != nullptr; + if (is_member_function) { + // If the member node represents a member function, it must have only one + // child for "this". + if (node->children().size() != 1 || + node->children()[0]->value() != "this") { + pending_error_ = absl_ports::InvalidArgumentError( + "Member functions can only be called via \"this\"."); + return; + } + return VisitFunctionHelper(node->function(), is_member_function); + } std::string value; if (node->children().size() == 1) { - // If a member has only one child, then it can be a numeric literal, - // or "this" if the member is a reference to a member function. + // If a member has only one child, then it represents a integer literal. value = node->children()[0]->value(); - if (value == "this") { - stack.push_back(ThisExpression::Create()); - return; - } } else if (node->children().size() == 2) { // If a member has two children, then it can only represent a floating point // number, so we need to join them by "." to build the numeric literal. @@ -68,8 +75,12 @@ void ScoringVisitor::VisitMember(const MemberNode* node) { stack.push_back(ConstantScoreExpression::Create(number)); } -void ScoringVisitor::VisitFunction(const FunctionNode* node) { +void ScoringVisitor::VisitFunctionHelper(const FunctionNode* node, + bool is_member_function) { std::vector<std::unique_ptr<ScoreExpression>> children; + if (is_member_function) { + children.push_back(ThisExpression::Create()); + } for (const auto& arg : node->args()) { arg->Accept(this); if (has_pending_error()) { @@ -82,9 +93,20 @@ void ScoringVisitor::VisitFunction(const FunctionNode* node) { absl_ports::InvalidArgumentError( absl_ports::StrCat("Unknown function: ", function_name)); - // Math functions - if (MathFunctionScoreExpression::kFunctionNames.find(function_name) != - MathFunctionScoreExpression::kFunctionNames.end()) { + if (DocumentFunctionScoreExpression::kFunctionNames.find(function_name) != + DocumentFunctionScoreExpression::kFunctionNames.end()) { + // Document-based function + expression = DocumentFunctionScoreExpression::Create( + DocumentFunctionScoreExpression::kFunctionNames.at(function_name), + std::move(children), &document_store_, default_score_); + } else if (function_name == + RelevanceScoreFunctionScoreExpression::kFunctionName) { + // relevanceScore function + expression = RelevanceScoreFunctionScoreExpression::Create( + std::move(children), &bm25f_calculator_, default_score_); + } else if (MathFunctionScoreExpression::kFunctionNames.find(function_name) != + MathFunctionScoreExpression::kFunctionNames.end()) { + // Math functions expression = MathFunctionScoreExpression::Create( MathFunctionScoreExpression::kFunctionNames.at(function_name), std::move(children)); diff --git a/icing/scoring/advanced_scoring/scoring-visitor.h b/icing/scoring/advanced_scoring/scoring-visitor.h index 47a03fd..539af2d 100644 --- a/icing/scoring/advanced_scoring/scoring-visitor.h +++ b/icing/scoring/advanced_scoring/scoring-visitor.h @@ -21,20 +21,32 @@ #include "icing/proto/scoring.pb.h" #include "icing/query/advanced_query_parser/abstract-syntax-tree.h" #include "icing/scoring/advanced_scoring/score-expression.h" +#include "icing/scoring/bm25f-calculator.h" +#include "icing/store/document-store.h" namespace icing { namespace lib { class ScoringVisitor : public AbstractSyntaxTreeVisitor { public: - explicit ScoringVisitor(double default_score) - : default_score_(default_score) {} + explicit ScoringVisitor(double default_score, + const DocumentStore* document_store, + const SchemaStore* schema_store, + Bm25fCalculator* bm25f_calculator) + : default_score_(default_score), + document_store_(*document_store), + schema_store_(*schema_store), + bm25f_calculator_(*bm25f_calculator) {} void VisitFunctionName(const FunctionNameNode* node) override; void VisitString(const StringNode* node) override; void VisitText(const TextNode* node) override; void VisitMember(const MemberNode* node) override; - void VisitFunction(const FunctionNode* node) override; + + void VisitFunction(const FunctionNode* node) override { + return VisitFunctionHelper(node, /*is_member_function=*/false); + } + void VisitUnaryOperator(const UnaryOperatorNode* node) override; void VisitNaryOperator(const NaryOperatorNode* node) override; @@ -58,6 +70,10 @@ class ScoringVisitor : public AbstractSyntaxTreeVisitor { } private: + // Visit function node. If is_member_function is true, a ThisExpression will + // be added as the first function argument. + void VisitFunctionHelper(const FunctionNode* node, bool is_member_function); + bool has_pending_error() const { return !pending_error_.ok(); } std::unique_ptr<ScoreExpression> pop_stack() { @@ -67,6 +83,10 @@ class ScoringVisitor : public AbstractSyntaxTreeVisitor { } double default_score_; + const DocumentStore& document_store_; + const SchemaStore& schema_store_; + Bm25fCalculator& bm25f_calculator_; + libtextclassifier3::Status pending_error_; std::vector<std::unique_ptr<ScoreExpression>> stack; }; diff --git a/icing/scoring/scorer-test-utils.h b/icing/scoring/scorer-test-utils.h new file mode 100644 index 0000000..c848970 --- /dev/null +++ b/icing/scoring/scorer-test-utils.h @@ -0,0 +1,77 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_SCORER_TEST_UTILS_H_ +#define ICING_SCORING_SCORER_TEST_UTILS_H_ + +#include "icing/proto/scoring.pb.h" + +namespace icing { +namespace lib { + +enum class ScorerTestingMode { kNormal, kAdvanced }; + +inline ScoringSpecProto CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::Code ranking_strategy, + ScorerTestingMode testing_mode = ScorerTestingMode::kNormal) { + ScoringSpecProto scoring_spec; + if (testing_mode != ScorerTestingMode::kAdvanced) { + scoring_spec.set_rank_by(ranking_strategy); + return scoring_spec; + } + scoring_spec.set_rank_by( + ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION); + switch (ranking_strategy) { + case ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE: + scoring_spec.set_advanced_scoring_expression("this.documentScore()"); + return scoring_spec; + case ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP: + scoring_spec.set_advanced_scoring_expression("this.creationTimestamp()"); + return scoring_spec; + case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT: + scoring_spec.set_advanced_scoring_expression("this.usageCount(1)"); + return scoring_spec; + case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT: + scoring_spec.set_advanced_scoring_expression("this.usageCount(2)"); + return scoring_spec; + case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT: + scoring_spec.set_advanced_scoring_expression("this.usageCount(3)"); + return scoring_spec; + case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP: + scoring_spec.set_advanced_scoring_expression( + "this.usageLastUsedTimestamp(1)"); + return scoring_spec; + case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP: + scoring_spec.set_advanced_scoring_expression( + "this.usageLastUsedTimestamp(2)"); + return scoring_spec; + case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP: + scoring_spec.set_advanced_scoring_expression( + "this.usageLastUsedTimestamp(3)"); + return scoring_spec; + case ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE: + scoring_spec.set_advanced_scoring_expression("this.relevanceScore()"); + return scoring_spec; + case ScoringSpecProto::RankingStrategy::NONE: + case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE: + case ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION: + scoring_spec.set_rank_by(ranking_strategy); + return scoring_spec; + } +} + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_SCORER_TEST_UTILS_H_ diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc index 7bbb8b7..f141738 100644 --- a/icing/scoring/scorer_test.cc +++ b/icing/scoring/scorer_test.cc @@ -29,6 +29,7 @@ #include "icing/schema-builder.h" #include "icing/schema/schema-store.h" #include "icing/scoring/scorer-factory.h" +#include "icing/scoring/scorer-test-utils.h" #include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" @@ -42,7 +43,7 @@ namespace lib { namespace { using ::testing::Eq; -class ScorerTest : public testing::Test { +class ScorerTest : public ::testing::TestWithParam<ScorerTestingMode> { protected: ScorerTest() : test_dir_(GetTestTempDir() + "/icing"), @@ -120,37 +121,31 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri, return usage_report; } -ScoringSpecProto CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::Code ranking_strategy) { - ScoringSpecProto scoring_spec; - scoring_spec.set_rank_by(ranking_strategy); - return scoring_spec; -} - -TEST_F(ScorerTest, CreationWithNullDocumentStoreShouldFail) { +TEST_P(ScorerTest, CreationWithNullDocumentStoreShouldFail) { EXPECT_THAT( scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), /*default_score=*/0, /*document_store=*/nullptr, schema_store()), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } -TEST_F(ScorerTest, CreationWithNullSchemaStoreShouldFail) { - EXPECT_THAT(scorer_factory::Create( - CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), - /*default_score=*/0, document_store(), - /*schema_store=*/nullptr), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +TEST_P(ScorerTest, CreationWithNullSchemaStoreShouldFail) { + EXPECT_THAT( + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), + /*default_score=*/0, document_store(), + /*schema_store=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } -TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) { +TEST_P(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), /*default_score=*/10, document_store(), schema_store())); // Non existent document id @@ -159,7 +154,7 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) { EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); } -TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) { +TEST_P(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) { // Creates a test document with a provided score DocumentProto test_document = DocumentBuilder() .SetKey("icing", "email/1") @@ -175,7 +170,7 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) { std::unique_ptr<Scorer> scorer, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -189,7 +184,7 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) { EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); } -TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) { +TEST_P(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) { // Creates a test document with a provided score int64_t creation_time = fake_clock1().GetSystemTimeMilliseconds(); int64_t ttl = 100; @@ -209,7 +204,7 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) { std::unique_ptr<Scorer> scorer, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -223,7 +218,7 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) { EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); } -TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) { +TEST_P(ScorerTest, ShouldGetDefaultDocumentScore) { // Creates a test document with the default document score 0 DocumentProto test_document = DocumentBuilder() @@ -239,14 +234,14 @@ TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) { std::unique_ptr<Scorer> scorer, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); } -TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) { +TEST_P(ScorerTest, ShouldGetCorrectDocumentScore) { // Creates a test document with document score 5 DocumentProto test_document = DocumentBuilder() @@ -263,7 +258,7 @@ TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) { std::unique_ptr<Scorer> scorer, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()), /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -272,7 +267,7 @@ TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) { // See scoring-processor_test.cc and icing-search-engine_test.cc for better // Bm25F scoring tests. -TEST_F(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { +TEST_P(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { // Creates a test document with document score 5 DocumentProto test_document = DocumentBuilder() @@ -289,14 +284,14 @@ TEST_F(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { std::unique_ptr<Scorer> scorer, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE), + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()), /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); } -TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) { +TEST_P(ScorerTest, ShouldGetCorrectCreationTimestampScore) { // Creates test_document1 with fake timestamp1 DocumentProto test_document1 = DocumentBuilder() @@ -322,7 +317,8 @@ TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) { std::unique_ptr<Scorer> scorer, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP), + ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP, + GetParam()), /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(document_id1); @@ -333,7 +329,7 @@ TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) { Eq(fake_clock2().GetSystemTimeMilliseconds())); } -TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { +TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { DocumentProto test_document = DocumentBuilder() .SetKey("icing", "email/1") @@ -350,19 +346,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { std::unique_ptr<Scorer> scorer1, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()), /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()), /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()), /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); @@ -380,7 +376,7 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); } -TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { +TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { DocumentProto test_document = DocumentBuilder() .SetKey("icing", "email/1") @@ -397,19 +393,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { std::unique_ptr<Scorer> scorer1, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()), /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()), /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()), /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); @@ -427,7 +423,7 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); } -TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { +TEST_P(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { DocumentProto test_document = DocumentBuilder() .SetKey("icing", "email/1") @@ -444,19 +440,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { std::unique_ptr<Scorer> scorer1, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()), /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, GetParam()), /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, GetParam()), /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); @@ -474,7 +470,7 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(1)); } -TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { +TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { DocumentProto test_document = DocumentBuilder() .SetKey("icing", "email/1") @@ -540,7 +536,7 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); } -TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { +TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { DocumentProto test_document = DocumentBuilder() .SetKey("icing", "email/1") @@ -606,7 +602,7 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); } -TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { +TEST_P(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { DocumentProto test_document = DocumentBuilder() .SetKey("icing", "email/1") @@ -672,13 +668,13 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(5000)); } -TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { +TEST_P(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - scorer_factory::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::NONE), - /*default_score=*/3, document_store(), - schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE, GetParam()), + /*default_score=*/3, document_store(), schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0); DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1); @@ -690,7 +686,7 @@ TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { ICING_ASSERT_OK_AND_ASSIGN( scorer, scorer_factory::Create( CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::NONE), + ScoringSpecProto::RankingStrategy::NONE, GetParam()), /*default_score=*/111, document_store(), schema_store())); docHitInfo1 = DocHitInfo(/*document_id_in=*/4); @@ -701,7 +697,7 @@ TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(111)); } -TEST_F(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) { +TEST_P(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) { DocumentProto test_document = DocumentBuilder() .SetKey("icing", "email/1") @@ -734,6 +730,10 @@ TEST_F(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) { EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(max_int_usage_timestamp_score)); } +INSTANTIATE_TEST_SUITE_P(ScorerTest, ScorerTest, + testing::Values(ScorerTestingMode::kNormal, + ScorerTestingMode::kAdvanced)); + } // namespace } // namespace lib diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc index 921fc7f..7e4ca1d 100644 --- a/icing/scoring/scoring-processor_test.cc +++ b/icing/scoring/scoring-processor_test.cc @@ -27,6 +27,7 @@ #include "icing/proto/term.pb.h" #include "icing/proto/usage.pb.h" #include "icing/schema-builder.h" +#include "icing/scoring/scorer-test-utils.h" #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" #include "icing/testing/tmp-directory.h" @@ -41,7 +42,8 @@ using ::testing::Gt; using ::testing::IsEmpty; using ::testing::SizeIs; -class ScoringProcessorTest : public testing::Test { +class ScoringProcessorTest + : public ::testing::TestWithParam<ScorerTestingMode> { protected: ScoringProcessorTest() : test_dir_(GetTestTempDir() + "/icing"), @@ -187,21 +189,21 @@ TEST_F(ScoringProcessorTest, CreationWithNullSchemaStoreShouldFail) { StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } -TEST_F(ScoringProcessorTest, ShouldCreateInstance) { - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); +TEST_P(ScoringProcessorTest, ShouldCreateInstance) { + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()); ICING_EXPECT_OK( ScoringProcessor::Create(spec_proto, document_store(), schema_store())); } -TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { +TEST_P(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { // Creates an empty DocHitInfoIterator std::vector<DocHitInfo> doc_hit_infos = {}; std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()); // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( @@ -213,7 +215,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { IsEmpty()); } -TEST_F(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) { +TEST_P(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) { // Sets up documents ICING_ASSERT_OK_AND_ASSIGN( DocumentId document_id1, @@ -226,8 +228,8 @@ TEST_F(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) { std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()); // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( @@ -245,7 +247,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) { IsEmpty()); } -TEST_F(ScoringProcessorTest, ShouldRespectNumToScore) { +TEST_P(ScoringProcessorTest, ShouldRespectNumToScore) { // Sets up documents ICING_ASSERT_OK_AND_ASSIGN( auto doc_hit_result_pair, @@ -256,8 +258,8 @@ TEST_F(ScoringProcessorTest, ShouldRespectNumToScore) { std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()); // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( @@ -275,7 +277,7 @@ TEST_F(ScoringProcessorTest, ShouldRespectNumToScore) { SizeIs(3)); } -TEST_F(ScoringProcessorTest, ShouldScoreByDocumentScore) { +TEST_P(ScoringProcessorTest, ShouldScoreByDocumentScore) { // Creates input doc_hit_infos and expected output scored_document_hits ICING_ASSERT_OK_AND_ASSIGN( auto doc_hit_result_pair, @@ -288,8 +290,8 @@ TEST_F(ScoringProcessorTest, ShouldScoreByDocumentScore) { std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()); // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( @@ -303,7 +305,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByDocumentScore) { EqualsScoredDocumentHit(scored_document_hits.at(2)))); } -TEST_F(ScoringProcessorTest, +TEST_P(ScoringProcessorTest, ShouldScoreByRelevanceScore_DocumentsWithDifferentLength) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -343,8 +345,8 @@ TEST_F(ScoringProcessorTest, std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()); // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( @@ -372,7 +374,7 @@ TEST_F(ScoringProcessorTest, EqualsScoredDocumentHit(expected_scored_doc_hit3))); } -TEST_F(ScoringProcessorTest, +TEST_P(ScoringProcessorTest, ShouldScoreByRelevanceScore_DocumentsWithSameLength) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -412,8 +414,8 @@ TEST_F(ScoringProcessorTest, std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()); // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( @@ -440,7 +442,7 @@ TEST_F(ScoringProcessorTest, EqualsScoredDocumentHit(expected_scored_doc_hit3))); } -TEST_F(ScoringProcessorTest, +TEST_P(ScoringProcessorTest, ShouldScoreByRelevanceScore_DocumentsWithDifferentQueryFrequency) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -485,8 +487,8 @@ TEST_F(ScoringProcessorTest, std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()); // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( @@ -513,7 +515,7 @@ TEST_F(ScoringProcessorTest, EqualsScoredDocumentHit(expected_scored_doc_hit3))); } -TEST_F(ScoringProcessorTest, +TEST_P(ScoringProcessorTest, ShouldScoreByRelevanceScore_HitTermWithZeroFrequency) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -534,8 +536,8 @@ TEST_F(ScoringProcessorTest, std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()); // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( @@ -558,7 +560,7 @@ TEST_F(ScoringProcessorTest, ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1))); } -TEST_F(ScoringProcessorTest, +TEST_P(ScoringProcessorTest, ShouldScoreByRelevanceScore_SameHitFrequencyDifferentPropertyWeights) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -592,8 +594,8 @@ TEST_F(ScoringProcessorTest, std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()); PropertyWeight body_property_weight = CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5); @@ -630,7 +632,7 @@ TEST_F(ScoringProcessorTest, EqualsScoredDocumentHit(expected_scored_doc_hit2))); } -TEST_F(ScoringProcessorTest, +TEST_P(ScoringProcessorTest, ShouldScoreByRelevanceScore_WithImplicitPropertyWeight) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -664,8 +666,8 @@ TEST_F(ScoringProcessorTest, std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()); PropertyWeight body_property_weight = CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5); @@ -702,7 +704,7 @@ TEST_F(ScoringProcessorTest, EqualsScoredDocumentHit(expected_scored_doc_hit2))); } -TEST_F(ScoringProcessorTest, +TEST_P(ScoringProcessorTest, ShouldScoreByRelevanceScore_WithDefaultPropertyWeight) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -727,8 +729,8 @@ TEST_F(ScoringProcessorTest, std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()); *spec_proto.add_type_property_weights() = CreateTypePropertyWeights(/*schema_type=*/"email", {}); @@ -738,9 +740,9 @@ TEST_F(ScoringProcessorTest, std::unique_ptr<ScoringProcessor> scoring_processor, ScoringProcessor::Create(spec_proto, document_store(), schema_store())); - ScoringSpecProto spec_proto_with_weights; - spec_proto_with_weights.set_rank_by( - ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ScoringSpecProto spec_proto_with_weights = + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()); PropertyWeight body_property_weight = CreatePropertyWeight(/*path=*/"body", /*weight=*/1.0); @@ -789,7 +791,7 @@ TEST_F(ScoringProcessorTest, ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit))); } -TEST_F(ScoringProcessorTest, +TEST_P(ScoringProcessorTest, ShouldScoreByRelevanceScore_WithZeroPropertyWeight) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -823,8 +825,8 @@ TEST_F(ScoringProcessorTest, std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, GetParam()); // Sets property weight for "body" to 0.0. PropertyWeight body_property_weight = @@ -861,7 +863,7 @@ TEST_F(ScoringProcessorTest, EXPECT_THAT(scored_document_hits.at(1).score(), Gt(0.0)); } -TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { +TEST_P(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, /*creation_timestamp_ms=*/1571100001111); @@ -894,8 +896,8 @@ TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP, GetParam()); // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( @@ -909,7 +911,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { EqualsScoredDocumentHit(scored_document_hit1))); } -TEST_F(ScoringProcessorTest, ShouldScoreByUsageCount) { +TEST_P(ScoringProcessorTest, ShouldScoreByUsageCount) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); @@ -954,8 +956,8 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageCount) { std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, GetParam()); // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( @@ -969,7 +971,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageCount) { EqualsScoredDocumentHit(scored_document_hit3))); } -TEST_F(ScoringProcessorTest, ShouldScoreByUsageTimestamp) { +TEST_P(ScoringProcessorTest, ShouldScoreByUsageTimestamp) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); @@ -1013,9 +1015,9 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageTimestamp) { std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, + GetParam()); // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( @@ -1029,7 +1031,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageTimestamp) { EqualsScoredDocumentHit(scored_document_hit3))); } -TEST_F(ScoringProcessorTest, ShouldHandleNoScores) { +TEST_P(ScoringProcessorTest, ShouldHandleNoScores) { // Creates input doc_hit_infos and corresponding scored_document_hits ICING_ASSERT_OK_AND_ASSIGN( auto doc_hit_result_pair, @@ -1050,8 +1052,8 @@ TEST_F(ScoringProcessorTest, ShouldHandleNoScores) { ScoredDocumentHit scored_document_hit_default = ScoredDocumentHit(4, kSectionIdMaskNone, /*score=*/0.0); - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, GetParam()); // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( @@ -1065,7 +1067,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNoScores) { EqualsScoredDocumentHit(scored_document_hits.at(2)))); } -TEST_F(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) { +TEST_P(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) { DocumentProto document1 = CreateDocument("icing", "email/1", /*score=*/1, kDefaultCreationTimestampMs); DocumentProto document2 = CreateDocument("icing", "email/2", /*score=*/2, @@ -1099,8 +1101,8 @@ TEST_F(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) { std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos); // A ScoringSpecProto with no scoring strategy - ScoringSpecProto spec_proto; - spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::NONE); + ScoringSpecProto spec_proto = CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE, GetParam()); // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( @@ -1114,6 +1116,10 @@ TEST_F(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) { EqualsScoredDocumentHit(scored_document_hit1))); } +INSTANTIATE_TEST_SUITE_P(ScoringProcessorTest, ScoringProcessorTest, + testing::Values(ScorerTestingMode::kNormal, + ScorerTestingMode::kAdvanced)); + } // namespace } // namespace lib diff --git a/java/src/com/google/android/icing/IcingSearchEngineInterface.java b/java/src/com/google/android/icing/IcingSearchEngineInterface.java index 5922716..0bc58f1 100644 --- a/java/src/com/google/android/icing/IcingSearchEngineInterface.java +++ b/java/src/com/google/android/icing/IcingSearchEngineInterface.java @@ -1,6 +1,5 @@ package com.google.android.icing; -import android.os.RemoteException; import com.google.android.icing.proto.DebugInfoResultProto; import com.google.android.icing.proto.DebugInfoVerbosity; import com.google.android.icing.proto.DeleteByNamespaceResultProto; @@ -33,12 +32,7 @@ import com.google.android.icing.proto.SuggestionSpecProto; import com.google.android.icing.proto.UsageReport; import java.io.Closeable; -/** - * A common user-facing interface to expose the funcationalities provided by Icing Library. - * - * <p>All the methods here throw {@link RemoteException} because the implementation for - * gmscore-appsearch-dynamite will throw it. - */ +/** A common user-facing interface to expose the funcationalities provided by Icing Library. */ public interface IcingSearchEngineInterface extends Closeable { /** * Initializes the current IcingSearchEngine implementation. diff --git a/proto/icing/proto/schema.proto b/proto/icing/proto/schema.proto index d9c43e2..5d1685c 100644 --- a/proto/icing/proto/schema.proto +++ b/proto/icing/proto/schema.proto @@ -164,10 +164,33 @@ message IntegerIndexingConfig { optional NumericMatchType.Code numeric_match_type = 1; } +// Describes how a property can be used to join this document with another +// document. See JoinSpecProto (in search.proto) for more details. +// Next tag: 2 +message JoinableConfig { + // OPTIONAL: Indicates what joinable type the content value of this property + // is. + // + // The default value is NONE. + message ValueType { + enum Code { + // Value in this property is not joinable. + NONE = 0; + + // Value in this property is a joinable (string) qualified id, which is + // composed of namespace and uri. + // See JoinSpecProto (in search.proto) and DocumentProto (in + // document.proto) for more details about qualified id, namespace and uri. + QUALIFIED_ID = 1; + } + } + optional ValueType.Code value_type = 1; +} + // Describes the schema of a single property of Documents that belong to a // specific SchemaTypeConfigProto. These can be considered as a rich, structured // type for each property of Documents accepted by IcingSearchEngine. -// Next tag: 8 +// Next tag: 9 message PropertyConfigProto { // REQUIRED: Name that uniquely identifies a property within an Document of // a specific SchemaTypeConfigProto. @@ -248,6 +271,16 @@ message PropertyConfigProto { // OPTIONAL: Describes how int64 properties should be indexed. Int64 // properties that do not set the indexing config will not be indexed. optional IntegerIndexingConfig integer_indexing_config = 7; + + // OPTIONAL: Describes how string properties can be used as a document joining + // matcher. + // + // Note: currently we only support STRING single joining, so if a property is + // set as joinable (i.e. joinable_config.content_type is not NONE), then: + // - DataType should be STRING. Otherwise joinable_config will be ignored. + // - The property itself and any upper-level (nested doc) property should + // contain at most one element (i.e. Cardinality is OPTIONAL or REQUIRED). + optional JoinableConfig joinable_config = 8; } // List of all supported types constitutes the schema used by Icing. diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index 654903b..d7938f1 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=494856295) +set(synced_AOSP_CL_number=495345294) |