diff options
70 files changed, 4404 insertions, 823 deletions
diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index d915d65..791368a 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -43,6 +43,8 @@ #include "icing/proto/search.pb.h" #include "icing/proto/status.pb.h" #include "icing/query/query-processor.h" +#include "icing/result/projection-tree.h" +#include "icing/result/projector.h" #include "icing/result/result-retriever.h" #include "icing/schema/schema-store.h" #include "icing/schema/schema-util.h" @@ -60,6 +62,7 @@ #include "icing/util/crc32.h" #include "icing/util/logging.h" #include "icing/util/status-macros.h" +#include "icing/util/tokenized-document.h" #include "unicode/uloc.h" namespace icing { @@ -693,7 +696,19 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { return result_proto; } - auto document_id_or = document_store_->Put(document, put_document_stats); + auto tokenized_document_or = TokenizedDocument::Create( + schema_store_.get(), language_segmenter_.get(), std::move(document)); + if (!tokenized_document_or.ok()) { + TransformStatus(tokenized_document_or.status(), result_status); + put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); + return result_proto; + } + TokenizedDocument tokenized_document( + std::move(tokenized_document_or).ValueOrDie()); + + auto document_id_or = + document_store_->Put(tokenized_document.document(), + tokenized_document.num_tokens(), put_document_stats); if (!document_id_or.ok()) { TransformStatus(document_id_or.status(), result_status); put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); @@ -702,8 +717,8 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { DocumentId document_id = document_id_or.ValueOrDie(); auto index_processor_or = IndexProcessor::Create( - schema_store_.get(), language_segmenter_.get(), normalizer_.get(), - index_.get(), CreateIndexProcessorOptions(options_), clock_.get()); + normalizer_.get(), index_.get(), CreateIndexProcessorOptions(options_), + clock_.get()); if (!index_processor_or.ok()) { TransformStatus(index_processor_or.status(), result_status); put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); @@ -712,8 +727,8 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { std::unique_ptr<IndexProcessor> index_processor = std::move(index_processor_or).ValueOrDie(); - auto status = - index_processor->IndexDocument(document, document_id, put_document_stats); + auto status = index_processor->IndexDocument(tokenized_document, document_id, + put_document_stats); TransformStatus(status, result_status); put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); @@ -721,7 +736,8 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { } GetResultProto IcingSearchEngine::Get(const std::string_view name_space, - const std::string_view uri) { + const std::string_view uri, + const GetResultSpecProto& result_spec) { GetResultProto result_proto; StatusProto* result_status = result_proto.mutable_status(); @@ -738,8 +754,29 @@ GetResultProto IcingSearchEngine::Get(const std::string_view name_space, return result_proto; } + DocumentProto document = std::move(document_or).ValueOrDie(); + std::unique_ptr<ProjectionTree> type_projection_tree; + std::unique_ptr<ProjectionTree> wildcard_projection_tree; + for (const TypePropertyMask& type_field_mask : + result_spec.type_property_masks()) { + if (type_field_mask.schema_type() == document.schema()) { + type_projection_tree = std::make_unique<ProjectionTree>(type_field_mask); + } else if (type_field_mask.schema_type() == + ProjectionTree::kSchemaTypeWildcard) { + wildcard_projection_tree = + std::make_unique<ProjectionTree>(type_field_mask); + } + } + + // Apply projection + if (type_projection_tree != nullptr) { + projector::Project(type_projection_tree->root().children, &document); + } else if (wildcard_projection_tree != nullptr) { + projector::Project(wildcard_projection_tree->root().children, &document); + } + result_status->set_code(StatusProto::OK); - *result_proto.mutable_document() = std::move(document_or).ValueOrDie(); + *result_proto.mutable_document() = std::move(document); return result_proto; } @@ -1237,7 +1274,8 @@ SearchResultProto IcingSearchEngine::Search( std::move(scoring_processor_or).ValueOrDie(); std::vector<ScoredDocumentHit> result_document_hits = scoring_processor->Score(std::move(query_results.root_iterator), - performance_configuration_.num_to_score); + performance_configuration_.num_to_score, + &query_results.query_term_iterators); query_stats->set_scoring_latency_ms( component_timer->GetElapsedMilliseconds()); query_stats->set_num_documents_scored(result_document_hits.size()); @@ -1416,7 +1454,8 @@ libtextclassifier3::Status IcingSearchEngine::OptimizeDocumentStore() { } // Copies valid document data to tmp directory - auto optimize_status = document_store_->OptimizeInto(temporary_document_dir); + auto optimize_status = document_store_->OptimizeInto( + temporary_document_dir, language_segmenter_.get()); // Handles error if any if (!optimize_status.ok()) { @@ -1523,9 +1562,9 @@ libtextclassifier3::Status IcingSearchEngine::RestoreIndexIfNeeded() { ICING_ASSIGN_OR_RETURN( std::unique_ptr<IndexProcessor> index_processor, - IndexProcessor::Create( - schema_store_.get(), language_segmenter_.get(), normalizer_.get(), - index_.get(), CreateIndexProcessorOptions(options_), clock_.get())); + IndexProcessor::Create(normalizer_.get(), index_.get(), + CreateIndexProcessorOptions(options_), + clock_.get())); ICING_VLOG(1) << "Restoring index by replaying documents from document id " << first_document_to_reindex << " to document id " @@ -1546,9 +1585,20 @@ libtextclassifier3::Status IcingSearchEngine::RestoreIndexIfNeeded() { return document_or.status(); } } + DocumentProto document(std::move(document_or).ValueOrDie()); + + libtextclassifier3::StatusOr<TokenizedDocument> tokenized_document_or = + TokenizedDocument::Create(schema_store_.get(), + language_segmenter_.get(), + std::move(document)); + if (!tokenized_document_or.ok()) { + return tokenized_document_or.status(); + } + TokenizedDocument tokenized_document( + std::move(tokenized_document_or).ValueOrDie()); libtextclassifier3::Status status = - index_processor->IndexDocument(document_or.ValueOrDie(), document_id); + index_processor->IndexDocument(tokenized_document, document_id); if (!status.ok()) { if (!absl_ports::IsDataLoss(status)) { // Real error. Stop recovering and pass it up. diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h index b2bb4f1..dfe56c4 100644 --- a/icing/icing-search-engine.h +++ b/icing/icing-search-engine.h @@ -36,6 +36,7 @@ #include "icing/proto/reset.pb.h" #include "icing/proto/schema.pb.h" #include "icing/proto/scoring.pb.h" +#include "icing/proto/search.proto.h" #include "icing/proto/search.pb.h" #include "icing/proto/usage.pb.h" #include "icing/result/result-state-manager.h" @@ -213,7 +214,8 @@ class IcingSearchEngine { // NOT_FOUND if the key doesn't exist or doc has been deleted // FAILED_PRECONDITION IcingSearchEngine has not been initialized yet // INTERNAL_ERROR on IO error - GetResultProto Get(std::string_view name_space, std::string_view uri); + GetResultProto Get(std::string_view name_space, std::string_view uri, + const GetResultSpecProto& result_spec); // Reports usage. The corresponding usage scores of the specified document in // the report will be updated. diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc index f4249f3..8c64614 100644 --- a/icing/icing-search-engine_test.cc +++ b/icing/icing-search-engine_test.cc @@ -54,10 +54,13 @@ namespace { using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::_; +using ::testing::ElementsAre; using ::testing::Eq; +using ::testing::Ge; using ::testing::Gt; using ::testing::HasSubstr; using ::testing::IsEmpty; +using ::testing::Le; using ::testing::Lt; using ::testing::Matcher; using ::testing::Ne; @@ -112,7 +115,6 @@ class IcingSearchEngineTest : public testing::Test { ICING_ASSERT_OK( icu_data_file_helper::SetUpICUDataFile(icu_data_file_path)); } - filesystem_.CreateDirectoryRecursively(GetTestBaseDir().c_str()); } @@ -156,6 +158,19 @@ DocumentProto CreateMessageDocument(std::string name_space, std::string uri) { .Build(); } +DocumentProto CreateEmailDocument(const std::string& name_space, + const std::string& uri, int score, + const std::string& subject_content, + const std::string& body_content) { + return DocumentBuilder() + .SetKey(name_space, uri) + .SetSchema("Email") + .SetScore(score) + .AddStringProperty("subject", subject_content) + .AddStringProperty("body", body_content) + .Build(); +} + SchemaProto CreateMessageSchema() { SchemaProto schema; auto type = schema.add_types(); @@ -265,6 +280,17 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri, return usage_report; } +std::vector<std::string> GetUrisFromSearchResults( + SearchResultProto& search_result_proto) { + std::vector<std::string> result_uris; + result_uris.reserve(search_result_proto.results_size()); + for (int i = 0; i < search_result_proto.results_size(); i++) { + result_uris.push_back( + search_result_proto.mutable_results(i)->document().uri()); + } + return result_uris; +} + TEST_F(IcingSearchEngineTest, SimpleInitialization) { IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); @@ -287,12 +313,14 @@ TEST_F(IcingSearchEngineTest, InitializingAgainSavesNonPersistedData) { expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document; - ASSERT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + ASSERT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); } TEST_F(IcingSearchEngineTest, MaxIndexMergeSizeReturnsInvalidArgument) { @@ -670,6 +698,194 @@ TEST_F(IcingSearchEngineTest, SetSchemaDelete) { } } +TEST_F(IcingSearchEngineTest, SetSchemaUnsetVersionIsZero) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // 1. Create a schema with an Email type with version 1 + SchemaProto schema; + SchemaTypeConfigProto* type = schema.add_types(); + type->set_schema_type("Email"); + PropertyConfigProto* property = type->add_properties(); + property->set_property_name("title"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::OPTIONAL); + + EXPECT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + + EXPECT_THAT(icing.GetSchema().schema().types(0).version(), Eq(0)); +} + +TEST_F(IcingSearchEngineTest, SetSchemaCompatibleVersionUpdateSucceeds) { + { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // 1. Create a schema with an Email type with version 1 + SchemaProto schema; + SchemaTypeConfigProto* type = schema.add_types(); + type->set_version(1); + type->set_schema_type("Email"); + PropertyConfigProto* property = type->add_properties(); + property->set_property_name("title"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::OPTIONAL); + + EXPECT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + + EXPECT_THAT(icing.GetSchema().schema().types(0).version(), Eq(1)); + } + + { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // 2. Create schema that adds a new optional property and updates version. + SchemaProto schema; + SchemaTypeConfigProto* type = schema.add_types(); + type->set_version(2); + type->set_schema_type("Email"); + PropertyConfigProto* property = type->add_properties(); + property->set_property_name("title"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::OPTIONAL); + property->set_property_name("body"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::OPTIONAL); + + // 3. SetSchema should succeed and the version number should be updated. + EXPECT_THAT(icing.SetSchema(schema, true).status(), ProtoIsOk()); + + EXPECT_THAT(icing.GetSchema().schema().types(0).version(), Eq(2)); + } +} + +TEST_F(IcingSearchEngineTest, SetSchemaIncompatibleVersionUpdateFails) { + { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // 1. Create a schema with an Email type with version 1 + SchemaProto schema; + SchemaTypeConfigProto* type = schema.add_types(); + type->set_version(1); + type->set_schema_type("Email"); + PropertyConfigProto* property = type->add_properties(); + property->set_property_name("title"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::OPTIONAL); + + EXPECT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + + EXPECT_THAT(icing.GetSchema().schema().types(0).version(), Eq(1)); + } + + { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // 2. Create schema that makes an incompatible change (OPTIONAL -> REQUIRED) + SchemaProto schema; + SchemaTypeConfigProto* type = schema.add_types(); + type->set_version(2); + type->set_schema_type("Email"); + PropertyConfigProto* property = type->add_properties(); + property->set_property_name("title"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::REQUIRED); + + // 3. SetSchema should fail and the version number should NOT be updated. + EXPECT_THAT(icing.SetSchema(schema).status(), + ProtoStatusIs(StatusProto::FAILED_PRECONDITION)); + + EXPECT_THAT(icing.GetSchema().schema().types(0).version(), Eq(1)); + } +} + +TEST_F(IcingSearchEngineTest, + SetSchemaIncompatibleVersionUpdateForceOverrideSucceeds) { + { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // 1. Create a schema with an Email type with version 1 + SchemaProto schema; + SchemaTypeConfigProto* type = schema.add_types(); + type->set_version(1); + type->set_schema_type("Email"); + PropertyConfigProto* property = type->add_properties(); + property->set_property_name("title"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::OPTIONAL); + + EXPECT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + + EXPECT_THAT(icing.GetSchema().schema().types(0).version(), Eq(1)); + } + + { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // 2. Create schema that makes an incompatible change (OPTIONAL -> REQUIRED) + // with force override to true. + SchemaProto schema; + SchemaTypeConfigProto* type = schema.add_types(); + type->set_version(2); + type->set_schema_type("Email"); + PropertyConfigProto* property = type->add_properties(); + property->set_property_name("title"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::REQUIRED); + + // 3. SetSchema should succeed and the version number should be updated. + EXPECT_THAT(icing.SetSchema(schema, true).status(), ProtoIsOk()); + + EXPECT_THAT(icing.GetSchema().schema().types(0).version(), Eq(2)); + } +} + +TEST_F(IcingSearchEngineTest, SetSchemaNoChangeVersionUpdateSucceeds) { + { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // 1. Create a schema with an Email type with version 1 + SchemaProto schema; + SchemaTypeConfigProto* type = schema.add_types(); + type->set_version(1); + type->set_schema_type("Email"); + PropertyConfigProto* property = type->add_properties(); + property->set_property_name("title"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::OPTIONAL); + + EXPECT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + + EXPECT_THAT(icing.GetSchema().schema().types(0).version(), Eq(1)); + } + + { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // 2. Create schema that only changes the version. + SchemaProto schema; + SchemaTypeConfigProto* type = schema.add_types(); + type->set_version(2); + type->set_schema_type("Email"); + PropertyConfigProto* property = type->add_properties(); + property->set_property_name("title"); + property->set_data_type(PropertyConfigProto::DataType::STRING); + property->set_cardinality(PropertyConfigProto::Cardinality::OPTIONAL); + + // 3. SetSchema should succeed and the version number should be updated. + EXPECT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + + EXPECT_THAT(icing.GetSchema().schema().types(0).version(), Eq(2)); + } +} + TEST_F(IcingSearchEngineTest, SetSchemaDuplicateTypesReturnsAlreadyExists) { IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); @@ -894,7 +1110,8 @@ TEST_F(IcingSearchEngineTest, SetSchemaRevalidatesDocumentsAndReturnsOk) { expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = email_document_with_subject; - EXPECT_THAT(icing.Get("namespace", "with_subject"), + EXPECT_THAT(icing.Get("namespace", "with_subject", + GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); // The document without a subject got deleted because it failed validation @@ -904,7 +1121,8 @@ TEST_F(IcingSearchEngineTest, SetSchemaRevalidatesDocumentsAndReturnsOk) { "Document (namespace, without_subject) not found."); expected_get_result_proto.clear_document(); - EXPECT_THAT(icing.Get("namespace", "without_subject"), + EXPECT_THAT(icing.Get("namespace", "without_subject", + GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); } @@ -962,7 +1180,8 @@ TEST_F(IcingSearchEngineTest, SetSchemaDeletesDocumentsAndReturnsOk) { expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = email_document; - EXPECT_THAT(icing.Get("namespace", "email_uri"), + EXPECT_THAT(icing.Get("namespace", "email_uri", + GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); // "message" document got deleted @@ -971,7 +1190,8 @@ TEST_F(IcingSearchEngineTest, SetSchemaDeletesDocumentsAndReturnsOk) { "Document (namespace, message_uri) not found."); expected_get_result_proto.clear_document(); - EXPECT_THAT(icing.Get("namespace", "message_uri"), + EXPECT_THAT(icing.Get("namespace", "message_uri", + GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); } @@ -1035,8 +1255,9 @@ TEST_F(IcingSearchEngineTest, GetDocument) { expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = CreateMessageDocument("namespace", "uri"); - ASSERT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + ASSERT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Put an invalid document PutResultProto put_result_proto = icing.Put(DocumentProto()); @@ -1050,7 +1271,208 @@ TEST_F(IcingSearchEngineTest, GetDocument) { expected_get_result_proto.mutable_status()->set_message( "Document (wrong, uri) not found."); expected_get_result_proto.clear_document(); - ASSERT_THAT(icing.Get("wrong", "uri"), + ASSERT_THAT(icing.Get("wrong", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); +} + +TEST_F(IcingSearchEngineTest, GetDocumentProjectionEmpty) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); + + DocumentProto document = CreateMessageDocument("namespace", "uri"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + GetResultSpecProto result_spec; + TypePropertyMask* mask = result_spec.add_type_property_masks(); + mask->set_schema_type(document.schema()); + mask->add_paths(""); + + GetResultProto expected_get_result_proto; + expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); + *expected_get_result_proto.mutable_document() = document; + expected_get_result_proto.mutable_document()->clear_properties(); + ASSERT_THAT(icing.Get("namespace", "uri", result_spec), + EqualsProto(expected_get_result_proto)); +} + +TEST_F(IcingSearchEngineTest, GetDocumentWildCardProjectionEmpty) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); + + DocumentProto document = CreateMessageDocument("namespace", "uri"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + GetResultSpecProto result_spec; + TypePropertyMask* mask = result_spec.add_type_property_masks(); + mask->set_schema_type("*"); + mask->add_paths(""); + + GetResultProto expected_get_result_proto; + expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); + *expected_get_result_proto.mutable_document() = document; + expected_get_result_proto.mutable_document()->clear_properties(); + ASSERT_THAT(icing.Get("namespace", "uri", result_spec), + EqualsProto(expected_get_result_proto)); +} + +TEST_F(IcingSearchEngineTest, GetDocumentProjectionMultipleFieldPaths) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreatePersonAndEmailSchema()).status(), + ProtoIsOk()); + + // 1. Add an email document + DocumentProto document = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .AddStringProperty("emailAddress", "shopgirl@aol.com") + .Build()) + .AddStringProperty("subject", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + GetResultSpecProto result_spec; + TypePropertyMask* mask = result_spec.add_type_property_masks(); + mask->set_schema_type("Email"); + mask->add_paths("sender.name"); + mask->add_paths("subject"); + + // 2. Verify that the returned result only contains the 'sender.name' + // property and the 'subject' property. + GetResultProto expected_get_result_proto; + expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); + *expected_get_result_proto.mutable_document() = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty("sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .Build()) + .AddStringProperty("subject", "Hello World!") + .Build(); + ASSERT_THAT(icing.Get("namespace", "uri1", result_spec), + EqualsProto(expected_get_result_proto)); +} + +TEST_F(IcingSearchEngineTest, GetDocumentWildcardProjectionMultipleFieldPaths) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreatePersonAndEmailSchema()).status(), + ProtoIsOk()); + + // 1. Add an email document + DocumentProto document = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .AddStringProperty("emailAddress", "shopgirl@aol.com") + .Build()) + .AddStringProperty("subject", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + GetResultSpecProto result_spec; + TypePropertyMask* mask = result_spec.add_type_property_masks(); + mask->set_schema_type("*"); + mask->add_paths("sender.name"); + mask->add_paths("subject"); + + // 2. Verify that the returned result only contains the 'sender.name' + // property and the 'subject' property. + GetResultProto expected_get_result_proto; + expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); + *expected_get_result_proto.mutable_document() = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty("sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .Build()) + .AddStringProperty("subject", "Hello World!") + .Build(); + ASSERT_THAT(icing.Get("namespace", "uri1", result_spec), + EqualsProto(expected_get_result_proto)); +} + +TEST_F(IcingSearchEngineTest, + GetDocumentSpecificProjectionOverridesWildcardProjection) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreatePersonAndEmailSchema()).status(), + ProtoIsOk()); + + // 1. Add an email document + DocumentProto document = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddDocumentProperty( + "sender", + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Person") + .AddStringProperty("name", "Meg Ryan") + .AddStringProperty("emailAddress", "shopgirl@aol.com") + .Build()) + .AddStringProperty("subject", "Hello World!") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + // 2. Add type property masks for the wildcard and the specific type of the + // document 'Email'. The wildcard should be ignored and only the 'Email' + // projection should apply. + GetResultSpecProto result_spec; + TypePropertyMask* mask = result_spec.add_type_property_masks(); + mask->set_schema_type("*"); + mask->add_paths("subject"); + mask = result_spec.add_type_property_masks(); + mask->set_schema_type("Email"); + mask->add_paths("body"); + + // 3. Verify that the returned result only contains the 'body' property. + GetResultProto expected_get_result_proto; + expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); + *expected_get_result_proto.mutable_document() = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetCreationTimestampMs(1000) + .SetSchema("Email") + .AddStringProperty( + "body", "Oh what a beautiful morning! Oh what a beautiful day!") + .Build(); + ASSERT_THAT(icing.Get("namespace", "uri1", result_spec), EqualsProto(expected_get_result_proto)); } @@ -1593,16 +2015,18 @@ TEST_F(IcingSearchEngineTest, OptimizationShouldRemoveDeletedDocs) { filesystem()->GetFileSize(document_log_path.c_str()); // Validates that document can't be found right after Optimize() - EXPECT_THAT(icing.Get("namespace", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Validates that document is actually removed from document log EXPECT_THAT(document_log_size_after, Lt(document_log_size_before)); } // Destroys IcingSearchEngine to make sure nothing is cached. IcingSearchEngine icing(icing_options, GetTestJniCache()); EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); - EXPECT_THAT(icing.Get("namespace", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); } TEST_F(IcingSearchEngineTest, OptimizationShouldDeleteTemporaryDirectory) { @@ -1712,19 +2136,22 @@ TEST_F(IcingSearchEngineTest, GetAndPutShouldWorkAfterOptimization) { ASSERT_THAT(icing.Optimize().status(), ProtoIsOk()); // Validates that Get() and Put() are good right after Optimize() - EXPECT_THAT(icing.Get("namespace", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); EXPECT_THAT(icing.Put(document2).status(), ProtoIsOk()); } // Destroys IcingSearchEngine to make sure nothing is cached. IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); - EXPECT_THAT(icing.Get("namespace", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); EXPECT_THAT(icing.Put(document3).status(), ProtoIsOk()); } @@ -1748,14 +2175,16 @@ TEST_F(IcingSearchEngineTest, DeleteShouldWorkAfterOptimization) { StatusProto::NOT_FOUND); expected_get_result_proto.mutable_status()->set_message( "Document (namespace, uri1) not found."); - EXPECT_THAT(icing.Get("namespace", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); expected_get_result_proto.mutable_status()->clear_message(); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); } // Destroys IcingSearchEngine to make sure nothing is cached. IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); @@ -1766,13 +2195,15 @@ TEST_F(IcingSearchEngineTest, DeleteShouldWorkAfterOptimization) { expected_get_result_proto.mutable_status()->set_code(StatusProto::NOT_FOUND); expected_get_result_proto.mutable_status()->set_message( "Document (namespace, uri1) not found."); - EXPECT_THAT(icing.Get("namespace", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); expected_get_result_proto.mutable_status()->set_message( "Document (namespace, uri2) not found."); - EXPECT_THAT(icing.Get("namespace", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); } TEST_F(IcingSearchEngineTest, OptimizationFailureUninitializesIcing) { @@ -1832,7 +2263,10 @@ TEST_F(IcingSearchEngineTest, OptimizationFailureUninitializesIcing) { ProtoStatusIs(StatusProto::FAILED_PRECONDITION)); EXPECT_THAT(icing.Put(simple_doc).status(), ProtoStatusIs(StatusProto::FAILED_PRECONDITION)); - EXPECT_THAT(icing.Get(simple_doc.namespace_(), simple_doc.uri()).status(), + EXPECT_THAT(icing + .Get(simple_doc.namespace_(), simple_doc.uri(), + GetResultSpecProto::default_instance()) + .status(), ProtoStatusIs(StatusProto::FAILED_PRECONDITION)); EXPECT_THAT(icing.Search(search_spec, scoring_spec, result_spec).status(), ProtoStatusIs(StatusProto::FAILED_PRECONDITION)); @@ -1841,7 +2275,10 @@ TEST_F(IcingSearchEngineTest, OptimizationFailureUninitializesIcing) { EXPECT_THAT(icing.Reset().status(), ProtoIsOk()); EXPECT_THAT(icing.SetSchema(simple_schema).status(), ProtoIsOk()); EXPECT_THAT(icing.Put(simple_doc).status(), ProtoIsOk()); - EXPECT_THAT(icing.Get(simple_doc.namespace_(), simple_doc.uri()).status(), + EXPECT_THAT(icing + .Get(simple_doc.namespace_(), simple_doc.uri(), + GetResultSpecProto::default_instance()) + .status(), ProtoIsOk()); EXPECT_THAT(icing.Search(search_spec, scoring_spec, result_spec).status(), ProtoIsOk()); @@ -1900,12 +2337,14 @@ TEST_F(IcingSearchEngineTest, DeleteBySchemaType) { GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document1; - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Delete the first type. The first doc should be irretrievable. The // second should still be present. @@ -1922,14 +2361,16 @@ TEST_F(IcingSearchEngineTest, DeleteBySchemaType) { expected_get_result_proto.mutable_status()->set_message( "Document (namespace1, uri1) not found."); expected_get_result_proto.clear_document(); - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); expected_get_result_proto.mutable_status()->clear_message(); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Search for "message", only document2 should show up. SearchResultProto expected_search_result_proto; @@ -1976,12 +2417,14 @@ TEST_F(IcingSearchEngineTest, DeleteSchemaTypeByQuery) { GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document1; - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Delete the first type. The first doc should be irretrievable. The // second should still be present. @@ -1993,14 +2436,16 @@ TEST_F(IcingSearchEngineTest, DeleteSchemaTypeByQuery) { expected_get_result_proto.mutable_status()->set_message( "Document (namespace1, uri1) not found."); expected_get_result_proto.clear_document(); - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); expected_get_result_proto.mutable_status()->clear_message(); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); search_spec = SearchSpecProto::default_instance(); search_spec.set_query("message"); @@ -2055,16 +2500,19 @@ TEST_F(IcingSearchEngineTest, DeleteByNamespace) { GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document1; - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace1", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); *expected_get_result_proto.mutable_document() = document3; - EXPECT_THAT(icing.Get("namespace3", "uri3"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace3", "uri3", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Delete namespace1. Document1 and document2 should be irretrievable. // Document3 should still be present. @@ -2081,21 +2529,24 @@ TEST_F(IcingSearchEngineTest, DeleteByNamespace) { expected_get_result_proto.mutable_status()->set_message( "Document (namespace1, uri1) not found."); expected_get_result_proto.clear_document(); - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); expected_get_result_proto.mutable_status()->set_code(StatusProto::NOT_FOUND); expected_get_result_proto.mutable_status()->set_message( "Document (namespace1, uri2) not found."); expected_get_result_proto.clear_document(); - EXPECT_THAT(icing.Get("namespace1", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); expected_get_result_proto.mutable_status()->clear_message(); *expected_get_result_proto.mutable_document() = document3; - EXPECT_THAT(icing.Get("namespace3", "uri3"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace3", "uri3", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Search for "message", only document3 should show up. SearchResultProto expected_search_result_proto; @@ -2137,12 +2588,14 @@ TEST_F(IcingSearchEngineTest, DeleteNamespaceByQuery) { GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document1; - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Delete the first namespace. The first doc should be irretrievable. The // second should still be present. @@ -2154,14 +2607,16 @@ TEST_F(IcingSearchEngineTest, DeleteNamespaceByQuery) { expected_get_result_proto.mutable_status()->set_message( "Document (namespace1, uri1) not found."); expected_get_result_proto.clear_document(); - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); expected_get_result_proto.mutable_status()->clear_message(); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); search_spec = SearchSpecProto::default_instance(); search_spec.set_query("message"); @@ -2208,12 +2663,14 @@ TEST_F(IcingSearchEngineTest, DeleteByQuery) { GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document1; - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Delete all docs containing 'body1'. The first doc should be irretrievable. // The second should still be present. @@ -2232,14 +2689,16 @@ TEST_F(IcingSearchEngineTest, DeleteByQuery) { expected_get_result_proto.mutable_status()->set_message( "Document (namespace1, uri1) not found."); expected_get_result_proto.clear_document(); - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); expected_get_result_proto.mutable_status()->clear_message(); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); search_spec = SearchSpecProto::default_instance(); search_spec.set_query("message"); @@ -2281,12 +2740,14 @@ TEST_F(IcingSearchEngineTest, DeleteByQueryNotFound) { GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document1; - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Delete all docs containing 'foo', which should be none of them. Both docs // should still be present. @@ -2299,14 +2760,16 @@ TEST_F(IcingSearchEngineTest, DeleteByQueryNotFound) { expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); expected_get_result_proto.mutable_status()->clear_message(); *expected_get_result_proto.mutable_document() = document1; - EXPECT_THAT(icing.Get("namespace1", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace1", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); expected_get_result_proto.mutable_status()->clear_message(); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace2", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace2", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); search_spec = SearchSpecProto::default_instance(); search_spec.set_query("message"); @@ -2428,8 +2891,9 @@ TEST_F(IcingSearchEngineTest, IcingShouldWorkFineIfOptimizationIsAborted) { GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document1; - EXPECT_THAT(icing.Get("namespace", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); DocumentProto document2 = CreateMessageDocument("namespace", "uri2"); @@ -2484,8 +2948,9 @@ TEST_F(IcingSearchEngineTest, expected_get_result_proto.mutable_status()->set_code(StatusProto::NOT_FOUND); expected_get_result_proto.mutable_status()->set_message( "Document (namespace, uri) not found."); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); DocumentProto new_document = DocumentBuilder() @@ -2556,8 +3021,9 @@ TEST_F(IcingSearchEngineTest, OptimizationShouldRecoverIfDataFilesAreMissing) { expected_get_result_proto.mutable_status()->set_code(StatusProto::NOT_FOUND); expected_get_result_proto.mutable_status()->set_message( "Document (namespace, uri) not found."); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); DocumentProto new_document = DocumentBuilder() @@ -2804,8 +3270,9 @@ TEST_F(IcingSearchEngineTest, RecoverFromMissingHeaderFile) { EXPECT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); EXPECT_THAT(icing.Put(CreateMessageDocument("namespace", "uri")).status(), ProtoIsOk()); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); SearchResultProto search_result_proto = icing.Search(search_spec, GetDefaultScoringSpec(), ResultSpecProto::default_instance()); @@ -2820,8 +3287,9 @@ TEST_F(IcingSearchEngineTest, RecoverFromMissingHeaderFile) { EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); // Checks that DocumentLog is still ok - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Checks that the index is still ok so we can search over it SearchResultProto search_result_proto = @@ -2857,8 +3325,9 @@ TEST_F(IcingSearchEngineTest, RecoverFromInvalidHeaderMagic) { EXPECT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); EXPECT_THAT(icing.Put(CreateMessageDocument("namespace", "uri")).status(), ProtoIsOk()); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); SearchResultProto search_result_proto = icing.Search(search_spec, GetDefaultScoringSpec(), ResultSpecProto::default_instance()); @@ -2877,8 +3346,9 @@ TEST_F(IcingSearchEngineTest, RecoverFromInvalidHeaderMagic) { EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); // Checks that DocumentLog is still ok - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Checks that the index is still ok so we can search over it SearchResultProto search_result_proto = @@ -2914,8 +3384,9 @@ TEST_F(IcingSearchEngineTest, RecoverFromInvalidHeaderChecksum) { EXPECT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); EXPECT_THAT(icing.Put(CreateMessageDocument("namespace", "uri")).status(), ProtoIsOk()); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); SearchResultProto search_result_proto = icing.Search(search_spec, GetDefaultScoringSpec(), ResultSpecProto::default_instance()); @@ -2935,8 +3406,9 @@ TEST_F(IcingSearchEngineTest, RecoverFromInvalidHeaderChecksum) { EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); // Checks that DocumentLog is still ok - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Checks that the index is still ok so we can search over it SearchResultProto search_result_proto = @@ -2964,8 +3436,9 @@ TEST_F(IcingSearchEngineTest, UnableToRecoverFromCorruptSchema) { *expected_get_result_proto.mutable_document() = CreateMessageDocument("namespace", "uri"); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); } // This should shut down IcingSearchEngine and persist anything it needs to const std::string schema_file = @@ -2993,8 +3466,9 @@ TEST_F(IcingSearchEngineTest, UnableToRecoverFromCorruptDocumentLog) { *expected_get_result_proto.mutable_document() = CreateMessageDocument("namespace", "uri"); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); } // This should shut down IcingSearchEngine and persist anything it needs to const std::string document_log_file = @@ -3118,7 +3592,8 @@ TEST_F(IcingSearchEngineTest, RecoverFromInconsistentSchemaStore) { expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = email_document; - EXPECT_THAT(icing.Get("namespace", "email_uri"), + EXPECT_THAT(icing.Get("namespace", "email_uri", + GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); SearchSpecProto search_spec; @@ -3184,12 +3659,14 @@ TEST_F(IcingSearchEngineTest, RecoverFromInconsistentDocumentStore) { *expected_get_result_proto.mutable_document() = document1; // DocumentStore kept the additional document - EXPECT_THAT(icing.Get("namespace", "uri1"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); *expected_get_result_proto.mutable_document() = document2; - EXPECT_THAT(icing.Get("namespace", "uri2"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // We indexed the additional document SearchSpecProto search_spec; @@ -3666,6 +4143,392 @@ TEST_F(IcingSearchEngineTest, SearchResultShouldBeRankedByUsageTimestamp) { EqualsSearchResultIgnoreStats(expected_search_result_proto)); } +TEST_F(IcingSearchEngineTest, Bm25fRelevanceScoringOneNamespace) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); + EXPECT_THAT(icing.SetSchema(CreateEmailSchema()).status(), ProtoIsOk()); + + // Create and index documents in namespace "namespace1". + DocumentProto document = CreateEmailDocument( + "namespace1", "namespace1/uri0", /*score=*/10, "sushi belmont", + "fresh fish. inexpensive. good sushi."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri1", /*score=*/13, "peacock koriander", + "indian food. buffet. spicy food. kadai chicken."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri2", /*score=*/4, + "panda express", + "chinese food. cheap. inexpensive. kung pao."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri3", /*score=*/23, + "speederia pizza", + "thin-crust pizza. good and fast."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri4", /*score=*/8, + "whole foods", + "salads. pizza. organic food. expensive."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri5", /*score=*/18, "peets coffee", + "espresso. decaf. brewed coffee. whole beans. excellent coffee."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri6", /*score=*/4, "costco", + "bulk. cheap whole beans. frozen fish. food samples."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri7", /*score=*/4, + "starbucks coffee", + "habit. birthday rewards. good coffee"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); + search_spec.set_query("coffee OR food"); + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + SearchResultProto search_result_proto = icing.Search( + search_spec, scoring_spec, ResultSpecProto::default_instance()); + + // Result should be in descending score order + EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + // Both doc5 and doc7 have "coffee" in name and text sections. + // However, doc5 has more matches in the text section. + // Documents with "food" are ranked lower as the term "food" is commonly + // present in this corpus, and thus, has a lower IDF. + EXPECT_THAT(GetUrisFromSearchResults(search_result_proto), + ElementsAre("namespace1/uri5", // 'coffee' 3 times + "namespace1/uri7", // 'coffee' 2 times + "namespace1/uri1", // 'food' 2 times + "namespace1/uri4", // 'food' 2 times + "namespace1/uri2", // 'food' 1 time + "namespace1/uri6")); // 'food' 1 time +} + +TEST_F(IcingSearchEngineTest, Bm25fRelevanceScoringOneNamespaceNotOperator) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); + EXPECT_THAT(icing.SetSchema(CreateEmailSchema()).status(), ProtoIsOk()); + + // Create and index documents in namespace "namespace1". + DocumentProto document = CreateEmailDocument( + "namespace1", "namespace1/uri0", /*score=*/10, "sushi belmont", + "fresh fish. inexpensive. good sushi."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri1", /*score=*/13, "peacock koriander", + "indian food. buffet. spicy food. kadai chicken."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri2", /*score=*/4, + "panda express", + "chinese food. cheap. inexpensive. kung pao."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri3", /*score=*/23, "speederia pizza", + "thin-crust pizza. good and fast. nice coffee"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri4", /*score=*/8, + "whole foods", + "salads. pizza. organic food. expensive."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri5", /*score=*/18, "peets coffee", + "espresso. decaf. brewed coffee. whole beans. excellent coffee."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri6", /*score=*/4, "costco", + "bulk. cheap whole beans. frozen fish. food samples."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri7", /*score=*/4, + "starbucks coffee", + "habit. birthday rewards. good coffee"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); + search_spec.set_query("coffee -starbucks"); + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + SearchResultProto search_result_proto = icing.Search( + search_spec, scoring_spec, ResultSpecProto::default_instance()); + + // Result should be in descending score order + EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + EXPECT_THAT( + GetUrisFromSearchResults(search_result_proto), + ElementsAre("namespace1/uri5", // 'coffee' 3 times, 'starbucks' 0 times + "namespace1/uri3")); // 'coffee' 1 times, 'starbucks' 0 times +} + +TEST_F(IcingSearchEngineTest, + Bm25fRelevanceScoringOneNamespaceSectionRestrict) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); + EXPECT_THAT(icing.SetSchema(CreateEmailSchema()).status(), ProtoIsOk()); + + // Create and index documents in namespace "namespace1". + DocumentProto document = CreateEmailDocument( + "namespace1", "namespace1/uri0", /*score=*/10, "sushi belmont", + "fresh fish. inexpensive. good sushi."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri1", /*score=*/13, "peacock koriander", + "indian food. buffet. spicy food. kadai chicken."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri2", /*score=*/4, + "panda express", + "chinese food. cheap. inexpensive. kung pao."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri3", /*score=*/23, + "speederia pizza", + "thin-crust pizza. good and fast."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri4", /*score=*/8, + "whole foods", + "salads. pizza. organic food. expensive."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = + CreateEmailDocument("namespace1", "namespace1/uri5", /*score=*/18, + "peets coffee, best coffee", + "espresso. decaf. whole beans. excellent coffee."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri6", /*score=*/4, "costco", + "bulk. cheap whole beans. frozen fish. food samples."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri7", /*score=*/4, "starbucks", + "habit. birthday rewards. good coffee. brewed coffee"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); + search_spec.set_query("body:coffee OR body:food"); + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + SearchResultProto search_result_proto = icing.Search( + search_spec, scoring_spec, ResultSpecProto::default_instance()); + + // Result should be in descending score order, section restrict doesn't impact + // the BM25F score. + EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + // Both doc5 and doc7 have "coffee" in name and text sections. + // However, doc5 has more matches. + // Documents with "food" are ranked lower as the term "food" is commonly + // present in this corpus, and thus, has a lower IDF. + EXPECT_THAT( + GetUrisFromSearchResults(search_result_proto), + ElementsAre("namespace1/uri5", // 'coffee' 2 times in section subject, + // 1 time in section body + "namespace1/uri7", // 'coffee' 2 times in section body + "namespace1/uri1", // 'food' 2 times in section body + "namespace1/uri4", // 'food' 2 times in section body + "namespace1/uri2", // 'food' 1 time in section body + "namespace1/uri6")); // 'food' 1 time in section body +} + +TEST_F(IcingSearchEngineTest, Bm25fRelevanceScoringTwoNamespaces) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); + EXPECT_THAT(icing.SetSchema(CreateEmailSchema()).status(), ProtoIsOk()); + + // Create and index documents in namespace "namespace1". + DocumentProto document = CreateEmailDocument( + "namespace1", "namespace1/uri0", /*score=*/10, "sushi belmont", + "fresh fish. inexpensive. good sushi."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri1", /*score=*/13, "peacock koriander", + "indian food. buffet. spicy food. kadai chicken."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri2", /*score=*/4, + "panda express", + "chinese food. cheap. inexpensive. kung pao."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri3", /*score=*/23, + "speederia pizza", + "thin-crust pizza. good and fast."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri4", /*score=*/8, + "whole foods", + "salads. pizza. organic food. expensive."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri5", /*score=*/18, "peets coffee", + "espresso. decaf. brewed coffee. whole beans. excellent coffee."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri6", /*score=*/4, "costco", + "bulk. cheap whole beans. frozen fish. food samples."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri7", /*score=*/4, + "starbucks coffee", + "habit. birthday rewards. good coffee"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + // Create and index documents in namespace "namespace2". + document = CreateEmailDocument("namespace2", "namespace2/uri0", /*score=*/10, + "sushi belmont", + "fresh fish. inexpensive. good sushi."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace2", "namespace2/uri1", /*score=*/13, "peacock koriander", + "indian food. buffet. spicy food. kadai chicken."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace2", "namespace2/uri2", /*score=*/4, + "panda express", + "chinese food. cheap. inexpensive. kung pao."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace2", "namespace2/uri3", /*score=*/23, + "speederia pizza", + "thin-crust pizza. good and fast."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace2", "namespace2/uri4", /*score=*/8, + "whole foods", + "salads. pizza. organic food. expensive."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace2", "namespace2/uri5", /*score=*/18, "peets coffee", + "espresso. decaf. brewed coffee. whole beans. excellent coffee."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace2", "namespace2/uri6", /*score=*/4, "costco", + "bulk. cheap whole beans. frozen fish. food samples."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace2", "namespace2/uri7", /*score=*/4, + "starbucks coffee", "good coffee"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); + search_spec.set_query("coffee OR food"); + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ResultSpecProto result_spec_proto; + result_spec_proto.set_num_per_page(16); + SearchResultProto search_result_proto = + icing.Search(search_spec, scoring_spec, result_spec_proto); + + // Result should be in descending score order + EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + // The two corpora have the same documents except for document 7, which in + // "namespace2" is much shorter than the average dcoument length, so it is + // boosted. + EXPECT_THAT(GetUrisFromSearchResults(search_result_proto), + ElementsAre("namespace2/uri7", // 'coffee' 2 times, short doc + "namespace1/uri5", // 'coffee' 3 times + "namespace2/uri5", // 'coffee' 3 times + "namespace1/uri7", // 'coffee' 2 times + "namespace1/uri1", // 'food' 2 times + "namespace2/uri1", // 'food' 2 times + "namespace1/uri4", // 'food' 2 times + "namespace2/uri4", // 'food' 2 times + "namespace1/uri2", // 'food' 1 time + "namespace2/uri2", // 'food' 1 time + "namespace1/uri6", // 'food' 1 time + "namespace2/uri6")); // 'food' 1 time +} + +TEST_F(IcingSearchEngineTest, Bm25fRelevanceScoringWithNamespaceFilter) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); + EXPECT_THAT(icing.SetSchema(CreateEmailSchema()).status(), ProtoIsOk()); + + // Create and index documents in namespace "namespace1". + DocumentProto document = CreateEmailDocument( + "namespace1", "namespace1/uri0", /*score=*/10, "sushi belmont", + "fresh fish. inexpensive. good sushi."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri1", /*score=*/13, "peacock koriander", + "indian food. buffet. spicy food. kadai chicken."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri2", /*score=*/4, + "panda express", + "chinese food. cheap. inexpensive. kung pao."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri3", /*score=*/23, + "speederia pizza", + "thin-crust pizza. good and fast."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri4", /*score=*/8, + "whole foods", + "salads. pizza. organic food. expensive."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri5", /*score=*/18, "peets coffee", + "espresso. decaf. brewed coffee. whole beans. excellent coffee."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace1", "namespace1/uri6", /*score=*/4, "costco", + "bulk. cheap whole beans. frozen fish. food samples."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace1", "namespace1/uri7", /*score=*/4, + "starbucks coffee", + "habit. birthday rewards. good coffee"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + // Create and index documents in namespace "namespace2". + document = CreateEmailDocument("namespace2", "namespace2/uri0", /*score=*/10, + "sushi belmont", + "fresh fish. inexpensive. good sushi."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace2", "namespace2/uri1", /*score=*/13, "peacock koriander", + "indian food. buffet. spicy food. kadai chicken."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace2", "namespace2/uri2", /*score=*/4, + "panda express", + "chinese food. cheap. inexpensive. kung pao."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace2", "namespace2/uri3", /*score=*/23, + "speederia pizza", + "thin-crust pizza. good and fast."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace2", "namespace2/uri4", /*score=*/8, + "whole foods", + "salads. pizza. organic food. expensive."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace2", "namespace2/uri5", /*score=*/18, "peets coffee", + "espresso. decaf. brewed coffee. whole beans. excellent coffee."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument( + "namespace2", "namespace2/uri6", /*score=*/4, "costco", + "bulk. cheap whole beans. frozen fish. food samples."); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + document = CreateEmailDocument("namespace2", "namespace2/uri7", /*score=*/4, + "starbucks coffee", "good coffee"); + ASSERT_THAT(icing.Put(document).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); + search_spec.set_query("coffee OR food"); + // Now query only corpus 2 + search_spec.add_namespace_filters("namespace2"); + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + SearchResultProto search_result_proto = icing.Search( + search_spec, scoring_spec, ResultSpecProto::default_instance()); + search_result_proto = icing.Search(search_spec, scoring_spec, + ResultSpecProto::default_instance()); + + // Result from namespace "namespace2" should be in descending score order + EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + // Both doc5 and doc7 have "coffee" in name and text sections. + // Even though doc5 has more matches in the text section, doc7's length is + // much shorter than the average corpus's length, so it's being boosted. + // Documents with "food" are ranked lower as the term "food" is commonly + // present in this corpus, and thus, has a lower IDF. + EXPECT_THAT(GetUrisFromSearchResults(search_result_proto), + ElementsAre("namespace2/uri7", // 'coffee' 2 times, short doc + "namespace2/uri5", // 'coffee' 3 times + "namespace2/uri1", // 'food' 2 times + "namespace2/uri4", // 'food' 2 times + "namespace2/uri2", // 'food' 1 time + "namespace2/uri6")); // 'food' 1 time +} + TEST_F(IcingSearchEngineTest, SearchResultShouldHaveDefaultOrderWithoutUsageTimestamp) { IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); @@ -3922,8 +4785,9 @@ TEST_F(IcingSearchEngineTest, SetSchemaCanDetectPreviousSchemaWasLost) { expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document; - ASSERT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + ASSERT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Can search for it SearchResultProto expected_search_result_proto; @@ -3950,8 +4814,9 @@ TEST_F(IcingSearchEngineTest, SetSchemaCanDetectPreviousSchemaWasLost) { expected_get_result_proto.mutable_status()->set_message( "Document (namespace, uri) not found."); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); // Can't search for it SearchResultProto empty_result; @@ -3978,14 +4843,16 @@ TEST_F(IcingSearchEngineTest, PersistToDisk) { // Persisting shouldn't affect anything EXPECT_THAT(icing.PersistToDisk().status(), ProtoIsOk()); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); } // Destructing persists as well IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); - EXPECT_THAT(icing.Get("namespace", "uri"), - EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); } TEST_F(IcingSearchEngineTest, ResetOk) { @@ -4044,7 +4911,8 @@ TEST_F(IcingSearchEngineTest, ResetAbortedError) { GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); *expected_get_result_proto.mutable_document() = document; - EXPECT_THAT(icing.Get(document.namespace_(), document.uri()), + EXPECT_THAT(icing.Get(document.namespace_(), document.uri(), + GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); // Can add new data. @@ -4275,7 +5143,10 @@ TEST_F(IcingSearchEngineTest, UninitializedInstanceFailsSafely) { DocumentProto doc = CreateMessageDocument("namespace", "uri"); EXPECT_THAT(icing.Put(doc).status(), ProtoStatusIs(StatusProto::FAILED_PRECONDITION)); - EXPECT_THAT(icing.Get(doc.namespace_(), doc.uri()).status(), + EXPECT_THAT(icing + .Get(doc.namespace_(), doc.uri(), + GetResultSpecProto::default_instance()) + .status(), ProtoStatusIs(StatusProto::FAILED_PRECONDITION)); EXPECT_THAT(icing.Delete(doc.namespace_(), doc.uri()).status(), ProtoStatusIs(StatusProto::FAILED_PRECONDITION)); @@ -5310,11 +6181,13 @@ TEST_F(IcingSearchEngineTest, PutDocumentShouldLogFunctionLatency) { } TEST_F(IcingSearchEngineTest, PutDocumentShouldLogDocumentStoreStats) { - DocumentProto document = DocumentBuilder() - .SetKey("icing", "fake_type/0") - .SetSchema("Message") - .AddStringProperty("body", "message body") - .Build(); + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/0") + .SetSchema("Message") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .AddStringProperty("body", "message body") + .Build(); auto fake_clock = std::make_unique<FakeClock>(); fake_clock->SetTimerElapsedMilliseconds(10); @@ -5330,8 +6203,11 @@ TEST_F(IcingSearchEngineTest, PutDocumentShouldLogDocumentStoreStats) { EXPECT_THAT( put_result_proto.native_put_document_stats().document_store_latency_ms(), Eq(10)); - EXPECT_THAT(put_result_proto.native_put_document_stats().document_size(), - Eq(document.ByteSizeLong())); + size_t document_size = + put_result_proto.native_put_document_stats().document_size(); + EXPECT_THAT(document_size, Ge(document.ByteSizeLong())); + EXPECT_THAT(document_size, Le(document.ByteSizeLong() + + sizeof(DocumentProto::InternalFields))); } TEST_F(IcingSearchEngineTest, PutDocumentShouldLogIndexingStats) { @@ -5483,8 +6359,7 @@ TEST_F(IcingSearchEngineTest, SearchWithProjectionEmptyFieldPath) { // Retrieve only one result at a time to make sure that projection works when // retrieving all pages. result_spec.set_num_per_page(1); - ResultSpecProto::TypePropertyMask* email_field_mask = - result_spec.add_type_property_masks(); + TypePropertyMask* email_field_mask = result_spec.add_type_property_masks(); email_field_mask->set_schema_type("Email"); email_field_mask->add_paths(""); @@ -5568,8 +6443,7 @@ TEST_F(IcingSearchEngineTest, SearchWithProjectionMultipleFieldPaths) { // Retrieve only one result at a time to make sure that projection works when // retrieving all pages. result_spec.set_num_per_page(1); - ResultSpecProto::TypePropertyMask* email_field_mask = - result_spec.add_type_property_masks(); + TypePropertyMask* email_field_mask = result_spec.add_type_property_masks(); email_field_mask->set_schema_type("Email"); email_field_mask->add_paths("sender.name"); email_field_mask->add_paths("subject"); diff --git a/icing/index/hit/doc-hit-info.h b/icing/index/hit/doc-hit-info.h index 8171960..0be87d6 100644 --- a/icing/index/hit/doc-hit-info.h +++ b/icing/index/hit/doc-hit-info.h @@ -25,7 +25,7 @@ namespace icing { namespace lib { -// DocHitInfo provides a collapsed view of all hits for a specific term and doc. +// DocHitInfo provides a collapsed view of all hits for a specific doc. // Hits contain a document_id, section_id and a term frequency. The // information in multiple hits is collapse into a DocHitInfo by providing a // SectionIdMask of all sections that contained a hit for this term as well as @@ -36,7 +36,7 @@ class DocHitInfo { SectionIdMask hit_section_ids_mask = kSectionIdMaskNone) : document_id_(document_id_in), hit_section_ids_mask_(hit_section_ids_mask) { - memset(hit_term_frequency_, Hit::kDefaultTermFrequency, + memset(hit_term_frequency_, Hit::kNoTermFrequency, sizeof(hit_term_frequency_)); } diff --git a/icing/index/hit/doc-hit-info_test.cc b/icing/index/hit/doc-hit-info_test.cc index 15c0de9..36c1a06 100644 --- a/icing/index/hit/doc-hit-info_test.cc +++ b/icing/index/hit/doc-hit-info_test.cc @@ -34,13 +34,13 @@ constexpr DocumentId kSomeOtherDocumentId = 54; TEST(DocHitInfoTest, InitialMaxHitTermFrequencies) { DocHitInfo info(kSomeDocumentId); for (SectionId i = 0; i <= kMaxSectionId; ++i) { - EXPECT_THAT(info.hit_term_frequency(i), Eq(Hit::kDefaultTermFrequency)); + EXPECT_THAT(info.hit_term_frequency(i), Eq(Hit::kNoTermFrequency)); } } TEST(DocHitInfoTest, UpdateHitTermFrequenciesForTheFirstTime) { DocHitInfo info(kSomeDocumentId); - ASSERT_THAT(info.hit_term_frequency(3), Eq(Hit::kDefaultTermFrequency)); + ASSERT_THAT(info.hit_term_frequency(3), Eq(Hit::kNoTermFrequency)); // Updating a section for the first time, should change its hit // term_frequency diff --git a/icing/index/hit/hit.h b/icing/index/hit/hit.h index 525a5e5..ee1f64b 100644 --- a/icing/index/hit/hit.h +++ b/icing/index/hit/hit.h @@ -58,6 +58,7 @@ class Hit { static constexpr TermFrequency kMaxTermFrequency = std::numeric_limits<TermFrequency>::max(); static constexpr TermFrequency kDefaultTermFrequency = 1; + static constexpr TermFrequency kNoTermFrequency = 0; explicit Hit(Value value = kInvalidValue, TermFrequency term_frequency = kDefaultTermFrequency) diff --git a/icing/index/index-processor.cc b/icing/index/index-processor.cc index 892263b..d2f9d41 100644 --- a/icing/index/index-processor.cc +++ b/icing/index/index-processor.cc @@ -31,34 +31,30 @@ #include "icing/schema/section-manager.h" #include "icing/schema/section.h" #include "icing/store/document-id.h" -#include "icing/tokenization/language-segmenter.h" #include "icing/tokenization/token.h" #include "icing/tokenization/tokenizer-factory.h" #include "icing/tokenization/tokenizer.h" #include "icing/transform/normalizer.h" #include "icing/util/status-macros.h" +#include "icing/util/tokenized-document.h" namespace icing { namespace lib { libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>> -IndexProcessor::Create(const SchemaStore* schema_store, - const LanguageSegmenter* lang_segmenter, - const Normalizer* normalizer, Index* index, +IndexProcessor::Create(const Normalizer* normalizer, Index* index, const IndexProcessor::Options& options, const Clock* clock) { - ICING_RETURN_ERROR_IF_NULL(schema_store); - ICING_RETURN_ERROR_IF_NULL(lang_segmenter); ICING_RETURN_ERROR_IF_NULL(normalizer); ICING_RETURN_ERROR_IF_NULL(index); ICING_RETURN_ERROR_IF_NULL(clock); - return std::unique_ptr<IndexProcessor>(new IndexProcessor( - schema_store, lang_segmenter, normalizer, index, options, clock)); + return std::unique_ptr<IndexProcessor>( + new IndexProcessor(normalizer, index, options, clock)); } libtextclassifier3::Status IndexProcessor::IndexDocument( - const DocumentProto& document, DocumentId document_id, + const TokenizedDocument& tokenized_document, DocumentId document_id, NativePutDocumentStats* put_document_stats) { std::unique_ptr<Timer> index_timer = clock_.GetNewTimer(); @@ -68,54 +64,45 @@ libtextclassifier3::Status IndexProcessor::IndexDocument( "DocumentId %d must be greater than last added document_id %d", document_id, index_->last_added_document_id())); } - ICING_ASSIGN_OR_RETURN(std::vector<Section> sections, - schema_store_.ExtractSections(document)); uint32_t num_tokens = 0; libtextclassifier3::Status overall_status; - for (const Section& section : sections) { + for (const TokenizedSection& section : tokenized_document.sections()) { // TODO(b/152934343): pass real namespace ids in Index::Editor editor = index_->Edit(document_id, section.metadata.id, section.metadata.term_match_type, /*namespace_id=*/0); - for (std::string_view subcontent : section.content) { - ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer> tokenizer, - tokenizer_factory::CreateIndexingTokenizer( - section.metadata.tokenizer, &lang_segmenter_)); - ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer::Iterator> itr, - tokenizer->Tokenize(subcontent)); - while (itr->Advance()) { - if (++num_tokens > options_.max_tokens_per_document) { - // Index all tokens buffered so far. - editor.IndexAllBufferedTerms(); - if (put_document_stats != nullptr) { - put_document_stats->mutable_tokenization_stats() - ->set_exceeded_max_token_num(true); - put_document_stats->mutable_tokenization_stats() - ->set_num_tokens_indexed(options_.max_tokens_per_document); - } - switch (options_.token_limit_behavior) { - case Options::TokenLimitBehavior::kReturnError: - return absl_ports::ResourceExhaustedError( - "Max number of tokens reached!"); - case Options::TokenLimitBehavior::kSuppressError: - return overall_status; - } + for (std::string_view token : section.token_sequence) { + if (++num_tokens > options_.max_tokens_per_document) { + // Index all tokens buffered so far. + editor.IndexAllBufferedTerms(); + if (put_document_stats != nullptr) { + put_document_stats->mutable_tokenization_stats() + ->set_exceeded_max_token_num(true); + put_document_stats->mutable_tokenization_stats() + ->set_num_tokens_indexed(options_.max_tokens_per_document); } - std::string term = normalizer_.NormalizeTerm(itr->GetToken().text); - // Add this term to Hit buffer. Even if adding this hit fails, we keep - // trying to add more hits because it's possible that future hits could - // still be added successfully. For instance if the lexicon is full, we - // might fail to add a hit for a new term, but should still be able to - // add hits for terms that are already in the index. - auto status = editor.BufferTerm(term.c_str()); - if (overall_status.ok() && !status.ok()) { - // If we've succeeded to add everything so far, set overall_status to - // represent this new failure. If we've already failed, no need to - // update the status - we're already going to return a resource - // exhausted error. - overall_status = status; + switch (options_.token_limit_behavior) { + case Options::TokenLimitBehavior::kReturnError: + return absl_ports::ResourceExhaustedError( + "Max number of tokens reached!"); + case Options::TokenLimitBehavior::kSuppressError: + return overall_status; } } + std::string term = normalizer_.NormalizeTerm(token); + // Add this term to Hit buffer. Even if adding this hit fails, we keep + // trying to add more hits because it's possible that future hits could + // still be added successfully. For instance if the lexicon is full, we + // might fail to add a hit for a new term, but should still be able to + // add hits for terms that are already in the index. + auto status = editor.BufferTerm(term.c_str()); + if (overall_status.ok() && !status.ok()) { + // If we've succeeded to add everything so far, set overall_status to + // represent this new failure. If we've already failed, no need to + // update the status - we're already going to return a resource + // exhausted error. + overall_status = status; + } } // Add all the seen terms to the index with their term frequency. auto status = editor.IndexAllBufferedTerms(); diff --git a/icing/index/index-processor.h b/icing/index/index-processor.h index 2eb4ad8..9fc7c46 100644 --- a/icing/index/index-processor.h +++ b/icing/index/index-processor.h @@ -21,12 +21,11 @@ #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/index/index.h" #include "icing/proto/document.pb.h" -#include "icing/schema/schema-store.h" #include "icing/schema/section-manager.h" #include "icing/store/document-id.h" -#include "icing/tokenization/language-segmenter.h" #include "icing/tokenization/token.h" #include "icing/transform/normalizer.h" +#include "icing/util/tokenized-document.h" namespace icing { namespace lib { @@ -58,14 +57,13 @@ class IndexProcessor { // An IndexProcessor on success // FAILED_PRECONDITION if any of the pointers is null. static libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>> Create( - const SchemaStore* schema_store, const LanguageSegmenter* lang_segmenter, const Normalizer* normalizer, Index* index, const Options& options, const Clock* clock); - // Add document to the index, associated with document_id. If the number of - // tokens in the document exceeds max_tokens_per_document, then only the first - // max_tokens_per_document will be added to the index. All tokens of length - // exceeding max_token_length will be shortened to max_token_length. + // Add tokenized document to the index, associated with document_id. If the + // number of tokens in the document exceeds max_tokens_per_document, then only + // the first max_tokens_per_document will be added to the index. All tokens of + // length exceeding max_token_length will be shortened to max_token_length. // // Indexing a document *may* trigger an index merge. If a merge fails, then // all content in the index will be lost. @@ -82,25 +80,19 @@ class IndexProcessor { // NOT_FOUND if there is no definition for the document's schema type. // INTERNAL_ERROR if any other errors occur libtextclassifier3::Status IndexDocument( - const DocumentProto& document, DocumentId document_id, + const TokenizedDocument& tokenized_document, DocumentId document_id, NativePutDocumentStats* put_document_stats = nullptr); private: - IndexProcessor(const SchemaStore* schema_store, - const LanguageSegmenter* lang_segmenter, - const Normalizer* normalizer, Index* index, + IndexProcessor(const Normalizer* normalizer, Index* index, const Options& options, const Clock* clock) - : schema_store_(*schema_store), - lang_segmenter_(*lang_segmenter), - normalizer_(*normalizer), + : normalizer_(*normalizer), index_(index), options_(options), clock_(*clock) {} std::string NormalizeToken(const Token& token); - const SchemaStore& schema_store_; - const LanguageSegmenter& lang_segmenter_; const Normalizer& normalizer_; Index* const index_; const Options options_; diff --git a/icing/index/index-processor_benchmark.cc b/icing/index/index-processor_benchmark.cc index 96a390b..afeac4d 100644 --- a/icing/index/index-processor_benchmark.cc +++ b/icing/index/index-processor_benchmark.cc @@ -31,6 +31,7 @@ #include "icing/transform/normalizer-factory.h" #include "icing/transform/normalizer.h" #include "icing/util/logging.h" +#include "icing/util/tokenized-document.h" #include "unicode/uloc.h" // Run on a Linux workstation: @@ -168,16 +169,13 @@ void CleanUp(const Filesystem& filesystem, const std::string& index_dir) { } std::unique_ptr<IndexProcessor> CreateIndexProcessor( - const SchemaStore* schema_store, - const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, - Index* index, const Clock* clock) { + const Normalizer* normalizer, Index* index, const Clock* clock) { IndexProcessor::Options processor_options{}; processor_options.max_tokens_per_document = 1024 * 1024 * 10; processor_options.token_limit_behavior = IndexProcessor::Options::TokenLimitBehavior::kReturnError; - return IndexProcessor::Create(schema_store, language_segmenter, normalizer, - index, processor_options, clock) + return IndexProcessor::Create(normalizer, index, processor_options, clock) .ValueOrDie(); } @@ -203,15 +201,18 @@ void BM_IndexDocumentWithOneProperty(benchmark::State& state) { Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(schema_store.get(), language_segmenter.get(), - normalizer.get(), index.get(), &clock); + CreateIndexProcessor(normalizer.get(), index.get(), &clock); DocumentProto input_document = CreateDocumentWithOneProperty(state.range(0)); + TokenizedDocument tokenized_document(std::move( + TokenizedDocument::Create(schema_store.get(), language_segmenter.get(), + input_document) + .ValueOrDie())); DocumentId document_id = 0; for (auto _ : state) { ICING_ASSERT_OK( - index_processor->IndexDocument(input_document, document_id++)); + index_processor->IndexDocument(tokenized_document, document_id++)); } CleanUp(filesystem, index_dir); @@ -254,16 +255,19 @@ void BM_IndexDocumentWithTenProperties(benchmark::State& state) { Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(schema_store.get(), language_segmenter.get(), - normalizer.get(), index.get(), &clock); + CreateIndexProcessor(normalizer.get(), index.get(), &clock); DocumentProto input_document = CreateDocumentWithTenProperties(state.range(0)); + TokenizedDocument tokenized_document(std::move( + TokenizedDocument::Create(schema_store.get(), language_segmenter.get(), + input_document) + .ValueOrDie())); DocumentId document_id = 0; for (auto _ : state) { ICING_ASSERT_OK( - index_processor->IndexDocument(input_document, document_id++)); + index_processor->IndexDocument(tokenized_document, document_id++)); } CleanUp(filesystem, index_dir); @@ -306,16 +310,19 @@ void BM_IndexDocumentWithDiacriticLetters(benchmark::State& state) { Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(schema_store.get(), language_segmenter.get(), - normalizer.get(), index.get(), &clock); + CreateIndexProcessor(normalizer.get(), index.get(), &clock); DocumentProto input_document = CreateDocumentWithDiacriticLetters(state.range(0)); + TokenizedDocument tokenized_document(std::move( + TokenizedDocument::Create(schema_store.get(), language_segmenter.get(), + input_document) + .ValueOrDie())); DocumentId document_id = 0; for (auto _ : state) { ICING_ASSERT_OK( - index_processor->IndexDocument(input_document, document_id++)); + index_processor->IndexDocument(tokenized_document, document_id++)); } CleanUp(filesystem, index_dir); @@ -358,15 +365,18 @@ void BM_IndexDocumentWithHiragana(benchmark::State& state) { Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(schema_store.get(), language_segmenter.get(), - normalizer.get(), index.get(), &clock); + CreateIndexProcessor(normalizer.get(), index.get(), &clock); DocumentProto input_document = CreateDocumentWithHiragana(state.range(0)); + TokenizedDocument tokenized_document(std::move( + TokenizedDocument::Create(schema_store.get(), language_segmenter.get(), + input_document) + .ValueOrDie())); DocumentId document_id = 0; for (auto _ : state) { ICING_ASSERT_OK( - index_processor->IndexDocument(input_document, document_id++)); + index_processor->IndexDocument(tokenized_document, document_id++)); } CleanUp(filesystem, index_dir); diff --git a/icing/index/index-processor_test.cc b/icing/index/index-processor_test.cc index bdd9575..e6bb615 100644 --- a/icing/index/index-processor_test.cc +++ b/icing/index/index-processor_test.cc @@ -53,6 +53,7 @@ #include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer-factory.h" #include "icing/transform/normalizer.h" +#include "icing/util/tokenized-document.h" #include "unicode/uloc.h" namespace icing { @@ -140,8 +141,7 @@ class IndexProcessorTest : public Test { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - normalizer_.get(), index_.get(), + IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, &fake_clock_)); mock_icing_filesystem_ = std::make_unique<IcingMockFilesystem>(); } @@ -195,7 +195,7 @@ class IndexProcessorTest : public Test { type_config->set_schema_type(std::string(kFakeType)); AddStringProperty(std::string(kExactProperty), DataType::STRING, - Cardinality::REQUIRED, TermMatchType::EXACT_ONLY, + Cardinality::OPTIONAL, TermMatchType::EXACT_ONLY, type_config); AddStringProperty(std::string(kPrefixedProperty), DataType::STRING, @@ -244,25 +244,11 @@ TEST_F(IndexProcessorTest, CreationWithNullPointerShouldFail) { processor_options.token_limit_behavior = IndexProcessor::Options::TokenLimitBehavior::kReturnError; - EXPECT_THAT( - IndexProcessor::Create(/*schema_store=*/nullptr, lang_segmenter_.get(), - normalizer_.get(), index_.get(), processor_options, - &fake_clock_), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); - - EXPECT_THAT( - IndexProcessor::Create(schema_store_.get(), /*lang_segmenter=*/nullptr, - normalizer_.get(), index_.get(), processor_options, - &fake_clock_), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); - - EXPECT_THAT(IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - /*normalizer=*/nullptr, index_.get(), + EXPECT_THAT(IndexProcessor::Create(/*normalizer=*/nullptr, index_.get(), processor_options, &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); - EXPECT_THAT(IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - normalizer_.get(), /*index=*/nullptr, + EXPECT_THAT(IndexProcessor::Create(normalizer_.get(), /*index=*/nullptr, processor_options, &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } @@ -276,7 +262,12 @@ TEST_F(IndexProcessorTest, NoTermMatchTypeContent) { .AddBytesProperty(std::string(kUnindexedProperty2), "attachment bytes") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kInvalidDocumentId)); } @@ -287,7 +278,12 @@ TEST_F(IndexProcessorTest, OneDoc) { .SetSchema(std::string(kFakeType)) .AddStringProperty(std::string(kExactProperty), "hello world") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, @@ -313,7 +309,12 @@ TEST_F(IndexProcessorTest, MultipleDocs) { .AddStringProperty(std::string(kExactProperty), "hello world") .AddStringProperty(std::string(kPrefixedProperty), "good night moon!") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); std::string coffeeRepeatedString = "coffee"; @@ -329,7 +330,12 @@ TEST_F(IndexProcessorTest, MultipleDocs) { .AddStringProperty(std::string(kPrefixedProperty), "mr. world world wide") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId1), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId1), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId1)); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, @@ -372,11 +378,18 @@ TEST_F(IndexProcessorTest, DocWithNestedProperty) { .AddDocumentProperty( std::string(kSubProperty), DocumentBuilder() + .SetKey("icing", "nested_type/1") + .SetSchema(std::string(kNestedType)) .AddStringProperty(std::string(kNestedProperty), "rocky raccoon") .Build()) .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, @@ -396,7 +409,12 @@ TEST_F(IndexProcessorTest, DocWithRepeatedProperty) { .AddStringProperty(std::string(kRepeatedProperty), "rocky", "italian stallion") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, @@ -417,8 +435,7 @@ TEST_F(IndexProcessorTest, TooManyTokensReturnError) { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - normalizer_.get(), index_.get(), options, + IndexProcessor::Create(normalizer_.get(), index_.get(), options, &fake_clock_)); DocumentProto document = @@ -428,7 +445,11 @@ TEST_F(IndexProcessorTest, TooManyTokensReturnError) { .AddStringProperty(std::string(kExactProperty), "hello world") .AddStringProperty(std::string(kPrefixedProperty), "good night moon!") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); @@ -457,8 +478,7 @@ TEST_F(IndexProcessorTest, TooManyTokensSuppressError) { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - normalizer_.get(), index_.get(), options, + IndexProcessor::Create(normalizer_.get(), index_.get(), options, &fake_clock_)); DocumentProto document = @@ -468,7 +488,12 @@ TEST_F(IndexProcessorTest, TooManyTokensSuppressError) { .AddStringProperty(std::string(kExactProperty), "hello world") .AddStringProperty(std::string(kPrefixedProperty), "good night moon!") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); // "night" should have been indexed. @@ -498,8 +523,7 @@ TEST_F(IndexProcessorTest, TooLongTokens) { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - normalizer.get(), index_.get(), options, + IndexProcessor::Create(normalizer.get(), index_.get(), options, &fake_clock_)); DocumentProto document = @@ -509,7 +533,12 @@ TEST_F(IndexProcessorTest, TooLongTokens) { .AddStringProperty(std::string(kExactProperty), "hello world") .AddStringProperty(std::string(kPrefixedProperty), "good night moon!") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); // "good" should have been indexed normally. @@ -542,7 +571,12 @@ TEST_F(IndexProcessorTest, NonPrefixedContentPrefixQuery) { .SetSchema(std::string(kFakeType)) .AddStringProperty(std::string(kExactProperty), "best rocky movies") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); document = @@ -551,7 +585,12 @@ TEST_F(IndexProcessorTest, NonPrefixedContentPrefixQuery) { .SetSchema(std::string(kFakeType)) .AddStringProperty(std::string(kPrefixedProperty), "rocky raccoon") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId1), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId1), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId1)); // Only document_id 1 should surface in a prefix query for "Rock" @@ -570,7 +609,12 @@ TEST_F(IndexProcessorTest, TokenNormalization) { .SetSchema(std::string(kFakeType)) .AddStringProperty(std::string(kExactProperty), "ALL UPPER CASE") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); document = @@ -579,7 +623,12 @@ TEST_F(IndexProcessorTest, TokenNormalization) { .SetSchema(std::string(kFakeType)) .AddStringProperty(std::string(kExactProperty), "all lower case") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId1), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId1), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId1)); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, @@ -600,7 +649,12 @@ TEST_F(IndexProcessorTest, OutOfOrderDocumentIds) { .SetSchema(std::string(kFakeType)) .AddStringProperty(std::string(kExactProperty), "ALL UPPER CASE") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId1), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId1), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId1)); // Indexing a document with document_id < last_added_document_id should cause @@ -611,11 +665,15 @@ TEST_F(IndexProcessorTest, OutOfOrderDocumentIds) { .SetSchema(std::string(kFakeType)) .AddStringProperty(std::string(kExactProperty), "all lower case") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), + ICING_ASSERT_OK_AND_ASSIGN( + tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); // As should indexing a document document_id == last_added_document_id. - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId1)); @@ -635,8 +693,7 @@ TEST_F(IndexProcessorTest, NonAsciiIndexing) { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - normalizer_.get(), index_.get(), + IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, &fake_clock_)); DocumentProto document = @@ -646,7 +703,12 @@ TEST_F(IndexProcessorTest, NonAsciiIndexing) { .AddStringProperty(std::string(kExactProperty), "你好,世界!你好:世界。“你好”世界?") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, @@ -666,8 +728,7 @@ TEST_F(IndexProcessorTest, ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - normalizer_.get(), index_.get(), processor_options, + IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, &fake_clock_)); // This is the maximum token length that an empty lexicon constructed for a @@ -684,7 +745,11 @@ TEST_F(IndexProcessorTest, absl_ports::StrCat(enormous_string, " foo")) .AddStringProperty(std::string(kPrefixedProperty), "bar baz") .Build(); - EXPECT_THAT(index_processor_->IndexDocument(document, kDocumentId0), + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); @@ -715,6 +780,10 @@ TEST_F(IndexProcessorTest, IndexingDocAutomaticMerge) { .SetSchema(std::string(kFakeType)) .AddStringProperty(std::string(kExactProperty), kIpsumText) .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); Index::Options options(index_dir_, /*index_merge_size=*/document.ByteSizeLong() * 100); ICING_ASSERT_OK_AND_ASSIGN( @@ -727,8 +796,7 @@ TEST_F(IndexProcessorTest, IndexingDocAutomaticMerge) { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - normalizer_.get(), index_.get(), processor_options, + IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, &fake_clock_)); DocumentId doc_id = 0; // Have determined experimentally that indexing 3373 documents with this text @@ -737,10 +805,12 @@ TEST_F(IndexProcessorTest, IndexingDocAutomaticMerge) { // empties the LiteIndex. constexpr int kNumDocsLiteIndexExhaustion = 3373; for (; doc_id < kNumDocsLiteIndexExhaustion; ++doc_id) { - EXPECT_THAT(index_processor_->IndexDocument(document, doc_id), IsOk()); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, doc_id), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(doc_id)); } - EXPECT_THAT(index_processor_->IndexDocument(document, doc_id), IsOk()); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, doc_id), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(doc_id)); } @@ -768,6 +838,10 @@ TEST_F(IndexProcessorTest, IndexingDocMergeFailureResets) { .SetSchema(std::string(kFakeType)) .AddStringProperty(std::string(kPrefixedProperty), kIpsumText) .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); // 2. Recreate the index with the mock filesystem and a merge size that will // only allow one document to be added before requiring a merge. @@ -784,25 +858,26 @@ TEST_F(IndexProcessorTest, IndexingDocMergeFailureResets) { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(schema_store_.get(), lang_segmenter_.get(), - normalizer_.get(), index_.get(), processor_options, + IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, &fake_clock_)); // 3. Index one document. This should fit in the LiteIndex without requiring a // merge. DocumentId doc_id = 0; - EXPECT_THAT(index_processor_->IndexDocument(document, doc_id), IsOk()); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, doc_id), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(doc_id)); // 4. Add one more document to trigger a merge, which should fail and result // in a Reset. ++doc_id; - EXPECT_THAT(index_processor_->IndexDocument(document, doc_id), + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, doc_id), StatusIs(libtextclassifier3::StatusCode::DATA_LOSS)); EXPECT_THAT(index_->last_added_document_id(), Eq(kInvalidDocumentId)); // 5. Indexing a new document should succeed. - EXPECT_THAT(index_processor_->IndexDocument(document, doc_id), IsOk()); + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, doc_id), + IsOk()); EXPECT_THAT(index_->last_added_document_id(), Eq(doc_id)); } diff --git a/icing/index/iterator/doc-hit-info-iterator-and.h b/icing/index/iterator/doc-hit-info-iterator-and.h index 4618fb9..faca785 100644 --- a/icing/index/iterator/doc-hit-info-iterator-and.h +++ b/icing/index/iterator/doc-hit-info-iterator-and.h @@ -46,6 +46,16 @@ class DocHitInfoIteratorAnd : public DocHitInfoIterator { std::string ToString() const override; + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo> *matched_terms_stats) const override { + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + // Current hit isn't valid, return. + return; + } + short_->PopulateMatchedTermsStats(matched_terms_stats); + long_->PopulateMatchedTermsStats(matched_terms_stats); + } + private: std::unique_ptr<DocHitInfoIterator> short_; std::unique_ptr<DocHitInfoIterator> long_; @@ -67,6 +77,17 @@ class DocHitInfoIteratorAndNary : public DocHitInfoIterator { std::string ToString() const override; + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo> *matched_terms_stats) const override { + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + // Current hit isn't valid, return. + return; + } + for (size_t i = 0; i < iterators_.size(); ++i) { + iterators_.at(i)->PopulateMatchedTermsStats(matched_terms_stats); + } + } + private: std::vector<std::unique_ptr<DocHitInfoIterator>> iterators_; }; diff --git a/icing/index/iterator/doc-hit-info-iterator-and_test.cc b/icing/index/iterator/doc-hit-info-iterator-and_test.cc index 35574b7..783e937 100644 --- a/icing/index/iterator/doc-hit-info-iterator-and_test.cc +++ b/icing/index/iterator/doc-hit-info-iterator-and_test.cc @@ -32,8 +32,10 @@ namespace lib { namespace { using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::IsEmpty; +using ::testing::SizeIs; TEST(CreateAndIteratorTest, And) { // Basic test that we can create a working And iterator. Further testing of @@ -196,6 +198,125 @@ TEST(DocHitInfoIteratorAndTest, SectionIdMask) { EXPECT_THAT(and_iter.hit_intersect_section_ids_mask(), Eq(mask_anded_result)); } +TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) { + { + // Arbitrary section ids for the documents in the DocHitInfoIterators. + // Created to test correct section_id_mask behavior. + SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ + 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; + SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2{ + 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + DocHitInfo doc_hit_info1 = DocHitInfo(4); + doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); + doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3); + doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + + DocHitInfo doc_hit_info2 = DocHitInfo(4); + doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); + doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + + std::vector<DocHitInfo> first_vector = {doc_hit_info1}; + std::vector<DocHitInfo> second_vector = {doc_hit_info2}; + + auto first_iter = + std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); + first_iter->set_hit_intersect_section_ids_mask(section_id_mask1); + + auto second_iter = + std::make_unique<DocHitInfoIteratorDummy>(second_vector, "hello"); + second_iter->set_hit_intersect_section_ids_mask(section_id_mask2); + + DocHitInfoIteratorAnd and_iter(std::move(first_iter), + std::move(second_iter)); + std::vector<TermMatchInfo> matched_terms_stats; + and_iter.PopulateMatchedTermsStats(&matched_terms_stats); + EXPECT_THAT(matched_terms_stats, IsEmpty()); + + ICING_EXPECT_OK(and_iter.Advance()); + EXPECT_THAT(and_iter.doc_hit_info().document_id(), Eq(4)); + + and_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); + EXPECT_EQ(matched_terms_stats.at(1).term, "hello"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies1)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies2)); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask2); + + EXPECT_FALSE(and_iter.Advance().ok()); + } + { + // Arbitrary section ids for the documents in the DocHitInfoIterators. + // Created to test correct section_id_mask behavior. + SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ + 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + DocHitInfo doc_hit_info1 = DocHitInfo(4); + doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); + + std::vector<DocHitInfo> first_vector = {doc_hit_info1}; + std::vector<DocHitInfo> second_vector = {doc_hit_info1}; + + auto first_iter = + std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); + first_iter->set_hit_intersect_section_ids_mask(section_id_mask1); + + auto second_iter = + std::make_unique<DocHitInfoIteratorDummy>(second_vector, "hi"); + second_iter->set_hit_intersect_section_ids_mask(section_id_mask1); + + DocHitInfoIteratorAnd and_iter(std::move(first_iter), + std::move(second_iter)); + std::vector<TermMatchInfo> matched_terms_stats; + and_iter.PopulateMatchedTermsStats(&matched_terms_stats); + EXPECT_THAT(matched_terms_stats, IsEmpty()); + + ICING_EXPECT_OK(and_iter.Advance()); + EXPECT_THAT(and_iter.doc_hit_info().document_id(), Eq(4)); + + and_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies1)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); + + EXPECT_FALSE(and_iter.Advance().ok()); + } +} + +TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats_NoMatchingDocument) { + DocHitInfo doc_hit_info1 = DocHitInfo(4); + doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + + DocHitInfo doc_hit_info2 = DocHitInfo(5); + doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); + doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + + std::vector<DocHitInfo> first_vector = {doc_hit_info1}; + std::vector<DocHitInfo> second_vector = {doc_hit_info2}; + + auto first_iter = + std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); + auto second_iter = + std::make_unique<DocHitInfoIteratorDummy>(second_vector, "hello"); + + DocHitInfoIteratorAnd and_iter(std::move(first_iter), std::move(second_iter)); + std::vector<TermMatchInfo> matched_terms_stats; + and_iter.PopulateMatchedTermsStats(&matched_terms_stats); + EXPECT_THAT(matched_terms_stats, IsEmpty()); + EXPECT_FALSE(and_iter.Advance().ok()); +} + TEST(DocHitInfoIteratorAndNaryTest, Initialize) { std::vector<std::unique_ptr<DocHitInfoIterator>> iterators; iterators.push_back(std::make_unique<DocHitInfoIteratorDummy>()); @@ -345,6 +466,90 @@ TEST(DocHitInfoIteratorAndNaryTest, SectionIdMask) { EXPECT_THAT(and_iter.hit_intersect_section_ids_mask(), Eq(mask_anded_result)); } +TEST(DocHitInfoIteratorAndNaryTest, PopulateMatchedTermsStats) { + // Arbitrary section ids/term frequencies for the documents in the + // DocHitInfoIterators. + // For term "hi", document 10 and 8 + SectionIdMask section_id_mask1_hi = 0b01000101; // hits in sections 0, 2, 6 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_hi{ + 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfo doc_hit_info1_hi = DocHitInfo(10); + doc_hit_info1_hi.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + doc_hit_info1_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); + doc_hit_info1_hi.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + + DocHitInfo doc_hit_info2_hi = DocHitInfo(8); + doc_hit_info2_hi.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); + doc_hit_info2_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + + // For term "hello", document 10 and 9 + SectionIdMask section_id_mask1_hello = 0b00001001; // hits in sections 0, 3 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_hello{ + 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfo doc_hit_info1_hello = DocHitInfo(10); + doc_hit_info1_hello.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); + doc_hit_info1_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); + + DocHitInfo doc_hit_info2_hello = DocHitInfo(9); + doc_hit_info2_hello.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/3); + doc_hit_info2_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/2); + + // For term "ciao", document 10 and 9 + SectionIdMask section_id_mask1_ciao = 0b00000011; // hits in sections 0, 1 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_ciao{ + 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfo doc_hit_info1_ciao = DocHitInfo(10); + doc_hit_info1_ciao.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); + doc_hit_info1_ciao.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/3); + + DocHitInfo doc_hit_info2_ciao = DocHitInfo(9); + doc_hit_info2_ciao.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); + doc_hit_info2_ciao.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/2); + + std::vector<DocHitInfo> first_vector = {doc_hit_info1_hi, doc_hit_info2_hi}; + std::vector<DocHitInfo> second_vector = {doc_hit_info1_hello, + doc_hit_info2_hello}; + std::vector<DocHitInfo> third_vector = {doc_hit_info1_ciao, + doc_hit_info2_ciao}; + + auto first_iter = + std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); + auto second_iter = + std::make_unique<DocHitInfoIteratorDummy>(second_vector, "hello"); + auto third_iter = + std::make_unique<DocHitInfoIteratorDummy>(third_vector, "ciao"); + + std::vector<std::unique_ptr<DocHitInfoIterator>> iterators; + iterators.push_back(std::move(first_iter)); + iterators.push_back(std::move(second_iter)); + iterators.push_back(std::move(third_iter)); + + DocHitInfoIteratorAndNary and_iter(std::move(iterators)); + std::vector<TermMatchInfo> matched_terms_stats; + and_iter.PopulateMatchedTermsStats(&matched_terms_stats); + EXPECT_THAT(matched_terms_stats, IsEmpty()); + + ICING_EXPECT_OK(and_iter.Advance()); + EXPECT_THAT(and_iter.doc_hit_info().document_id(), Eq(10)); + + and_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(3)); // 3 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies1_hi)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1_hi); + EXPECT_EQ(matched_terms_stats.at(1).term, "hello"); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies1_hello)); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask1_hello); + EXPECT_EQ(matched_terms_stats.at(2).term, "ciao"); + EXPECT_THAT(matched_terms_stats.at(2).term_frequencies, + ElementsAreArray(term_frequencies1_ciao)); + EXPECT_EQ(matched_terms_stats.at(2).section_ids_mask, section_id_mask1_ciao); + + EXPECT_FALSE(and_iter.Advance().ok()); +} + } // namespace } // namespace lib diff --git a/icing/index/iterator/doc-hit-info-iterator-filter.h b/icing/index/iterator/doc-hit-info-iterator-filter.h index 9119610..fb60e38 100644 --- a/icing/index/iterator/doc-hit-info-iterator-filter.h +++ b/icing/index/iterator/doc-hit-info-iterator-filter.h @@ -67,6 +67,11 @@ class DocHitInfoIteratorFilter : public DocHitInfoIterator { std::string ToString() const override; + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo>* matched_terms_stats) const override { + delegate_->PopulateMatchedTermsStats(matched_terms_stats); + } + private: std::unique_ptr<DocHitInfoIterator> delegate_; const DocumentStore& document_store_; diff --git a/icing/index/iterator/doc-hit-info-iterator-or.cc b/icing/index/iterator/doc-hit-info-iterator-or.cc index 8f00f88..b4234e0 100644 --- a/icing/index/iterator/doc-hit-info-iterator-or.cc +++ b/icing/index/iterator/doc-hit-info-iterator-or.cc @@ -108,6 +108,7 @@ libtextclassifier3::Status DocHitInfoIteratorOr::Advance() { } else { chosen = left_.get(); } + current_ = chosen; doc_hit_info_ = chosen->doc_hit_info(); hit_intersect_section_ids_mask_ = chosen->hit_intersect_section_ids_mask(); @@ -139,6 +140,7 @@ DocHitInfoIteratorOrNary::DocHitInfoIteratorOrNary( : iterators_(std::move(iterators)) {} libtextclassifier3::Status DocHitInfoIteratorOrNary::Advance() { + current_iterators_.clear(); if (iterators_.size() < 2) { return absl_ports::InvalidArgumentError( "Not enough iterators to OR together"); @@ -187,6 +189,7 @@ libtextclassifier3::Status DocHitInfoIteratorOrNary::Advance() { hit_intersect_section_ids_mask_ = kSectionIdMaskNone; for (const auto& iterator : iterators_) { if (iterator->doc_hit_info().document_id() == next_document_id) { + current_iterators_.push_back(iterator.get()); if (doc_hit_info_.document_id() == kInvalidDocumentId) { doc_hit_info_ = iterator->doc_hit_info(); hit_intersect_section_ids_mask_ = diff --git a/icing/index/iterator/doc-hit-info-iterator-or.h b/icing/index/iterator/doc-hit-info-iterator-or.h index 4128e0f..2f49430 100644 --- a/icing/index/iterator/doc-hit-info-iterator-or.h +++ b/icing/index/iterator/doc-hit-info-iterator-or.h @@ -42,9 +42,26 @@ class DocHitInfoIteratorOr : public DocHitInfoIterator { std::string ToString() const override; + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo> *matched_terms_stats) const override { + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + // Current hit isn't valid, return. + return; + } + current_->PopulateMatchedTermsStats(matched_terms_stats); + // If equal, then current_ == left_. Combine with results from right_. + if (left_document_id_ == right_document_id_) { + right_->PopulateMatchedTermsStats(matched_terms_stats); + } + } + private: std::unique_ptr<DocHitInfoIterator> left_; std::unique_ptr<DocHitInfoIterator> right_; + // Pointer to the chosen iterator that points to the current doc_hit_info_. If + // both left_ and right_ point to the same docid, then chosen_ == left. + // chosen_ does not own the iterator it points to. + DocHitInfoIterator *current_; DocumentId left_document_id_ = kMaxDocumentId; DocumentId right_document_id_ = kMaxDocumentId; }; @@ -65,8 +82,22 @@ class DocHitInfoIteratorOrNary : public DocHitInfoIterator { std::string ToString() const override; + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo> *matched_terms_stats) const override { + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + // Current hit isn't valid, return. + return; + } + for (size_t i = 0; i < current_iterators_.size(); i++) { + current_iterators_.at(i)->PopulateMatchedTermsStats(matched_terms_stats); + } + } + private: std::vector<std::unique_ptr<DocHitInfoIterator>> iterators_; + // Pointers to the iterators that point to the current doc_hit_info_. + // current_iterators_ does not own the iterators it points to. + std::vector<DocHitInfoIterator *> current_iterators_; }; } // namespace lib diff --git a/icing/index/iterator/doc-hit-info-iterator-or_test.cc b/icing/index/iterator/doc-hit-info-iterator-or_test.cc index 3faa5ab..3f00a39 100644 --- a/icing/index/iterator/doc-hit-info-iterator-or_test.cc +++ b/icing/index/iterator/doc-hit-info-iterator-or_test.cc @@ -32,7 +32,10 @@ namespace lib { namespace { using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::Eq; +using ::testing::IsEmpty; +using ::testing::SizeIs; TEST(CreateAndIteratorTest, Or) { // Basic test that we can create a working Or iterator. Further testing of @@ -175,6 +178,159 @@ TEST(DocHitInfoIteratorOrTest, SectionIdMask) { EXPECT_THAT(or_iter.hit_intersect_section_ids_mask(), Eq(mask_anded_result)); } +TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) { + { + // Arbitrary section ids for the documents in the DocHitInfoIterators. + // Created to test correct section_id_mask behavior. + SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ + 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; + SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2{ + 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + DocHitInfo doc_hit_info1 = DocHitInfo(4); + doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); + doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3); + doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + + DocHitInfo doc_hit_info2 = DocHitInfo(4); + doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); + doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + + std::vector<DocHitInfo> first_vector = {doc_hit_info1}; + std::vector<DocHitInfo> second_vector = {doc_hit_info2}; + + auto first_iter = + std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); + first_iter->set_hit_intersect_section_ids_mask(section_id_mask1); + + auto second_iter = + std::make_unique<DocHitInfoIteratorDummy>(second_vector, "hello"); + second_iter->set_hit_intersect_section_ids_mask(section_id_mask2); + + DocHitInfoIteratorOr or_iter(std::move(first_iter), std::move(second_iter)); + std::vector<TermMatchInfo> matched_terms_stats; + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + EXPECT_THAT(matched_terms_stats, IsEmpty()); + + ICING_EXPECT_OK(or_iter.Advance()); + EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(4)); + + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); + EXPECT_EQ(matched_terms_stats.at(1).term, "hello"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies1)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies2)); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask2); + + EXPECT_FALSE(or_iter.Advance().ok()); + } + { + // Arbitrary section ids for the documents in the DocHitInfoIterators. + // Created to test correct section_id_mask behavior. + SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ + 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + DocHitInfo doc_hit_info1 = DocHitInfo(4); + doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); + + std::vector<DocHitInfo> first_vector = {doc_hit_info1}; + std::vector<DocHitInfo> second_vector = {doc_hit_info1}; + + auto first_iter = + std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); + first_iter->set_hit_intersect_section_ids_mask(section_id_mask1); + + auto second_iter = + std::make_unique<DocHitInfoIteratorDummy>(second_vector, "hi"); + second_iter->set_hit_intersect_section_ids_mask(section_id_mask1); + + DocHitInfoIteratorOr or_iter(std::move(first_iter), std::move(second_iter)); + std::vector<TermMatchInfo> matched_terms_stats; + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + EXPECT_THAT(matched_terms_stats, IsEmpty()); + + ICING_EXPECT_OK(or_iter.Advance()); + EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(4)); + + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies1)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); + + EXPECT_FALSE(or_iter.Advance().ok()); + } + { + // Arbitrary section ids for the documents in the DocHitInfoIterators. + // Created to test correct section_id_mask behavior. + SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ + 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; + SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2{ + 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + DocHitInfo doc_hit_info1 = DocHitInfo(4); + doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); + doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3); + doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + + DocHitInfo doc_hit_info2 = DocHitInfo(5); + doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); + doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + + std::vector<DocHitInfo> first_vector = {doc_hit_info1}; + std::vector<DocHitInfo> second_vector = {doc_hit_info2}; + + auto first_iter = + std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); + first_iter->set_hit_intersect_section_ids_mask(section_id_mask1); + + auto second_iter = + std::make_unique<DocHitInfoIteratorDummy>(second_vector, "hello"); + second_iter->set_hit_intersect_section_ids_mask(section_id_mask2); + + DocHitInfoIteratorOr or_iter(std::move(first_iter), std::move(second_iter)); + std::vector<TermMatchInfo> matched_terms_stats; + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + EXPECT_THAT(matched_terms_stats, IsEmpty()); + + ICING_EXPECT_OK(or_iter.Advance()); + EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(5)); + + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies2)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask2); + + ICING_EXPECT_OK(or_iter.Advance()); + EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(4)); + + matched_terms_stats.clear(); + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies1)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1); + + EXPECT_FALSE(or_iter.Advance().ok()); + } +} + TEST(DocHitInfoIteratorOrNaryTest, Initialize) { std::vector<std::unique_ptr<DocHitInfoIterator>> iterators; iterators.push_back(std::make_unique<DocHitInfoIteratorDummy>()); @@ -316,6 +472,125 @@ TEST(DocHitInfoIteratorOrNaryTest, SectionIdMask) { EXPECT_THAT(or_iter.hit_intersect_section_ids_mask(), Eq(mask_anded_result)); } +TEST(DocHitInfoIteratorOrNaryTest, PopulateMatchedTermsStats) { + // Arbitrary section ids/term frequencies for the documents in the + // DocHitInfoIterators. + // For term "hi", document 10 and 8 + SectionIdMask section_id_mask1_hi = 0b01000101; // hits in sections 0, 2, 6 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_hi{ + 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfo doc_hit_info1_hi = DocHitInfo(10); + doc_hit_info1_hi.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + doc_hit_info1_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); + doc_hit_info1_hi.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); + + SectionIdMask section_id_mask2_hi = 0b00000110; // hits in sections 1, 2 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2_hi{ + 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfo doc_hit_info2_hi = DocHitInfo(8); + doc_hit_info2_hi.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); + doc_hit_info2_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); + + // For term "hello", document 10 and 9 + SectionIdMask section_id_mask1_hello = 0b00001001; // hits in sections 0, 3 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_hello{ + 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfo doc_hit_info1_hello = DocHitInfo(10); + doc_hit_info1_hello.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); + doc_hit_info1_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); + + SectionIdMask section_id_mask2_hello = 0b00001100; // hits in sections 2, 3 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2_hello{ + 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfo doc_hit_info2_hello = DocHitInfo(9); + doc_hit_info2_hello.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/3); + doc_hit_info2_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/2); + + // For term "ciao", document 9 and 8 + SectionIdMask section_id_mask1_ciao = 0b00000011; // hits in sections 0, 1 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_ciao{ + 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfo doc_hit_info1_ciao = DocHitInfo(9); + doc_hit_info1_ciao.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); + doc_hit_info1_ciao.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/3); + + SectionIdMask section_id_mask2_ciao = 0b00011000; // hits in sections 3, 4 + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2_ciao{ + 0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfo doc_hit_info2_ciao = DocHitInfo(8); + doc_hit_info2_ciao.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); + doc_hit_info2_ciao.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/2); + + std::vector<DocHitInfo> first_vector = {doc_hit_info1_hi, doc_hit_info2_hi}; + std::vector<DocHitInfo> second_vector = {doc_hit_info1_hello, + doc_hit_info2_hello}; + std::vector<DocHitInfo> third_vector = {doc_hit_info1_ciao, + doc_hit_info2_ciao}; + + auto first_iter = + std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); + auto second_iter = + std::make_unique<DocHitInfoIteratorDummy>(second_vector, "hello"); + auto third_iter = + std::make_unique<DocHitInfoIteratorDummy>(third_vector, "ciao"); + + std::vector<std::unique_ptr<DocHitInfoIterator>> iterators; + iterators.push_back(std::move(first_iter)); + iterators.push_back(std::move(second_iter)); + iterators.push_back(std::move(third_iter)); + + DocHitInfoIteratorOrNary or_iter(std::move(iterators)); + std::vector<TermMatchInfo> matched_terms_stats; + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + EXPECT_THAT(matched_terms_stats, IsEmpty()); + + ICING_EXPECT_OK(or_iter.Advance()); + EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(10)); + + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies1_hi)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask1_hi); + EXPECT_EQ(matched_terms_stats.at(1).term, "hello"); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies1_hello)); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask1_hello); + + ICING_EXPECT_OK(or_iter.Advance()); + EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(9)); + + matched_terms_stats.clear(); + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies2_hello)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask2_hello); + EXPECT_EQ(matched_terms_stats.at(1).term, "ciao"); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies1_ciao)); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask1_ciao); + + ICING_EXPECT_OK(or_iter.Advance()); + EXPECT_THAT(or_iter.doc_hit_info().document_id(), Eq(8)); + + matched_terms_stats.clear(); + or_iter.PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies2_hi)); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask2_hi); + EXPECT_EQ(matched_terms_stats.at(1).term, "ciao"); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies2_ciao)); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask2_ciao); + + EXPECT_FALSE(or_iter.Advance().ok()); +} + } // namespace } // namespace lib diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict.h b/icing/index/iterator/doc-hit-info-iterator-section-restrict.h index ae5a896..ba74384 100644 --- a/icing/index/iterator/doc-hit-info-iterator-section-restrict.h +++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict.h @@ -52,6 +52,15 @@ class DocHitInfoIteratorSectionRestrict : public DocHitInfoIterator { std::string ToString() const override; + // NOTE: currently, section restricts does decide which documents to + // return, but doesn't impact the relevance score of a document. + // TODO(b/173156803): decide whether we want to filter the matched_terms_stats + // for the restricted sections. + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo>* matched_terms_stats) const override { + delegate_->PopulateMatchedTermsStats(matched_terms_stats); + } + private: std::unique_ptr<DocHitInfoIterator> delegate_; const DocumentStore& document_store_; diff --git a/icing/index/iterator/doc-hit-info-iterator-test-util.h b/icing/index/iterator/doc-hit-info-iterator-test-util.h index c4d7aa7..913696a 100644 --- a/icing/index/iterator/doc-hit-info-iterator-test-util.h +++ b/icing/index/iterator/doc-hit-info-iterator-test-util.h @@ -15,7 +15,6 @@ #ifndef ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_ #define ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_ -#include <cstdint> #include <string> #include <utility> #include <vector> @@ -40,8 +39,9 @@ namespace lib { class DocHitInfoIteratorDummy : public DocHitInfoIterator { public: DocHitInfoIteratorDummy() = default; - explicit DocHitInfoIteratorDummy(std::vector<DocHitInfo> doc_hit_infos) - : doc_hit_infos_(std::move(doc_hit_infos)) {} + explicit DocHitInfoIteratorDummy(std::vector<DocHitInfo> doc_hit_infos, + std::string term = "") + : doc_hit_infos_(std::move(doc_hit_infos)), term_(std::move(term)) {} libtextclassifier3::Status Advance() override { if (index_ < doc_hit_infos_.size()) { @@ -54,6 +54,36 @@ class DocHitInfoIteratorDummy : public DocHitInfoIterator { "No more DocHitInfos in iterator"); } + // Imitates behavior of DocHitInfoIteratorTermMain/DocHitInfoIteratorTermLite + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo>* matched_terms_stats) const override { + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + // Current hit isn't valid, return. + return; + } + SectionIdMask section_mask = doc_hit_info_.hit_section_ids_mask(); + std::array<Hit::TermFrequency, kMaxSectionId> section_term_frequencies = { + Hit::kNoTermFrequency}; + + while (section_mask) { + SectionId section_id = __builtin_ctz(section_mask); + section_term_frequencies.at(section_id) = + doc_hit_info_.hit_term_frequency(section_id); + section_mask &= ~(1u << section_id); + } + TermMatchInfo term_stats(term_, doc_hit_info_.hit_section_ids_mask(), + section_term_frequencies); + + for (auto& cur_term_stats : *matched_terms_stats) { + if (cur_term_stats.term == term_stats.term) { + // Same docId and same term, we don't need to add the term and the term + // frequency should always be the same + return; + } + } + matched_terms_stats->push_back(term_stats); + } + void set_hit_intersect_section_ids_mask( SectionIdMask hit_intersect_section_ids_mask) { hit_intersect_section_ids_mask_ = hit_intersect_section_ids_mask; @@ -91,6 +121,7 @@ class DocHitInfoIteratorDummy : public DocHitInfoIterator { int32_t num_blocks_inspected_ = 0; int32_t num_leaf_advance_calls_ = 0; std::vector<DocHitInfo> doc_hit_infos_; + std::string term_; }; inline std::vector<DocumentId> GetDocumentIds(DocHitInfoIterator* iterator) { diff --git a/icing/index/iterator/doc-hit-info-iterator.h b/icing/index/iterator/doc-hit-info-iterator.h index bcc2b6e..c4d9901 100644 --- a/icing/index/iterator/doc-hit-info-iterator.h +++ b/icing/index/iterator/doc-hit-info-iterator.h @@ -17,6 +17,7 @@ #include <cstdint> #include <string> +#include <string_view> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" @@ -28,6 +29,26 @@ namespace icing { namespace lib { +// Data structure that maps a single matched query term to its section mask +// and the list of term frequencies. +// TODO(b/158603837): add stat on whether the matched terms are prefix matched +// or not. This information will be used to boost exact match. +struct TermMatchInfo { + std::string_view term; + // SectionIdMask associated to the term. + SectionIdMask section_ids_mask; + // Array with fixed size kMaxSectionId. For every section id, i.e. + // vector index, it stores the term frequency of the term. + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies; + + explicit TermMatchInfo( + std::string_view term, SectionIdMask section_ids_mask, + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies) + : term(term), + section_ids_mask(section_ids_mask), + term_frequencies(std::move(term_frequencies)) {} +}; + // Iterator over DocHitInfos (collapsed Hits) in REVERSE document_id order. // // NOTE: You must call Advance() before calling hit_info() or @@ -70,6 +91,14 @@ class DocHitInfoIterator { // A string representing the iterator. virtual std::string ToString() const = 0; + // For the last hit docid, retrieves all the matched query terms and other + // stats, see TermMatchInfo. + // If Advance() wasn't called after construction, Advance() returned false or + // the concrete HitIterator didn't override this method, the vectors aren't + // populated. + virtual void PopulateMatchedTermsStats( + std::vector<TermMatchInfo>* matched_terms_stats) const {} + protected: DocHitInfo doc_hit_info_; SectionIdMask hit_intersect_section_ids_mask_ = kSectionIdMaskNone; diff --git a/icing/index/lite/doc-hit-info-iterator-term-lite.h b/icing/index/lite/doc-hit-info-iterator-term-lite.h index bd2de6d..ac5e97f 100644 --- a/icing/index/lite/doc-hit-info-iterator-term-lite.h +++ b/icing/index/lite/doc-hit-info-iterator-term-lite.h @@ -49,6 +49,34 @@ class DocHitInfoIteratorTermLite : public DocHitInfoIterator { } int32_t GetNumLeafAdvanceCalls() const override { return num_advance_calls_; } + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo>* matched_terms_stats) const override { + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + // Current hit isn't valid, return. + return; + } + SectionIdMask section_mask = doc_hit_info_.hit_section_ids_mask(); + std::array<Hit::TermFrequency, kMaxSectionId> section_term_frequencies = { + Hit::kNoTermFrequency}; + while (section_mask) { + SectionId section_id = __builtin_ctz(section_mask); + section_term_frequencies.at(section_id) = + doc_hit_info_.hit_term_frequency(section_id); + section_mask &= ~(1u << section_id); + } + TermMatchInfo term_stats(term_, doc_hit_info_.hit_section_ids_mask(), + std::move(section_term_frequencies)); + + for (const TermMatchInfo& cur_term_stats : *matched_terms_stats) { + if (cur_term_stats.term == term_stats.term) { + // Same docId and same term, we don't need to add the term and the term + // frequency should always be the same + return; + } + } + matched_terms_stats->push_back(std::move(term_stats)); + } + protected: // Add DocHitInfos corresponding to term_ to cached_hits_. virtual libtextclassifier3::Status RetrieveMoreHits() = 0; diff --git a/icing/index/main/doc-hit-info-iterator-term-main.h b/icing/index/main/doc-hit-info-iterator-term-main.h index 1f77226..d626d7a 100644 --- a/icing/index/main/doc-hit-info-iterator-term-main.h +++ b/icing/index/main/doc-hit-info-iterator-term-main.h @@ -49,6 +49,34 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { } int32_t GetNumLeafAdvanceCalls() const override { return num_advance_calls_; } + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo>* matched_terms_stats) const override { + if (doc_hit_info_.document_id() == kInvalidDocumentId) { + // Current hit isn't valid, return. + return; + } + SectionIdMask section_mask = doc_hit_info_.hit_section_ids_mask(); + std::array<Hit::TermFrequency, kMaxSectionId> section_term_frequencies = { + Hit::kNoTermFrequency}; + while (section_mask) { + SectionId section_id = __builtin_ctz(section_mask); + section_term_frequencies.at(section_id) = + doc_hit_info_.hit_term_frequency(section_id); + section_mask &= ~(1u << section_id); + } + TermMatchInfo term_stats(term_, doc_hit_info_.hit_section_ids_mask(), + std::move(section_term_frequencies)); + + for (const TermMatchInfo& cur_term_stats : *matched_terms_stats) { + if (cur_term_stats.term == term_stats.term) { + // Same docId and same term, we don't need to add the term and the term + // frequency should always be the same + return; + } + } + matched_terms_stats->push_back(std::move(term_stats)); + } + protected: // Add DocHitInfos corresponding to term_ to cached_doc_hit_infos_. virtual libtextclassifier3::Status RetrieveMoreHits() = 0; diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc index a18a183..bf709cd 100644 --- a/icing/jni/icing-search-engine-jni.cc +++ b/icing/jni/icing-search-engine-jni.cc @@ -189,16 +189,21 @@ Java_com_google_android_icing_IcingSearchEngine_nativePut( JNIEXPORT jbyteArray JNICALL Java_com_google_android_icing_IcingSearchEngine_nativeGet( - JNIEnv* env, jclass clazz, jobject object, jstring name_space, - jstring uri) { + JNIEnv* env, jclass clazz, jobject object, jstring name_space, jstring uri, + jbyteArray result_spec_bytes) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); const char* native_name_space = env->GetStringUTFChars(name_space, /*isCopy=*/nullptr); const char* native_uri = env->GetStringUTFChars(uri, /*isCopy=*/nullptr); + icing::lib::GetResultSpecProto get_result_spec; + if (!ParseProtoFromJniByteArray(env, result_spec_bytes, &get_result_spec)) { + ICING_LOG(ERROR) << "Failed to parse GetResultSpecProto in nativeGet"; + return nullptr; + } icing::lib::GetResultProto get_result_proto = - icing->Get(native_name_space, native_uri); + icing->Get(native_name_space, native_uri, get_result_spec); return SerializeProtoToJniByteArray(env, get_result_proto); } diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc index 4d714f8..0732ed0 100644 --- a/icing/query/query-processor.cc +++ b/icing/query/query-processor.cc @@ -131,10 +131,8 @@ QueryProcessor::QueryProcessor(Index* index, schema_store_(*schema_store), clock_(*clock) {} -libtextclassifier3::StatusOr<QueryProcessor::QueryResults> -QueryProcessor::ParseSearch(const SearchSpecProto& search_spec) { - ICING_ASSIGN_OR_RETURN(QueryResults results, ParseRawQuery(search_spec)); - +DocHitInfoIteratorFilter::Options QueryProcessor::getFilterOptions( + const SearchSpecProto& search_spec) { DocHitInfoIteratorFilter::Options options; if (search_spec.namespace_filters_size() > 0) { @@ -148,7 +146,14 @@ QueryProcessor::ParseSearch(const SearchSpecProto& search_spec) { std::vector<std::string_view>(search_spec.schema_type_filters().begin(), search_spec.schema_type_filters().end()); } + return options; +} + +libtextclassifier3::StatusOr<QueryProcessor::QueryResults> +QueryProcessor::ParseSearch(const SearchSpecProto& search_spec) { + ICING_ASSIGN_OR_RETURN(QueryResults results, ParseRawQuery(search_spec)); + DocHitInfoIteratorFilter::Options options = getFilterOptions(search_spec); results.root_iterator = std::make_unique<DocHitInfoIteratorFilter>( std::move(results.root_iterator), &document_store_, &schema_store_, &clock_, options); @@ -158,6 +163,8 @@ QueryProcessor::ParseSearch(const SearchSpecProto& search_spec) { // TODO(cassiewang): Collect query stats to populate the SearchResultsProto libtextclassifier3::StatusOr<QueryProcessor::QueryResults> QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) { + DocHitInfoIteratorFilter::Options options = getFilterOptions(search_spec); + // Tokenize the incoming raw query // // TODO(cassiewang): Consider caching/creating a tokenizer factory that will @@ -258,12 +265,22 @@ QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) { index_.GetIterator(normalized_text, kSectionIdMaskAll, search_spec.term_match_type())); - // Add terms to match if this is not a negation term. + // Add term iterator and terms to match if this is not a negation term. // WARNING: setting query terms at this point is not compatible with // group-level excludes, group-level sections restricts or excluded // section restricts. Those are not currently supported. If they became // supported, this handling for query terms would need to be altered. if (!frames.top().saw_exclude) { + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<DocHitInfoIterator> term_iterator, + index_.GetIterator(normalized_text, kSectionIdMaskAll, + search_spec.term_match_type())); + + results.query_term_iterators[normalized_text] = + std::make_unique<DocHitInfoIteratorFilter>( + std::move(term_iterator), &document_store_, &schema_store_, + &clock_, options); + results.query_terms[frames.top().section_restrict].insert( std::move(normalized_text)); } diff --git a/icing/query/query-processor.h b/icing/query/query-processor.h index fa98627..0932ec5 100644 --- a/icing/query/query-processor.h +++ b/icing/query/query-processor.h @@ -19,6 +19,7 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/index/index.h" +#include "icing/index/iterator/doc-hit-info-iterator-filter.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/search.pb.h" #include "icing/query/query-terms.h" @@ -53,6 +54,11 @@ class QueryProcessor { // A map from section names to sets of terms restricted to those sections. // Query terms that are not restricted are found at the entry with key "". SectionRestrictQueryTermsMap query_terms; + // Hit iterators for the text terms in the query. These query_term_iterators + // are completely separate from the iterators that make the iterator tree + // beginning with root_iterator. + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; }; // Parse the search configurations (including the query, any additional // filters, etc.) in the SearchSpecProto into one DocHitInfoIterator. @@ -85,6 +91,11 @@ class QueryProcessor { libtextclassifier3::StatusOr<QueryResults> ParseRawQuery( const SearchSpecProto& search_spec); + // Return the options for the DocHitInfoIteratorFilter based on the + // search_spec. + DocHitInfoIteratorFilter::Options getFilterOptions( + const SearchSpecProto& search_spec); + // Not const because we could modify/sort the hit buffer in the lite index at // query time. Index& index_; diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc index 7546ae4..6ec0a2a 100644 --- a/icing/query/query-processor_test.cc +++ b/icing/query/query-processor_test.cc @@ -54,6 +54,7 @@ namespace lib { namespace { using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::IsEmpty; using ::testing::SizeIs; using ::testing::Test; @@ -232,6 +233,7 @@ TEST_F(QueryProcessorTest, EmptyGroupMatchAllDocuments) { EXPECT_THAT(GetDocumentIds(results.root_iterator.get()), ElementsAre(document_id2, document_id1)); EXPECT_THAT(results.query_terms, IsEmpty()); + EXPECT_THAT(results.query_term_iterators, IsEmpty()); } TEST_F(QueryProcessorTest, EmptyQueryMatchAllDocuments) { @@ -281,6 +283,7 @@ TEST_F(QueryProcessorTest, EmptyQueryMatchAllDocuments) { EXPECT_THAT(GetDocumentIds(results.root_iterator.get()), ElementsAre(document_id2, document_id1)); EXPECT_THAT(results.query_terms, IsEmpty()); + EXPECT_THAT(results.query_term_iterators, IsEmpty()); } TEST_F(QueryProcessorTest, QueryTermNormalized) { @@ -312,6 +315,8 @@ TEST_F(QueryProcessorTest, QueryTermNormalized) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -334,11 +339,26 @@ TEST_F(QueryProcessorTest, QueryTermNormalized) { ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, query_processor->ParseSearch(search_spec)); - // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "world"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "world")); + + EXPECT_THAT(results.query_term_iterators, SizeIs(2)); } TEST_F(QueryProcessorTest, OneTermPrefixMatch) { @@ -370,6 +390,8 @@ TEST_F(QueryProcessorTest, OneTermPrefixMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::PREFIX; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -389,11 +411,21 @@ TEST_F(QueryProcessorTest, OneTermPrefixMatch) { ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, query_processor->ParseSearch(search_spec)); - // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "he"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, OneTermExactMatch) { @@ -425,6 +457,8 @@ TEST_F(QueryProcessorTest, OneTermExactMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -444,11 +478,90 @@ TEST_F(QueryProcessorTest, OneTermExactMatch) { ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, query_processor->ParseSearch(search_spec)); - // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + + EXPECT_THAT(results.query_terms, SizeIs(1)); + EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); +} + +TEST_F(QueryProcessorTest, AndSameTermExactMatch) { + // Create the schema and document store + SchemaProto schema; + AddSchemaType(&schema, "email"); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + + // These documents don't actually match to the tokens in the index. We're + // just inserting the documents so that the DocHitInfoIterators will see + // that the document exists and not filter out the DocumentId as deleted. + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, + document_store_->Put(DocumentBuilder() + .SetKey("namespace", "1") + .SetSchema("email") + .Build())); + + // Populate the index + SectionId section_id = 0; + SectionIdMask section_id_mask = 1U << section_id; + TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + EXPECT_THAT( + AddTokenToIndex(document_id, section_id, term_match_type, "hello"), + IsOk()); + + // Perform query + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<QueryProcessor> query_processor, + QueryProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); + + SearchSpecProto search_spec; + search_spec.set_query("hello hello"); + search_spec.set_term_match_type(term_match_type); + + ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, + query_processor->ParseSearch(search_spec)); + + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + + ASSERT_FALSE(results.root_iterator->Advance().ok()); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, AndTwoTermExactMatch) { @@ -480,6 +593,8 @@ TEST_F(QueryProcessorTest, AndTwoTermExactMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -502,11 +617,94 @@ TEST_F(QueryProcessorTest, AndTwoTermExactMatch) { ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, query_processor->ParseSearch(search_spec)); - // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "world"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "world")); + EXPECT_THAT(results.query_term_iterators, SizeIs(2)); +} + +TEST_F(QueryProcessorTest, AndSameTermPrefixMatch) { + // Create the schema and document store + SchemaProto schema; + AddSchemaType(&schema, "email"); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + + // These documents don't actually match to the tokens in the index. We're + // just inserting the documents so that the DocHitInfoIterators will see + // that the document exists and not filter out the DocumentId as deleted. + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, + document_store_->Put(DocumentBuilder() + .SetKey("namespace", "1") + .SetSchema("email") + .Build())); + + // Populate the index + SectionId section_id = 0; + SectionIdMask section_id_mask = 1U << section_id; + TermMatchType::Code term_match_type = TermMatchType::PREFIX; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + + EXPECT_THAT( + AddTokenToIndex(document_id, section_id, term_match_type, "hello"), + IsOk()); + + // Perform query + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<QueryProcessor> query_processor, + QueryProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get(), &fake_clock_)); + + SearchSpecProto search_spec; + search_spec.set_query("he he"); + search_spec.set_term_match_type(term_match_type); + + ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, + query_processor->ParseSearch(search_spec)); + + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "he"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + + ASSERT_FALSE(results.root_iterator->Advance().ok()); + + EXPECT_THAT(results.query_terms, SizeIs(1)); + EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, AndTwoTermPrefixMatch) { @@ -537,6 +735,8 @@ TEST_F(QueryProcessorTest, AndTwoTermPrefixMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; TermMatchType::Code term_match_type = TermMatchType::PREFIX; EXPECT_THAT( @@ -561,10 +761,25 @@ TEST_F(QueryProcessorTest, AndTwoTermPrefixMatch) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "he"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "wo"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he", "wo")); + EXPECT_THAT(results.query_term_iterators, SizeIs(2)); } TEST_F(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { @@ -595,6 +810,8 @@ TEST_F(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; TermMatchType::Code term_match_type = TermMatchType::PREFIX; EXPECT_THAT(AddTokenToIndex(document_id, section_id, @@ -619,10 +836,25 @@ TEST_F(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "wo"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "wo")); + EXPECT_THAT(results.query_term_iterators, SizeIs(2)); } TEST_F(QueryProcessorTest, OrTwoTermExactMatch) { @@ -658,6 +890,8 @@ TEST_F(QueryProcessorTest, OrTwoTermExactMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; EXPECT_THAT( @@ -682,11 +916,33 @@ TEST_F(QueryProcessorTest, OrTwoTermExactMatch) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id2, section_id_mask), - DocHitInfo(document_id1, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id2); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "world"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + + matched_terms_stats.clear(); + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id1); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "world")); + EXPECT_THAT(results.query_term_iterators, SizeIs(2)); } TEST_F(QueryProcessorTest, OrTwoTermPrefixMatch) { @@ -722,6 +978,8 @@ TEST_F(QueryProcessorTest, OrTwoTermPrefixMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; TermMatchType::Code term_match_type = TermMatchType::PREFIX; EXPECT_THAT( @@ -746,11 +1004,32 @@ TEST_F(QueryProcessorTest, OrTwoTermPrefixMatch) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id2, section_id_mask), - DocHitInfo(document_id1, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id2); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "wo"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + + matched_terms_stats.clear(); + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id1); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "he"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he", "wo")); + EXPECT_THAT(results.query_term_iterators, SizeIs(2)); } TEST_F(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { @@ -786,6 +1065,8 @@ TEST_F(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; EXPECT_THAT(AddTokenToIndex(document_id1, section_id, TermMatchType::EXACT_ONLY, "hello"), @@ -809,11 +1090,32 @@ TEST_F(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id2, section_id_mask), - DocHitInfo(document_id1, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id2); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "wo"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + + matched_terms_stats.clear(); + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id1); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "hello"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("hello", "wo")); + EXPECT_THAT(results.query_term_iterators, SizeIs(2)); } TEST_F(QueryProcessorTest, CombinedAndOrTerms) { @@ -848,6 +1150,8 @@ TEST_F(QueryProcessorTest, CombinedAndOrTerms) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; + std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "animal puppy dog" @@ -888,11 +1192,27 @@ TEST_F(QueryProcessorTest, CombinedAndOrTerms) { query_processor->ParseSearch(search_spec)); // Only Document 1 matches since it has puppy AND dog - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id1, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), + document_id1); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "puppy"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "dog"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("puppy", "kitten", "dog")); + EXPECT_THAT(results.query_term_iterators, SizeIs(3)); } { @@ -905,15 +1225,47 @@ TEST_F(QueryProcessorTest, CombinedAndOrTerms) { ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, query_processor->ParseSearch(search_spec)); - // Both Document 1 and 2 match since Document 1 has puppy AND dog, and - // Document 2 has kitten + // Both Document 1 and 2 match since Document 1 has animal AND puppy, and + // Document 2 has animal AND kitten // Descending order of valid DocumentIds - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id2, section_id_mask), - DocHitInfo(document_id1, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), + document_id2); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "animal"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "kitten"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies)); + + matched_terms_stats.clear(); + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), + document_id1); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "animal"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "puppy"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("animal", "puppy", "kitten")); + EXPECT_THAT(results.query_term_iterators, SizeIs(3)); } { @@ -927,11 +1279,27 @@ TEST_F(QueryProcessorTest, CombinedAndOrTerms) { query_processor->ParseSearch(search_spec)); // Only Document 2 matches since it has both kitten and cat - EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id2, section_id_mask))); + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), + document_id2); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "kitten"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "cat"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(term_frequencies)); + EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("kitten", "foo", "bar", "cat")); + EXPECT_THAT(results.query_term_iterators, SizeIs(4)); } } @@ -967,7 +1335,6 @@ TEST_F(QueryProcessorTest, OneGroup) { // Populate the index SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "puppy dog" @@ -1001,11 +1368,14 @@ TEST_F(QueryProcessorTest, OneGroup) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds + DocHitInfo expectedDocHitInfo(document_id1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id1, section_id_mask))); + ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("puppy", "kitten", "foo")); + EXPECT_THAT(results.query_term_iterators, SizeIs(3)); } TEST_F(QueryProcessorTest, TwoGroups) { @@ -1040,7 +1410,6 @@ TEST_F(QueryProcessorTest, TwoGroups) { // Populate the index SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "puppy dog" @@ -1074,12 +1443,16 @@ TEST_F(QueryProcessorTest, TwoGroups) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds + DocHitInfo expectedDocHitInfo1(document_id1); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + DocHitInfo expectedDocHitInfo2(document_id2); + expectedDocHitInfo2.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id2, section_id_mask), - DocHitInfo(document_id1, section_id_mask))); + ElementsAre(expectedDocHitInfo2, expectedDocHitInfo1)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("puppy", "dog", "kitten", "cat")); + EXPECT_THAT(results.query_term_iterators, SizeIs(4)); } TEST_F(QueryProcessorTest, ManyLevelNestedGrouping) { @@ -1114,7 +1487,6 @@ TEST_F(QueryProcessorTest, ManyLevelNestedGrouping) { // Populate the index SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "puppy dog" @@ -1148,11 +1520,14 @@ TEST_F(QueryProcessorTest, ManyLevelNestedGrouping) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds + DocHitInfo expectedDocHitInfo(document_id1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id1, section_id_mask))); + ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("puppy", "kitten", "foo")); + EXPECT_THAT(results.query_term_iterators, SizeIs(3)); } TEST_F(QueryProcessorTest, OneLevelNestedGrouping) { @@ -1187,7 +1562,6 @@ TEST_F(QueryProcessorTest, OneLevelNestedGrouping) { // Populate the index SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "puppy dog" @@ -1220,12 +1594,16 @@ TEST_F(QueryProcessorTest, OneLevelNestedGrouping) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds + DocHitInfo expectedDocHitInfo1(document_id1); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + DocHitInfo expectedDocHitInfo2(document_id2); + expectedDocHitInfo2.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id2, section_id_mask), - DocHitInfo(document_id1, section_id_mask))); + ElementsAre(expectedDocHitInfo2, expectedDocHitInfo1)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("puppy", "kitten", "cat")); + EXPECT_THAT(results.query_term_iterators, SizeIs(3)); } TEST_F(QueryProcessorTest, ExcludeTerm) { @@ -1289,6 +1667,7 @@ TEST_F(QueryProcessorTest, ExcludeTerm) { EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(DocHitInfo(document_id2, kSectionIdMaskNone))); EXPECT_THAT(results.query_terms, IsEmpty()); + EXPECT_THAT(results.query_term_iterators, IsEmpty()); } TEST_F(QueryProcessorTest, ExcludeNonexistentTerm) { @@ -1350,6 +1729,7 @@ TEST_F(QueryProcessorTest, ExcludeNonexistentTerm) { ElementsAre(DocHitInfo(document_id2, kSectionIdMaskNone), DocHitInfo(document_id1, kSectionIdMaskNone))); EXPECT_THAT(results.query_terms, IsEmpty()); + EXPECT_THAT(results.query_term_iterators, IsEmpty()); } TEST_F(QueryProcessorTest, ExcludeAnd) { @@ -1420,6 +1800,7 @@ TEST_F(QueryProcessorTest, ExcludeAnd) { // animal, there are no results. EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), IsEmpty()); EXPECT_THAT(results.query_terms, IsEmpty()); + EXPECT_THAT(results.query_term_iterators, IsEmpty()); } { @@ -1436,6 +1817,7 @@ TEST_F(QueryProcessorTest, ExcludeAnd) { EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), IsEmpty()); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("cat")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } } @@ -1471,7 +1853,6 @@ TEST_F(QueryProcessorTest, ExcludeOr) { // Populate the index SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "animal dog" @@ -1509,6 +1890,7 @@ TEST_F(QueryProcessorTest, ExcludeOr) { EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(DocHitInfo(document_id1, kSectionIdMaskNone))); EXPECT_THAT(results.query_terms, IsEmpty()); + EXPECT_THAT(results.query_term_iterators, IsEmpty()); } { @@ -1520,11 +1902,17 @@ TEST_F(QueryProcessorTest, ExcludeOr) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds + DocHitInfo expectedDocHitInfo1(document_id1); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0, + /*hit_term_frequency=*/1); + DocHitInfo expectedDocHitInfo2(document_id2); + expectedDocHitInfo2.UpdateSection(/*section_id=*/0, + /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id2, section_id_mask), - DocHitInfo(document_id1, section_id_mask))); + ElementsAre(expectedDocHitInfo2, expectedDocHitInfo1)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } } @@ -1561,7 +1949,6 @@ TEST_F(QueryProcessorTest, DeletedFilter) { // Populate the index SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "animal dog" @@ -1593,10 +1980,13 @@ TEST_F(QueryProcessorTest, DeletedFilter) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds + DocHitInfo expectedDocHitInfo(document_id2); + expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id2, section_id_mask))); + ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, NamespaceFilter) { @@ -1631,7 +2021,6 @@ TEST_F(QueryProcessorTest, NamespaceFilter) { // Populate the index SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "animal dog" @@ -1664,10 +2053,13 @@ TEST_F(QueryProcessorTest, NamespaceFilter) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds + DocHitInfo expectedDocHitInfo(document_id1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id1, section_id_mask))); + ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, SchemaTypeFilter) { @@ -1703,7 +2095,6 @@ TEST_F(QueryProcessorTest, SchemaTypeFilter) { // Populate the index SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "animal dog" @@ -1732,10 +2123,13 @@ TEST_F(QueryProcessorTest, SchemaTypeFilter) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds + DocHitInfo expectedDocHitInfo(document_id1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id1, section_id_mask))); + ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, SectionFilterForOneDocument) { @@ -1768,7 +2162,6 @@ TEST_F(QueryProcessorTest, SectionFilterForOneDocument) { .Build())); // Populate the index - SectionIdMask section_id_mask = 1U << subject_section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document has content "animal" @@ -1792,10 +2185,13 @@ TEST_F(QueryProcessorTest, SectionFilterForOneDocument) { query_processor->ParseSearch(search_spec)); // Descending order of valid DocumentIds + DocHitInfo expectedDocHitInfo(document_id); + expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id, section_id_mask))); + ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms["subject"], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, SectionFilterAcrossSchemaTypes) { @@ -1840,8 +2236,6 @@ TEST_F(QueryProcessorTest, SectionFilterAcrossSchemaTypes) { .Build())); // Populate the index - SectionIdMask email_section_id_mask = 1U << email_foo_section_id; - SectionIdMask message_section_id_mask = 1U << message_foo_section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Email document has content "animal" @@ -1871,12 +2265,15 @@ TEST_F(QueryProcessorTest, SectionFilterAcrossSchemaTypes) { // Ordered by descending DocumentId, so message comes first since it was // inserted last - EXPECT_THAT( - GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(message_document_id, message_section_id_mask), - DocHitInfo(email_document_id, email_section_id_mask))); + DocHitInfo expectedDocHitInfo1(message_document_id); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + DocHitInfo expectedDocHitInfo2(email_document_id); + expectedDocHitInfo2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/1); + EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), + ElementsAre(expectedDocHitInfo1, expectedDocHitInfo2)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms["foo"], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, SectionFilterWithinSchemaType) { @@ -1920,7 +2317,6 @@ TEST_F(QueryProcessorTest, SectionFilterWithinSchemaType) { .Build())); // Populate the index - SectionIdMask email_section_id_mask = 1U << email_foo_section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Email document has content "animal" @@ -1952,11 +2348,13 @@ TEST_F(QueryProcessorTest, SectionFilterWithinSchemaType) { // Shouldn't include the message document since we're only looking at email // types - EXPECT_THAT( - GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(email_document_id, email_section_id_mask))); + DocHitInfo expectedDocHitInfo(email_document_id); + expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), + ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms["foo"], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, SectionFilterRespectsDifferentSectionIds) { @@ -2000,7 +2398,6 @@ TEST_F(QueryProcessorTest, SectionFilterRespectsDifferentSectionIds) { .Build())); // Populate the index - SectionIdMask email_section_id_mask = 1U << email_foo_section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Email document has content "animal" @@ -2033,11 +2430,13 @@ TEST_F(QueryProcessorTest, SectionFilterRespectsDifferentSectionIds) { // Even though the section id is the same, we should be able to tell that it // doesn't match to the name of the section filter - EXPECT_THAT( - GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(email_document_id, email_section_id_mask))); + DocHitInfo expectedDocHitInfo(email_document_id); + expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), + ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms["foo"], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, NonexistentSectionFilterReturnsEmptyResults) { @@ -2095,6 +2494,7 @@ TEST_F(QueryProcessorTest, NonexistentSectionFilterReturnsEmptyResults) { EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms["nonexistent"], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, UnindexedSectionFilterReturnsEmptyResults) { @@ -2152,6 +2552,7 @@ TEST_F(QueryProcessorTest, UnindexedSectionFilterReturnsEmptyResults) { EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), IsEmpty()); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms["foo"], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } TEST_F(QueryProcessorTest, SectionFilterTermAndUnrestrictedTerm) { @@ -2195,8 +2596,6 @@ TEST_F(QueryProcessorTest, SectionFilterTermAndUnrestrictedTerm) { .Build())); // Poplate the index - SectionIdMask email_section_id_mask = 1U << email_foo_section_id; - SectionIdMask message_section_id_mask = 1U << message_foo_section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Email document has content "animal" @@ -2228,13 +2627,16 @@ TEST_F(QueryProcessorTest, SectionFilterTermAndUnrestrictedTerm) { // Ordered by descending DocumentId, so message comes first since it was // inserted last - EXPECT_THAT( - GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(message_document_id, message_section_id_mask), - DocHitInfo(email_document_id, email_section_id_mask))); + DocHitInfo expectedDocHitInfo1(message_document_id); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + DocHitInfo expectedDocHitInfo2(email_document_id); + expectedDocHitInfo2.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), + ElementsAre(expectedDocHitInfo1, expectedDocHitInfo2)); EXPECT_THAT(results.query_terms, SizeIs(2)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("cat")); EXPECT_THAT(results.query_terms["foo"], UnorderedElementsAre("animal")); + EXPECT_THAT(results.query_term_iterators, SizeIs(2)); } TEST_F(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) { @@ -2263,7 +2665,6 @@ TEST_F(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) { // Populate the index int section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; EXPECT_THAT( @@ -2289,8 +2690,10 @@ TEST_F(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) { ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, query_processor->ParseSearch(search_spec)); + DocHitInfo expectedDocHitInfo(document_id); + expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), - ElementsAre(DocHitInfo(document_id, section_id_mask))); + ElementsAre(expectedDocHitInfo)); } TEST_F(QueryProcessorTest, DocumentPastTtlFilteredOut) { diff --git a/icing/result/projection-tree.cc b/icing/result/projection-tree.cc index 382fcb4..67617a3 100644 --- a/icing/result/projection-tree.cc +++ b/icing/result/projection-tree.cc @@ -22,8 +22,7 @@ namespace icing { namespace lib { -ProjectionTree::ProjectionTree( - const ResultSpecProto::TypePropertyMask& type_field_mask) { +ProjectionTree::ProjectionTree(const TypePropertyMask& type_field_mask) { for (const std::string& field_mask : type_field_mask.paths()) { Node* current_node = &root_; for (std::string_view sub_field_mask : diff --git a/icing/result/projection-tree.h b/icing/result/projection-tree.h index a87a8fc..b2e5ffc 100644 --- a/icing/result/projection-tree.h +++ b/icing/result/projection-tree.h @@ -35,8 +35,7 @@ class ProjectionTree { std::vector<Node> children; }; - explicit ProjectionTree( - const ResultSpecProto::TypePropertyMask& type_field_mask); + explicit ProjectionTree(const TypePropertyMask& type_field_mask); const Node& root() const { return root_; } diff --git a/icing/result/projection-tree_test.cc b/icing/result/projection-tree_test.cc index 77d1d21..2b0f966 100644 --- a/icing/result/projection-tree_test.cc +++ b/icing/result/projection-tree_test.cc @@ -28,14 +28,14 @@ using ::testing::IsEmpty; using ::testing::SizeIs; TEST(ProjectionTreeTest, CreateEmptyFieldMasks) { - ResultSpecProto::TypePropertyMask type_field_mask; + TypePropertyMask type_field_mask; ProjectionTree tree(type_field_mask); EXPECT_THAT(tree.root().name, IsEmpty()); EXPECT_THAT(tree.root().children, IsEmpty()); } TEST(ProjectionTreeTest, CreateTreeTopLevel) { - ResultSpecProto::TypePropertyMask type_field_mask; + TypePropertyMask type_field_mask; type_field_mask.add_paths("subject"); ProjectionTree tree(type_field_mask); @@ -46,7 +46,7 @@ TEST(ProjectionTreeTest, CreateTreeTopLevel) { } TEST(ProjectionTreeTest, CreateTreeMultipleTopLevel) { - ResultSpecProto::TypePropertyMask type_field_mask; + TypePropertyMask type_field_mask; type_field_mask.add_paths("subject"); type_field_mask.add_paths("body"); @@ -60,7 +60,7 @@ TEST(ProjectionTreeTest, CreateTreeMultipleTopLevel) { } TEST(ProjectionTreeTest, CreateTreeNested) { - ResultSpecProto::TypePropertyMask type_field_mask; + TypePropertyMask type_field_mask; type_field_mask.add_paths("subject.body"); type_field_mask.add_paths("body"); @@ -76,7 +76,7 @@ TEST(ProjectionTreeTest, CreateTreeNested) { } TEST(ProjectionTreeTest, CreateTreeNestedSharedNode) { - ResultSpecProto::TypePropertyMask type_field_mask; + TypePropertyMask type_field_mask; type_field_mask.add_paths("sender.name.first"); type_field_mask.add_paths("sender.emailAddress"); diff --git a/icing/result/projector.cc b/icing/result/projector.cc new file mode 100644 index 0000000..8acdc8a --- /dev/null +++ b/icing/result/projector.cc @@ -0,0 +1,60 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/result/projector.h" + +#include <algorithm> + +namespace icing { +namespace lib { + +namespace projector { + +void Project(const std::vector<ProjectionTree::Node>& projection_tree, + DocumentProto* document) { + int num_kept = 0; + for (int cur_pos = 0; cur_pos < document->properties_size(); ++cur_pos) { + PropertyProto* prop = document->mutable_properties(cur_pos); + auto itr = std::find_if(projection_tree.begin(), projection_tree.end(), + [&prop](const ProjectionTree::Node& node) { + return node.name == prop->name(); + }); + if (itr == projection_tree.end()) { + // Property is not present in the projection tree. Just skip it. + continue; + } + // This property should be kept. + document->mutable_properties()->SwapElements(num_kept, cur_pos); + ++num_kept; + if (itr->children.empty()) { + // A field mask does refer to this property, but it has no children. So + // we should take the entire property, with all of its + // subproperties/values + continue; + } + // The field mask refers to children of this property. Recurse through the + // document values that this property holds and project the children + // requested by this field mask. + for (DocumentProto& subproperty : *(prop->mutable_document_values())) { + Project(itr->children, &subproperty); + } + } + document->mutable_properties()->DeleteSubrange( + num_kept, document->properties_size() - num_kept); +} + +} // namespace projector + +} // namespace lib +} // namespace icing diff --git a/icing/store/enable-bm25f.h b/icing/result/projector.h index cee94d1..43d9052 100644 --- a/icing/store/enable-bm25f.h +++ b/icing/result/projector.h @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Google LLC +// Copyright (C) 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,20 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef ICING_STORE_ENABLE_BM25F_H_ -#define ICING_STORE_ENABLE_BM25F_H_ +#ifndef ICING_RESULT_PROJECTOR_H_ +#define ICING_RESULT_PROJECTOR_H_ + +#include <vector> + +#include "icing/proto/document.pb.h" +#include "icing/result/projection-tree.h" namespace icing { namespace lib { -inline bool enable_bm25f_ = false; +namespace projector { -inline bool enableBm25f() { return enable_bm25f_; } +void Project(const std::vector<ProjectionTree::Node>& projection_tree, + DocumentProto* document); -// Setter for testing purposes. It should never be called in production code. -inline void setEnableBm25f(bool enable_bm25f) { enable_bm25f_ = enable_bm25f; } +} // namespace projector } // namespace lib } // namespace icing -#endif // ICING_STORE_ENABLE_BM25F_H_ +#endif // ICING_RESULT_PROJECTOR_H_ diff --git a/icing/result/result-retriever.cc b/icing/result/result-retriever.cc index 0b8ad4a..85e78a8 100644 --- a/icing/result/result-retriever.cc +++ b/icing/result/result-retriever.cc @@ -22,48 +22,13 @@ #include "icing/proto/term.pb.h" #include "icing/result/page-result-state.h" #include "icing/result/projection-tree.h" +#include "icing/result/projector.h" #include "icing/result/snippet-context.h" #include "icing/util/status-macros.h" namespace icing { namespace lib { -namespace { - -void Project(const std::vector<ProjectionTree::Node>& projection_tree, - google::protobuf::RepeatedPtrField<PropertyProto>* properties) { - int num_kept = 0; - for (int cur_pos = 0; cur_pos < properties->size(); ++cur_pos) { - PropertyProto* prop = properties->Mutable(cur_pos); - auto itr = std::find_if(projection_tree.begin(), projection_tree.end(), - [&prop](const ProjectionTree::Node& node) { - return node.name == prop->name(); - }); - if (itr == projection_tree.end()) { - // Property is not present in the projection tree. Just skip it. - continue; - } - // This property should be kept. - properties->SwapElements(num_kept, cur_pos); - ++num_kept; - if (itr->children.empty()) { - // A field mask does refer to this property, but it has no children. So - // we should take the entire property, with all of its - // subproperties/values - continue; - } - // The field mask refers to children of this property. Recurse through the - // document values that this property holds and project the children - // requested by this field mask. - for (DocumentProto& subproperty : *(prop->mutable_document_values())) { - Project(itr->children, subproperty.mutable_properties()); - } - } - properties->DeleteSubrange(num_kept, properties->size() - num_kept); -} - -} // namespace - libtextclassifier3::StatusOr<std::unique_ptr<ResultRetriever>> ResultRetriever::Create(const DocumentStore* doc_store, const SchemaStore* schema_store, @@ -118,17 +83,15 @@ ResultRetriever::RetrieveResults( } } + DocumentProto document = std::move(document_or).ValueOrDie(); // Apply projection - auto itr = page_result_state.projection_tree_map.find( - document_or.ValueOrDie().schema()); - + auto itr = page_result_state.projection_tree_map.find(document.schema()); if (itr != page_result_state.projection_tree_map.end()) { - Project(itr->second.root().children, - document_or.ValueOrDie().mutable_properties()); + projector::Project(itr->second.root().children, &document); } else if (wildcard_projection_tree_itr != page_result_state.projection_tree_map.end()) { - Project(wildcard_projection_tree_itr->second.root().children, - document_or.ValueOrDie().mutable_properties()); + projector::Project(wildcard_projection_tree_itr->second.root().children, + &document); } SearchResultProto::ResultProto result; @@ -137,13 +100,13 @@ ResultRetriever::RetrieveResults( remaining_num_to_snippet > search_results.size()) { SnippetProto snippet_proto = snippet_retriever_->RetrieveSnippet( snippet_context.query_terms, snippet_context.match_type, - snippet_context.snippet_spec, document_or.ValueOrDie(), + snippet_context.snippet_spec, document, scored_document_hit.hit_section_id_mask()); *result.mutable_snippet() = std::move(snippet_proto); } // Add the document, itself. - *result.mutable_document() = std::move(document_or).ValueOrDie(); + *result.mutable_document() = std::move(document); search_results.push_back(std::move(result)); } return search_results; diff --git a/icing/result/result-retriever_test.cc b/icing/result/result-retriever_test.cc index 98cc75a..7cb2d62 100644 --- a/icing/result/result-retriever_test.cc +++ b/icing/result/result-retriever_test.cc @@ -779,7 +779,7 @@ TEST_F(ResultRetrieverTest, ProjectionTopLevelLeadNodeFieldPath) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask type_property_mask; + TypePropertyMask type_property_mask; type_property_mask.set_schema_type("Email"); type_property_mask.add_paths("name"); std::unordered_map<std::string, ProjectionTree> type_projection_tree_map; @@ -881,7 +881,7 @@ TEST_F(ResultRetrieverTest, ProjectionNestedLeafNodeFieldPath) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask type_property_mask; + TypePropertyMask type_property_mask; type_property_mask.set_schema_type("Email"); type_property_mask.add_paths("sender.name"); std::unordered_map<std::string, ProjectionTree> type_projection_tree_map; @@ -994,7 +994,7 @@ TEST_F(ResultRetrieverTest, ProjectionIntermediateNodeFieldPath) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask type_property_mask; + TypePropertyMask type_property_mask; type_property_mask.set_schema_type("Email"); type_property_mask.add_paths("sender"); std::unordered_map<std::string, ProjectionTree> type_projection_tree_map; @@ -1111,7 +1111,7 @@ TEST_F(ResultRetrieverTest, ProjectionMultipleNestedFieldPaths) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask type_property_mask; + TypePropertyMask type_property_mask; type_property_mask.set_schema_type("Email"); type_property_mask.add_paths("sender.name"); type_property_mask.add_paths("sender.emailAddress"); @@ -1214,7 +1214,7 @@ TEST_F(ResultRetrieverTest, ProjectionEmptyFieldPath) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask type_property_mask; + TypePropertyMask type_property_mask; type_property_mask.set_schema_type("Email"); std::unordered_map<std::string, ProjectionTree> type_projection_tree_map; type_projection_tree_map.insert( @@ -1297,7 +1297,7 @@ TEST_F(ResultRetrieverTest, ProjectionInvalidFieldPath) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask type_property_mask; + TypePropertyMask type_property_mask; type_property_mask.set_schema_type("Email"); type_property_mask.add_paths("nonExistentProperty"); std::unordered_map<std::string, ProjectionTree> type_projection_tree_map; @@ -1381,7 +1381,7 @@ TEST_F(ResultRetrieverTest, ProjectionValidAndInvalidFieldPath) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask type_property_mask; + TypePropertyMask type_property_mask; type_property_mask.set_schema_type("Email"); type_property_mask.add_paths("name"); type_property_mask.add_paths("nonExistentProperty"); @@ -1469,7 +1469,7 @@ TEST_F(ResultRetrieverTest, ProjectionMultipleTypesNoWildcards) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask type_property_mask; + TypePropertyMask type_property_mask; type_property_mask.set_schema_type("Email"); type_property_mask.add_paths("name"); std::unordered_map<std::string, ProjectionTree> type_projection_tree_map; @@ -1558,7 +1558,7 @@ TEST_F(ResultRetrieverTest, ProjectionMultipleTypesWildcard) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask wildcard_type_property_mask; + TypePropertyMask wildcard_type_property_mask; wildcard_type_property_mask.set_schema_type( std::string(ProjectionTree::kSchemaTypeWildcard)); wildcard_type_property_mask.add_paths("name"); @@ -1648,10 +1648,10 @@ TEST_F(ResultRetrieverTest, ProjectionMultipleTypesWildcardWithOneOverride) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask email_type_property_mask; + TypePropertyMask email_type_property_mask; email_type_property_mask.set_schema_type("Email"); email_type_property_mask.add_paths("body"); - ResultSpecProto::TypePropertyMask wildcard_type_property_mask; + TypePropertyMask wildcard_type_property_mask; wildcard_type_property_mask.set_schema_type( std::string(ProjectionTree::kSchemaTypeWildcard)); wildcard_type_property_mask.add_paths("name"); @@ -1752,10 +1752,10 @@ TEST_F(ResultRetrieverTest, ProjectionSingleTypesWildcardAndOverride) { {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask email_type_property_mask; + TypePropertyMask email_type_property_mask; email_type_property_mask.set_schema_type("Email"); email_type_property_mask.add_paths("sender.name"); - ResultSpecProto::TypePropertyMask wildcard_type_property_mask; + TypePropertyMask wildcard_type_property_mask; wildcard_type_property_mask.set_schema_type( std::string(ProjectionTree::kSchemaTypeWildcard)); wildcard_type_property_mask.add_paths("name"); @@ -1861,10 +1861,10 @@ TEST_F(ResultRetrieverTest, {document_id1, hit_section_id_mask, /*score=*/0}, {document_id2, hit_section_id_mask, /*score=*/0}}; - ResultSpecProto::TypePropertyMask email_type_property_mask; + TypePropertyMask email_type_property_mask; email_type_property_mask.set_schema_type("Email"); email_type_property_mask.add_paths("sender.name"); - ResultSpecProto::TypePropertyMask wildcard_type_property_mask; + TypePropertyMask wildcard_type_property_mask; wildcard_type_property_mask.set_schema_type( std::string(ProjectionTree::kSchemaTypeWildcard)); wildcard_type_property_mask.add_paths("sender"); diff --git a/icing/result/result-state.cc b/icing/result/result-state.cc index f1479b9..82738a9 100644 --- a/icing/result/result-state.cc +++ b/icing/result/result-state.cc @@ -47,7 +47,7 @@ ResultState::ResultState(std::vector<ScoredDocumentHit> scored_document_hits, num_returned_(0), scored_document_hit_comparator_(scoring_spec.order_by() == ScoringSpecProto::Order::DESC) { - for (const ResultSpecProto::TypePropertyMask& type_field_mask : + for (const TypePropertyMask& type_field_mask : result_spec.type_property_masks()) { projection_tree_map_.insert( {type_field_mask.schema_type(), ProjectionTree(type_field_mask)}); diff --git a/icing/result/snippet-retriever.cc b/icing/result/snippet-retriever.cc index 09d0f7a..d4a5f79 100644 --- a/icing/result/snippet-retriever.cc +++ b/icing/result/snippet-retriever.cc @@ -334,7 +334,8 @@ SnippetProto SnippetRetriever::RetrieveSnippet( snippet_spec.num_matches_per_property(); // Retrieve values and snippet them. - auto values_or = schema_store_.GetSectionContent(document, metadata->path); + auto values_or = + schema_store_.GetStringSectionContent(document, metadata->path); if (!values_or.ok()) { continue; } @@ -344,7 +345,7 @@ SnippetProto SnippetRetriever::RetrieveSnippet( // If we couldn't create the tokenizer properly, just skip this section. continue; } - std::vector<std::string> values = values_or.ValueOrDie(); + std::vector<std::string_view> values = values_or.ValueOrDie(); for (int value_index = 0; value_index < values.size(); ++value_index) { if (match_options.max_matches_remaining <= 0) { break; diff --git a/icing/schema/schema-store.cc b/icing/schema/schema-store.cc index e54cc0c..b43d2a4 100644 --- a/icing/schema/schema-store.cc +++ b/icing/schema/schema-store.cc @@ -421,16 +421,16 @@ libtextclassifier3::StatusOr<SchemaTypeId> SchemaStore::GetSchemaTypeId( return schema_type_mapper_->Get(schema_type); } -libtextclassifier3::StatusOr<std::vector<std::string>> -SchemaStore::GetSectionContent(const DocumentProto& document, - std::string_view section_path) const { - return section_manager_->GetSectionContent(document, section_path); +libtextclassifier3::StatusOr<std::vector<std::string_view>> +SchemaStore::GetStringSectionContent(const DocumentProto& document, + std::string_view section_path) const { + return section_manager_->GetStringSectionContent(document, section_path); } -libtextclassifier3::StatusOr<std::vector<std::string>> -SchemaStore::GetSectionContent(const DocumentProto& document, - SectionId section_id) const { - return section_manager_->GetSectionContent(document, section_id); +libtextclassifier3::StatusOr<std::vector<std::string_view>> +SchemaStore::GetStringSectionContent(const DocumentProto& document, + SectionId section_id) const { + return section_manager_->GetStringSectionContent(document, section_id); } libtextclassifier3::StatusOr<const SectionMetadata*> diff --git a/icing/schema/schema-store.h b/icing/schema/schema-store.h index cff7abd..3854704 100644 --- a/icing/schema/schema-store.h +++ b/icing/schema/schema-store.h @@ -180,8 +180,9 @@ class SchemaStore { // 1. Property is optional and not found in the document // 2. section_path is invalid // 3. Content is empty - libtextclassifier3::StatusOr<std::vector<std::string>> GetSectionContent( - const DocumentProto& document, std::string_view section_path) const; + libtextclassifier3::StatusOr<std::vector<std::string_view>> + GetStringSectionContent(const DocumentProto& document, + std::string_view section_path) const; // Finds content of a section by id // @@ -189,8 +190,9 @@ class SchemaStore { // A string of content on success // INVALID_ARGUMENT if section id is invalid // NOT_FOUND if type config name of document not found - libtextclassifier3::StatusOr<std::vector<std::string>> GetSectionContent( - const DocumentProto& document, SectionId section_id) const; + libtextclassifier3::StatusOr<std::vector<std::string_view>> + GetStringSectionContent(const DocumentProto& document, + SectionId section_id) const; // Returns the SectionMetadata associated with the SectionId that's in the // SchemaTypeId. diff --git a/icing/schema/section-manager.cc b/icing/schema/section-manager.cc index 0285cef..a10e9b9 100644 --- a/icing/schema/section-manager.cc +++ b/icing/schema/section-manager.cc @@ -155,8 +155,9 @@ BuildSectionMetadataCache(const SchemaUtil::TypeConfigMap& type_config_map, // Helper function to get string content from a property. Repeated values are // joined into one string. We only care about the STRING data type. -std::vector<std::string> GetPropertyContent(const PropertyProto& property) { - std::vector<std::string> values; +std::vector<std::string_view> GetStringPropertyContent( + const PropertyProto& property) { + std::vector<std::string_view> values; if (!property.string_values().empty()) { std::copy(property.string_values().begin(), property.string_values().end(), std::back_inserter(values)); @@ -194,9 +195,9 @@ SectionManager::Create(const SchemaUtil::TypeConfigMap& type_config_map, schema_type_mapper, std::move(section_metadata_cache))); } -libtextclassifier3::StatusOr<std::vector<std::string>> -SectionManager::GetSectionContent(const DocumentProto& document, - std::string_view section_path) const { +libtextclassifier3::StatusOr<std::vector<std::string_view>> +SectionManager::GetStringSectionContent(const DocumentProto& document, + std::string_view section_path) const { // Finds the first property name in section_path size_t separator_position = section_path.find(kPropertySeparator); std::string_view current_property_name = @@ -221,7 +222,8 @@ SectionManager::GetSectionContent(const DocumentProto& document, if (separator_position == std::string::npos) { // Current property name is the last one in section path - std::vector<std::string> content = GetPropertyContent(*property_iterator); + std::vector<std::string_view> content = + GetStringPropertyContent(*property_iterator); if (content.empty()) { // The content of property is explicitly set to empty, we'll treat it as // NOT_FOUND because the index doesn't care about empty strings. @@ -234,11 +236,13 @@ SectionManager::GetSectionContent(const DocumentProto& document, // Gets section content recursively std::string_view sub_section_path = section_path.substr(separator_position + 1); - std::vector<std::string> nested_document_content; + std::vector<std::string_view> nested_document_content; for (const auto& nested_document : property_iterator->document_values()) { - auto content_or = GetSectionContent(nested_document, sub_section_path); + auto content_or = + GetStringSectionContent(nested_document, sub_section_path); if (content_or.ok()) { - std::vector<std::string> content = std::move(content_or).ValueOrDie(); + std::vector<std::string_view> content = + std::move(content_or).ValueOrDie(); std::move(content.begin(), content.end(), std::back_inserter(nested_document_content)); } @@ -251,9 +255,9 @@ SectionManager::GetSectionContent(const DocumentProto& document, return nested_document_content; } -libtextclassifier3::StatusOr<std::vector<std::string>> -SectionManager::GetSectionContent(const DocumentProto& document, - SectionId section_id) const { +libtextclassifier3::StatusOr<std::vector<std::string_view>> +SectionManager::GetStringSectionContent(const DocumentProto& document, + SectionId section_id) const { if (!IsSectionIdValid(section_id)) { return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( "Section id %d is greater than the max value %d", section_id, @@ -270,7 +274,7 @@ SectionManager::GetSectionContent(const DocumentProto& document, } // The index of metadata list is the same as the section id, so we can use // section id as the index. - return GetSectionContent(document, metadata_list[section_id].path); + return GetStringSectionContent(document, metadata_list[section_id].path); } libtextclassifier3::StatusOr<const SectionMetadata*> @@ -303,7 +307,7 @@ SectionManager::ExtractSections(const DocumentProto& document) const { std::vector<Section> sections; for (const auto& section_metadata : metadata_list) { auto section_content_or = - GetSectionContent(document, section_metadata.path); + GetStringSectionContent(document, section_metadata.path); // Adds to result vector if section is found in document if (section_content_or.ok()) { sections.emplace_back(SectionMetadata(section_metadata), diff --git a/icing/schema/section-manager.h b/icing/schema/section-manager.h index 475fa6a..191a169 100644 --- a/icing/schema/section-manager.h +++ b/icing/schema/section-manager.h @@ -61,8 +61,9 @@ class SectionManager { // 1. Property is optional and not found in the document // 2. section_path is invalid // 3. Content is empty - libtextclassifier3::StatusOr<std::vector<std::string>> GetSectionContent( - const DocumentProto& document, std::string_view section_path) const; + libtextclassifier3::StatusOr<std::vector<std::string_view>> + GetStringSectionContent(const DocumentProto& document, + std::string_view section_path) const; // Finds content of a section by id // @@ -70,8 +71,9 @@ class SectionManager { // A string of content on success // INVALID_ARGUMENT if section id is invalid // NOT_FOUND if type config name of document not found - libtextclassifier3::StatusOr<std::vector<std::string>> GetSectionContent( - const DocumentProto& document, SectionId section_id) const; + libtextclassifier3::StatusOr<std::vector<std::string_view>> + GetStringSectionContent(const DocumentProto& document, + SectionId section_id) const; // Returns the SectionMetadata associated with the SectionId that's in the // SchemaTypeId. diff --git a/icing/schema/section-manager_test.cc b/icing/schema/section-manager_test.cc index 2d995df..15d9a19 100644 --- a/icing/schema/section-manager_test.cc +++ b/icing/schema/section-manager_test.cc @@ -186,60 +186,64 @@ TEST_F(SectionManagerTest, CreationWithTooManyPropertiesShouldFail) { HasSubstr("Too many properties"))); } -TEST_F(SectionManagerTest, GetSectionContent) { +TEST_F(SectionManagerTest, GetStringSectionContent) { ICING_ASSERT_OK_AND_ASSIGN( auto section_manager, SectionManager::Create(type_config_map_, schema_type_mapper_.get())); // Test simple section paths - EXPECT_THAT(section_manager->GetSectionContent(email_document_, - /*section_path*/ "subject"), - IsOkAndHolds(ElementsAre("the subject"))); - EXPECT_THAT(section_manager->GetSectionContent(email_document_, - /*section_path*/ "text"), + EXPECT_THAT( + section_manager->GetStringSectionContent(email_document_, + /*section_path*/ "subject"), + IsOkAndHolds(ElementsAre("the subject"))); + EXPECT_THAT(section_manager->GetStringSectionContent(email_document_, + /*section_path*/ "text"), IsOkAndHolds(ElementsAre("the text"))); // Test repeated values, they are joined into one string - ICING_ASSERT_OK_AND_ASSIGN(auto content, section_manager->GetSectionContent( - email_document_, + ICING_ASSERT_OK_AND_ASSIGN( + auto content, + section_manager->GetStringSectionContent(email_document_, /*section_path*/ "recipients")); EXPECT_THAT(content, ElementsAre("recipient1", "recipient2", "recipient3")); // Test concatenated section paths: "property1.property2" - ICING_ASSERT_OK_AND_ASSIGN(content, section_manager->GetSectionContent( + ICING_ASSERT_OK_AND_ASSIGN(content, section_manager->GetStringSectionContent( conversation_document_, /*section_path*/ "emails.subject")); EXPECT_THAT(content, ElementsAre("the subject", "the subject")); - ICING_ASSERT_OK_AND_ASSIGN(content, section_manager->GetSectionContent( + ICING_ASSERT_OK_AND_ASSIGN(content, section_manager->GetStringSectionContent( conversation_document_, /*section_path*/ "emails.text")); EXPECT_THAT(content, ElementsAre("the text", "the text")); - ICING_ASSERT_OK_AND_ASSIGN( - content, - section_manager->GetSectionContent(conversation_document_, - /*section_path*/ "emails.recipients")); + ICING_ASSERT_OK_AND_ASSIGN(content, + section_manager->GetStringSectionContent( + conversation_document_, + /*section_path*/ "emails.recipients")); EXPECT_THAT(content, ElementsAre("recipient1", "recipient2", "recipient3", "recipient1", "recipient2", "recipient3")); // Test non-existing paths - EXPECT_THAT(section_manager->GetSectionContent(email_document_, - /*section_path*/ "name"), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - EXPECT_THAT(section_manager->GetSectionContent(email_document_, - /*section_path*/ "invalid"), + EXPECT_THAT(section_manager->GetStringSectionContent(email_document_, + /*section_path*/ "name"), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); EXPECT_THAT( - section_manager->GetSectionContent(conversation_document_, - /*section_path*/ "emails.invalid"), + section_manager->GetStringSectionContent(email_document_, + /*section_path*/ "invalid"), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(section_manager->GetStringSectionContent( + conversation_document_, + /*section_path*/ "emails.invalid"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); // Test other data types // BYTES type can't be indexed, so content won't be returned - EXPECT_THAT(section_manager->GetSectionContent(email_document_, - /*section_path*/ "attachment"), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT( + section_manager->GetStringSectionContent(email_document_, + /*section_path*/ "attachment"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); // The following tests are similar to the ones above but use section ids // instead of section paths @@ -249,16 +253,16 @@ TEST_F(SectionManagerTest, GetSectionContent) { SectionId subject_section_id = 1; SectionId invalid_email_section_id = 2; ICING_ASSERT_OK_AND_ASSIGN( - content, section_manager->GetSectionContent(email_document_, - recipients_section_id)); + content, section_manager->GetStringSectionContent(email_document_, + recipients_section_id)); EXPECT_THAT(content, ElementsAre("recipient1", "recipient2", "recipient3")); - EXPECT_THAT( - section_manager->GetSectionContent(email_document_, subject_section_id), - IsOkAndHolds(ElementsAre("the subject"))); + EXPECT_THAT(section_manager->GetStringSectionContent(email_document_, + subject_section_id), + IsOkAndHolds(ElementsAre("the subject"))); - EXPECT_THAT(section_manager->GetSectionContent(email_document_, - invalid_email_section_id), + EXPECT_THAT(section_manager->GetStringSectionContent( + email_document_, invalid_email_section_id), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); // Conversation (section id -> section path): @@ -270,21 +274,21 @@ TEST_F(SectionManagerTest, GetSectionContent) { SectionId name_section_id = 2; SectionId invalid_conversation_section_id = 3; ICING_ASSERT_OK_AND_ASSIGN( - content, section_manager->GetSectionContent( + content, section_manager->GetStringSectionContent( conversation_document_, emails_recipients_section_id)); EXPECT_THAT(content, ElementsAre("recipient1", "recipient2", "recipient3", "recipient1", "recipient2", "recipient3")); ICING_ASSERT_OK_AND_ASSIGN( - content, section_manager->GetSectionContent(conversation_document_, - emails_subject_section_id)); + content, section_manager->GetStringSectionContent( + conversation_document_, emails_subject_section_id)); EXPECT_THAT(content, ElementsAre("the subject", "the subject")); - EXPECT_THAT(section_manager->GetSectionContent(conversation_document_, - name_section_id), + EXPECT_THAT(section_manager->GetStringSectionContent(conversation_document_, + name_section_id), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - EXPECT_THAT(section_manager->GetSectionContent( + EXPECT_THAT(section_manager->GetStringSectionContent( conversation_document_, invalid_conversation_section_id), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } diff --git a/icing/schema/section.h b/icing/schema/section.h index 058f261..40e623a 100644 --- a/icing/schema/section.h +++ b/icing/schema/section.h @@ -17,6 +17,7 @@ #include <cstdint> #include <string> +#include <string_view> #include <utility> #include <vector> @@ -83,9 +84,10 @@ struct SectionMetadata { // values of a property. struct Section { SectionMetadata metadata; - std::vector<std::string> content; + std::vector<std::string_view> content; - Section(SectionMetadata&& metadata_in, std::vector<std::string>&& content_in) + Section(SectionMetadata&& metadata_in, + std::vector<std::string_view>&& content_in) : metadata(std::move(metadata_in)), content(std::move(content_in)) {} }; diff --git a/icing/scoring/bm25f-calculator.cc b/icing/scoring/bm25f-calculator.cc new file mode 100644 index 0000000..7495e98 --- /dev/null +++ b/icing/scoring/bm25f-calculator.cc @@ -0,0 +1,223 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/bm25f-calculator.h" + +#include <cstdint> +#include <cstdlib> +#include <string> +#include <unordered_set> +#include <vector> + +#include "icing/absl_ports/str_cat.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/store/corpus-associated-scoring-data.h" +#include "icing/store/corpus-id.h" +#include "icing/store/document-associated-score-data.h" +#include "icing/store/document-id.h" + +namespace icing { +namespace lib { + +// Smoothing parameter, determines the relevance of higher term frequency +// documents. The higher k1, the higher their relevance. 1.2 is the default +// value in the BM25F literature and works well in most corpora. +constexpr float k1_ = 1.2f; +// Smoothing parameter, determines the weight of the document length on the +// final score. The higher b, the higher the influence of the document length. +// 0.7 is the default value in the BM25F literature and works well in most +// corpora. +constexpr float b_ = 0.7f; + +// TODO(b/158603900): add tests for Bm25fCalculator +Bm25fCalculator::Bm25fCalculator(const DocumentStore *document_store) + : document_store_(document_store) {} + +// During initialization, Bm25fCalculator iterates through +// hit-iterators for each query term to pre-compute n(q_i) for each corpus under +// consideration. +void Bm25fCalculator::PrepareToScore( + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + *query_term_iterators) { + Clear(); + TermId term_id = 0; + for (auto &iter : *query_term_iterators) { + const std::string &term = iter.first; + if (term_id_map_.find(term) != term_id_map_.end()) { + continue; + } + term_id_map_[term] = ++term_id; + DocHitInfoIterator *term_it = iter.second.get(); + while (term_it->Advance().ok()) { + auto status_or = document_store_->GetDocumentAssociatedScoreData( + term_it->doc_hit_info().document_id()); + if (!status_or.ok()) { + ICING_LOG(ERROR) << "No document score data"; + continue; + } + DocumentAssociatedScoreData data = status_or.ValueOrDie(); + CorpusId corpus_id = data.corpus_id(); + CorpusTermInfo corpus_term_info(corpus_id, term_id); + corpus_nqi_map_[corpus_term_info.value]++; + } + } +} + +void Bm25fCalculator::Clear() { + term_id_map_.clear(); + corpus_avgdl_map_.clear(); + corpus_nqi_map_.clear(); + corpus_idf_map_.clear(); +} + +// Computes BM25F relevance score for query terms matched in document D. +// +// BM25F = \sum_i IDF(q_i) * tf(q_i, D) +// +// where IDF(q_i) is the Inverse Document Frequency (IDF) weight of the query +// term q_i in the corpus with document D, and tf(q_i, D) is the weighted and +// normalized term frequency of query term q_i in the document D. +float Bm25fCalculator::ComputeScore(const DocHitInfoIterator *query_it, + const DocHitInfo &hit_info, + double default_score) { + auto status_or = + document_store_->GetDocumentAssociatedScoreData(hit_info.document_id()); + if (!status_or.ok()) { + ICING_LOG(ERROR) << "No document score data"; + return default_score; + } + DocumentAssociatedScoreData data = status_or.ValueOrDie(); + std::vector<TermMatchInfo> matched_terms_stats; + query_it->PopulateMatchedTermsStats(&matched_terms_stats); + + float score = 0; + for (const TermMatchInfo &term_match_info : matched_terms_stats) { + float idf_weight = + GetCorpusIdfWeightForTerm(term_match_info.term, data.corpus_id()); + float normalized_tf = + ComputedNormalizedTermFrequency(term_match_info, hit_info, data); + score += idf_weight * normalized_tf; + } + + ICING_VLOG(1) << IcingStringUtil::StringPrintf( + "BM25F: corpus_id:%d docid:%d score:%f\n", data.corpus_id(), + hit_info.document_id(), score); + return score; +} + +// Compute inverse document frequency (IDF) weight for query term in the given +// corpus, and cache it in the map. +// +// N - n(q_i) + 0.5 +// IDF(q_i) = log(1 + ------------------) +// n(q_i) + 0.5 +// +// where N is the number of documents in the corpus, and n(q_i) is the number +// of documents in the corpus containing the query term q_i. +float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, + CorpusId corpus_id) { + TermId term_id = term_id_map_[term]; + + CorpusTermInfo corpus_term_info(corpus_id, term_id); + auto iter = corpus_idf_map_.find(corpus_term_info.value); + if (iter != corpus_idf_map_.end()) { + return iter->second; + } + + // First, figure out corpus scoring data. + auto status_or = document_store_->GetCorpusAssociatedScoreData(corpus_id); + if (!status_or.ok()) { + ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( + "No scoring data for corpus [%d]", corpus_id); + return 0; + } + CorpusAssociatedScoreData csdata = status_or.ValueOrDie(); + + uint32_t num_docs = csdata.num_docs(); + uint32_t nqi = corpus_nqi_map_[corpus_term_info.value]; + float idf = + nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi - 0.5f)) : 0.0f; + corpus_idf_map_.insert({corpus_term_info.value, idf}); + ICING_VLOG(1) << IcingStringUtil::StringPrintf( + "corpus_id:%d term:%s N:%d nqi:%d idf:%f", corpus_id, + std::string(term).c_str(), num_docs, nqi, idf); + return idf; +} + +// Get per corpus average document length and cache the result in the map. +float Bm25fCalculator::GetCorpusAvgDocLength(CorpusId corpus_id) { + auto iter = corpus_avgdl_map_.find(corpus_id); + if (iter != corpus_avgdl_map_.end()) { + return iter->second; + } + + // First, figure out corpus scoring data. + auto status_or = document_store_->GetCorpusAssociatedScoreData(corpus_id); + if (!status_or.ok()) { + ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( + "No scoring data for corpus [%d]", corpus_id); + return 0; + } + CorpusAssociatedScoreData csdata = status_or.ValueOrDie(); + + corpus_avgdl_map_[corpus_id] = csdata.average_doc_length_in_tokens(); + return csdata.average_doc_length_in_tokens(); +} + +// Computes normalized term frequency for query term q_i in document D. +// +// f(q_i, D) * (k1 + 1) +// Normalized TF = -------------------------------------------- +// f(q_i, D) + k1 * (1 - b + b * |D| / avgdl) +// +// where f(q_i, D) is the frequency of query term q_i in document D, +// |D| is the #tokens in D, avgdl is the average document length in the corpus, +// k1 and b are smoothing parameters. +float Bm25fCalculator::ComputedNormalizedTermFrequency( + const TermMatchInfo &term_match_info, const DocHitInfo &hit_info, + const DocumentAssociatedScoreData &data) { + uint32_t dl = data.length_in_tokens(); + float avgdl = GetCorpusAvgDocLength(data.corpus_id()); + float f_q = + ComputeTermFrequencyForMatchedSections(data.corpus_id(), term_match_info); + float normalized_tf = + f_q * (k1_ + 1) / (f_q + k1_ * (1 - b_ + b_ * dl / avgdl)); + + ICING_VLOG(1) << IcingStringUtil::StringPrintf( + "corpus_id:%d docid:%d dl:%d avgdl:%f f_q:%f norm_tf:%f\n", + data.corpus_id(), hit_info.document_id(), dl, avgdl, f_q, normalized_tf); + return normalized_tf; +} + +// Note: once we support section weights, we should update this function to +// compute the weighted term frequency. +float Bm25fCalculator::ComputeTermFrequencyForMatchedSections( + CorpusId corpus_id, const TermMatchInfo &term_match_info) const { + float sum = 0.0f; + SectionIdMask sections = term_match_info.section_ids_mask; + while (sections != 0) { + SectionId section_id = __builtin_ctz(sections); + sections &= ~(1u << section_id); + + Hit::TermFrequency tf = term_match_info.term_frequencies[section_id]; + if (tf != Hit::kNoTermFrequency) { + sum += tf; + } + } + return sum; +} + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/bm25f-calculator.h b/icing/scoring/bm25f-calculator.h new file mode 100644 index 0000000..91b4f24 --- /dev/null +++ b/icing/scoring/bm25f-calculator.h @@ -0,0 +1,148 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_BM25F_CALCULATOR_H_ +#define ICING_SCORING_BM25F_CALCULATOR_H_ + +#include <cstdint> +#include <string> +#include <unordered_set> +#include <vector> + +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/legacy/index/icing-bit-util.h" +#include "icing/store/corpus-id.h" +#include "icing/store/document-store.h" + +namespace icing { +namespace lib { + +// Bm25fCalculator encapsulates the logic to compute BM25F term-weight based +// ranking function. +// +// The formula to compute BM25F is as follows: +// +// BM25F = \sum_i IDF(q_i) * tf(q_i, D) +// +// where IDF(q_i) is the Inverse Document Frequency (IDF) weight of the query +// term q_i in the corpus with document D, and tf(q_i, D) is the weighted and +// normalized term frequency of query term q_i in the document D. +// +// IDF(q_i) is computed as follows: +// +// N - n(q_i) + 0.5 +// IDF(q_i) = log(1 + ------------------) +// n(q_i) + 0.5 +// +// where N is the number of documents in the corpus, and n(q_i) is the number +// of documents in the corpus containing the query term q_i. +// +// Lastly, tf(q_i, D) is computed as follows: +// +// f(q_i, D) * (k1 + 1) +// Normalized TF = -------------------------------------------- +// f(q_i, D) + k1 * (1 - b + b * |D| / avgdl) +// +// where f(q_i, D) is the frequency of query term q_i in document D, +// |D| is the #tokens in D, avgdl is the average document length in the corpus, +// k1 and b are smoothing parameters. +// +// see: go/icing-bm25f +// see: glossary/bm25 +class Bm25fCalculator { + public: + explicit Bm25fCalculator(const DocumentStore *document_store_); + + // Precompute and cache statistics relevant to BM25F. + // Populates term_id_map_ and corpus_nqi_map_ for use while scoring other + // results. + // The query_term_iterators map is used to build the + // std::unordered_map<std::string_view, TermId> term_id_map_. It must + // outlive the bm25f-calculator otherwise the string_view key in term_id_map_, + // used later to compute a document score, will be meaningless. + void PrepareToScore( + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + *query_term_iterators); + + // Compute the BM25F relevance score for the given hit, represented by + // DocHitInfo. + // The default score will be returned only when the scorer fails to find or + // calculate a score for the document. + float ComputeScore(const DocHitInfoIterator *query_it, + const DocHitInfo &hit_info, double default_score); + + private: + // Compact ID for each query term. + using TermId = uint16_t; + + // Compact representation of <CorpusId, TermId> for use as a key in a + // hash_map. + struct CorpusTermInfo { + // Layout bits: 16 bit CorpusId + 16 bit TermId + using Value = uint32_t; + + Value value; + + static constexpr int kCorpusIdBits = sizeof(CorpusId); + static constexpr int kTermIdBits = sizeof(TermId); + + explicit CorpusTermInfo(CorpusId corpus_id, TermId term_id) : value(0) { + BITFIELD_OR(value, kTermIdBits, kCorpusIdBits, + static_cast<uint64_t>(corpus_id)); + BITFIELD_OR(value, 0, kTermIdBits, term_id); + } + + bool operator==(const CorpusTermInfo &other) const { + return value == other.value; + } + }; + + float GetCorpusIdfWeightForTerm(std::string_view term, CorpusId corpus_id); + float GetCorpusAvgDocLength(CorpusId corpus_id); + float ComputedNormalizedTermFrequency( + const TermMatchInfo &term_match_info, const DocHitInfo &hit_info, + const DocumentAssociatedScoreData &data); + float ComputeTermFrequencyForMatchedSections( + CorpusId corpus_id, const TermMatchInfo &term_match_info) const; + + void Clear(); + + const DocumentStore *document_store_; // Does not own. + + // Map from query term to compact term ID. + // Necessary as a key to the other maps. + // The use of the string_view as key here means that the query_term_iterators + // map must outlive the bm25f + std::unordered_map<std::string_view, TermId> term_id_map_; + + // Map from corpus ID to average document length (avgdl). + // Necessary to calculate the normalized term frequency. + // This information is cached in the DocumentStore::CorpusScoreCache + std::unordered_map<CorpusId, float> corpus_avgdl_map_; + + // Map from <corpus ID, term ID> to number of documents containing term q_i, + // called n(q_i). + // Necessary to calculate IDF(q_i) (inverse document frequency). + // This information must be calculated by iterating through the hits for these + // terms. + std::unordered_map<CorpusTermInfo::Value, uint32_t> corpus_nqi_map_; + + // Map from <corpus ID, term ID> to IDF(q_i) (inverse document frequency). + std::unordered_map<CorpusTermInfo::Value, float> corpus_idf_map_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_BM25F_CALCULATOR_H_ diff --git a/icing/scoring/scorer.cc b/icing/scoring/scorer.cc index 0739532..b29d8b6 100644 --- a/icing/scoring/scorer.cc +++ b/icing/scoring/scorer.cc @@ -18,8 +18,10 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/scoring.pb.h" -#include "icing/store/document-associated-score-data.h" +#include "icing/scoring/bm25f-calculator.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/util/status-macros.h" @@ -33,10 +35,11 @@ class DocumentScoreScorer : public Scorer { double default_score) : document_store_(*document_store), default_score_(default_score) {} - double GetScore(DocumentId document_id) override { + double GetScore(const DocHitInfo& hit_info, + const DocHitInfoIterator*) override { ICING_ASSIGN_OR_RETURN( DocumentAssociatedScoreData score_data, - document_store_.GetDocumentAssociatedScoreData(document_id), + document_store_.GetDocumentAssociatedScoreData(hit_info.document_id()), default_score_); return static_cast<double>(score_data.document_score()); @@ -53,10 +56,11 @@ class DocumentCreationTimestampScorer : public Scorer { double default_score) : document_store_(*document_store), default_score_(default_score) {} - double GetScore(DocumentId document_id) override { + double GetScore(const DocHitInfo& hit_info, + const DocHitInfoIterator*) override { ICING_ASSIGN_OR_RETURN( DocumentAssociatedScoreData score_data, - document_store_.GetDocumentAssociatedScoreData(document_id), + document_store_.GetDocumentAssociatedScoreData(hit_info.document_id()), default_score_); return static_cast<double>(score_data.creation_timestamp_ms()); @@ -67,6 +71,33 @@ class DocumentCreationTimestampScorer : public Scorer { double default_score_; }; +class RelevanceScoreScorer : public Scorer { + public: + explicit RelevanceScoreScorer( + std::unique_ptr<Bm25fCalculator> bm25f_calculator, double default_score) + : bm25f_calculator_(std::move(bm25f_calculator)), + default_score_(default_score) {} + + void PrepareToScore( + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>* + query_term_iterators) { + bm25f_calculator_->PrepareToScore(query_term_iterators); + } + + double GetScore(const DocHitInfo& hit_info, + const DocHitInfoIterator* query_it) override { + if (!query_it) { + return default_score_; + } + return static_cast<double>( + bm25f_calculator_->ComputeScore(query_it, hit_info, default_score_)); + } + + private: + std::unique_ptr<Bm25fCalculator> bm25f_calculator_; + double default_score_; +}; + // A scorer which assigns scores to documents based on usage reports. class UsageScorer : public Scorer { public: @@ -77,10 +108,11 @@ class UsageScorer : public Scorer { ranking_strategy_(ranking_strategy), default_score_(default_score) {} - double GetScore(DocumentId document_id) override { - ICING_ASSIGN_OR_RETURN(UsageStore::UsageScores usage_scores, - document_store_.GetUsageScores(document_id), - default_score_); + double GetScore(const DocHitInfo& hit_info, + const DocHitInfoIterator*) override { + ICING_ASSIGN_OR_RETURN( + UsageStore::UsageScores usage_scores, + document_store_.GetUsageScores(hit_info.document_id()), default_score_); switch (ranking_strategy_) { case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT: @@ -113,7 +145,10 @@ class NoScorer : public Scorer { public: explicit NoScorer(double default_score) : default_score_(default_score) {} - double GetScore(DocumentId document_id) override { return default_score_; } + double GetScore(const DocHitInfo& hit_info, + const DocHitInfoIterator*) override { + return default_score_; + } private: double default_score_; @@ -131,6 +166,11 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( case ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP: return std::make_unique<DocumentCreationTimestampScorer>(document_store, default_score); + case ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE: { + auto bm25f_calculator = std::make_unique<Bm25fCalculator>(document_store); + return std::make_unique<RelevanceScoreScorer>(std::move(bm25f_calculator), + default_score); + } case ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT: [[fallthrough]]; case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT: @@ -144,9 +184,6 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP: return std::make_unique<UsageScorer>(document_store, rank_by, default_score); - case ScoringSpecProto::RankingStrategy:: - RELEVANCE_SCORE_NONFUNCTIONAL_PLACEHOLDER: - [[fallthrough]]; case ScoringSpecProto::RankingStrategy::NONE: return std::make_unique<NoScorer>(default_score); } diff --git a/icing/scoring/scorer.h b/icing/scoring/scorer.h index 55c6b5c..a22db0f 100644 --- a/icing/scoring/scorer.h +++ b/icing/scoring/scorer.h @@ -18,6 +18,8 @@ #include <memory> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/scoring.pb.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" @@ -46,18 +48,28 @@ class Scorer { // Returns a non-negative score of a document. The score can be a // document-associated score which comes from the DocumentProto directly, an - // accumulated score, or even an inferred score. If it fails to find or - // calculate a score, the user-provided default score will be returned. + // accumulated score, a relevance score, or even an inferred score. If it + // fails to find or calculate a score, the user-provided default score will be + // returned. // // Some examples of possible scores: // 1. Document-associated scores: document score, creation timestamp score. // 2. Accumulated scores: usage count score. // 3. Inferred scores: a score calculated by a machine learning model. + // 4. Relevance score: computed as BM25F score. // // NOTE: This method is performance-sensitive as it's called for every // potential result document. We're trying to avoid returning StatusOr<double> // to save a little more time and memory. - virtual double GetScore(DocumentId document_id) = 0; + virtual double GetScore(const DocHitInfo& hit_info, + const DocHitInfoIterator* query_it = nullptr) = 0; + + // Currently only overriden by the RelevanceScoreScorer. + // NOTE: the query_term_iterators map must + // outlive the scorer, see bm25f-calculator for more details. + virtual void PrepareToScore( + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>* + query_term_iterators) {} }; } // namespace lib diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc index b669eb1..b114515 100644 --- a/icing/scoring/scorer_test.cc +++ b/icing/scoring/scorer_test.cc @@ -21,6 +21,7 @@ #include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" +#include "icing/index/hit/doc-hit-info.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" #include "icing/proto/scoring.pb.h" @@ -120,9 +121,10 @@ TEST_F(ScorerTest, ShouldGetDefaultScore) { Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, /*default_score=*/10, document_store())); - DocumentId non_existing_document_id = 1; + // Non existent document id + DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1); // The caller-provided default score is returned - EXPECT_THAT(scorer->GetScore(non_existing_document_id), Eq(10)); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); } TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) { @@ -142,7 +144,8 @@ TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) { Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, /*default_score=*/10, document_store())); - EXPECT_THAT(scorer->GetScore(document_id), Eq(0)); + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); } TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) { @@ -163,7 +166,32 @@ TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) { Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, /*default_score=*/0, document_store())); - EXPECT_THAT(scorer->GetScore(document_id), Eq(5)); + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5)); +} + +// See scoring-processor_test.cc and icing-search-engine_test.cc for better +// Bm25F scoring tests. +TEST_F(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { + // Creates a test document with document score 5 + DocumentProto test_document = + DocumentBuilder() + .SetScore(5) + .SetKey("icing", "email/1") + .SetSchema("email") + .AddStringProperty("subject", "subject foo") + .SetCreationTimestampMs(fake_clock2().GetSystemTimeMilliseconds()) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, + document_store()->Put(test_document)); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + Scorer::Create(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, + /*default_score=*/10, document_store())); + + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); } TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) { @@ -193,9 +221,11 @@ TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) { Scorer::Create(ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP, /*default_score=*/0, document_store())); - EXPECT_THAT(scorer->GetScore(document_id1), + DocHitInfo docHitInfo1 = DocHitInfo(document_id1); + DocHitInfo docHitInfo2 = DocHitInfo(document_id2); + EXPECT_THAT(scorer->GetScore(docHitInfo1), Eq(fake_clock1().GetSystemTimeMilliseconds())); - EXPECT_THAT(scorer->GetScore(document_id2), + EXPECT_THAT(scorer->GetScore(docHitInfo2), Eq(fake_clock2().GetSystemTimeMilliseconds())); } @@ -224,9 +254,10 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { std::unique_ptr<Scorer> scorer3, Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, /*default_score=*/0, document_store())); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); // Report a type1 usage. UsageReport usage_report_type1 = CreateUsageReport( @@ -234,9 +265,9 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { UsageReport::USAGE_TYPE1); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type1)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(1)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(1)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); } TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { @@ -264,9 +295,10 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { std::unique_ptr<Scorer> scorer3, Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, /*default_score=*/0, document_store())); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); // Report a type2 usage. UsageReport usage_report_type2 = CreateUsageReport( @@ -274,9 +306,9 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { UsageReport::USAGE_TYPE2); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type2)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(1)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(1)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); } TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { @@ -304,9 +336,10 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { std::unique_ptr<Scorer> scorer3, Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, /*default_score=*/0, document_store())); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); // Report a type1 usage. UsageReport usage_report_type3 = CreateUsageReport( @@ -314,9 +347,9 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { UsageReport::USAGE_TYPE3); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type3)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(1)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(1)); } TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { @@ -347,35 +380,36 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { Scorer::Create( ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, /*default_score=*/0, document_store())); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); UsageReport usage_report_type1_time1 = CreateUsageReport( /*name_space=*/"icing", /*uri=*/"email/1", /*timestamp_ms=*/1000, UsageReport::USAGE_TYPE1); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type1_time1)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(1)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(1)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); // Report usage with timestamp = 5000ms, score should be updated. UsageReport usage_report_type1_time5 = CreateUsageReport( /*name_space=*/"icing", /*uri=*/"email/1", /*timestamp_ms=*/5000, UsageReport::USAGE_TYPE1); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type1_time5)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(5)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(5)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); // Report usage with timestamp = 3000ms, score should not be updated. UsageReport usage_report_type1_time3 = CreateUsageReport( /*name_space=*/"icing", /*uri=*/"email/1", /*timestamp_ms=*/3000, UsageReport::USAGE_TYPE1); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type1_time3)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(5)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(5)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); } TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { @@ -406,35 +440,36 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { Scorer::Create( ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, /*default_score=*/0, document_store())); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); UsageReport usage_report_type2_time1 = CreateUsageReport( /*name_space=*/"icing", /*uri=*/"email/1", /*timestamp_ms=*/1000, UsageReport::USAGE_TYPE2); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type2_time1)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(1)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(1)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); // Report usage with timestamp = 5000ms, score should be updated. UsageReport usage_report_type2_time5 = CreateUsageReport( /*name_space=*/"icing", /*uri=*/"email/1", /*timestamp_ms=*/5000, UsageReport::USAGE_TYPE2); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type2_time5)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(5)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(5)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); // Report usage with timestamp = 3000ms, score should not be updated. UsageReport usage_report_type2_time3 = CreateUsageReport( /*name_space=*/"icing", /*uri=*/"email/1", /*timestamp_ms=*/3000, UsageReport::USAGE_TYPE2); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type2_time3)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(5)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(5)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); } TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { @@ -465,35 +500,36 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { Scorer::Create( ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, /*default_score=*/0, document_store())); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(0)); + DocHitInfo docHitInfo = DocHitInfo(document_id); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(0)); UsageReport usage_report_type3_time1 = CreateUsageReport( /*name_space=*/"icing", /*uri=*/"email/1", /*timestamp_ms=*/1000, UsageReport::USAGE_TYPE3); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type3_time1)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(1)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(1)); // Report usage with timestamp = 5000ms, score should be updated. UsageReport usage_report_type3_time5 = CreateUsageReport( /*name_space=*/"icing", /*uri=*/"email/1", /*timestamp_ms=*/5000, UsageReport::USAGE_TYPE3); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type3_time5)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(5)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(5)); // Report usage with timestamp = 3000ms, score should not be updated. UsageReport usage_report_type3_time3 = CreateUsageReport( /*name_space=*/"icing", /*uri=*/"email/1", /*timestamp_ms=*/3000, UsageReport::USAGE_TYPE3); ICING_ASSERT_OK(document_store()->ReportUsage(usage_report_type3_time3)); - EXPECT_THAT(scorer1->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer2->GetScore(document_id), Eq(0)); - EXPECT_THAT(scorer3->GetScore(document_id), Eq(5)); + EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); + EXPECT_THAT(scorer3->GetScore(docHitInfo), Eq(5)); } TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { @@ -502,17 +538,23 @@ TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { Scorer::Create(ScoringSpecProto::RankingStrategy::NONE, /*default_score=*/3, document_store())); - EXPECT_THAT(scorer->GetScore(/*document_id=*/0), Eq(3)); - EXPECT_THAT(scorer->GetScore(/*document_id=*/1), Eq(3)); - EXPECT_THAT(scorer->GetScore(/*document_id=*/2), Eq(3)); + DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0); + DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1); + DocHitInfo docHitInfo3 = DocHitInfo(/*document_id_in=*/2); + EXPECT_THAT(scorer->GetScore(docHitInfo1), Eq(3)); + EXPECT_THAT(scorer->GetScore(docHitInfo2), Eq(3)); + EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3)); ICING_ASSERT_OK_AND_ASSIGN( scorer, Scorer::Create(ScoringSpecProto::RankingStrategy::NONE, /*default_score=*/111, document_store())); - EXPECT_THAT(scorer->GetScore(/*document_id=*/4), Eq(111)); - EXPECT_THAT(scorer->GetScore(/*document_id=*/5), Eq(111)); - EXPECT_THAT(scorer->GetScore(/*document_id=*/6), Eq(111)); + docHitInfo1 = DocHitInfo(/*document_id_in=*/4); + docHitInfo2 = DocHitInfo(/*document_id_in=*/5); + docHitInfo3 = DocHitInfo(/*document_id_in=*/6); + EXPECT_THAT(scorer->GetScore(docHitInfo1), Eq(111)); + EXPECT_THAT(scorer->GetScore(docHitInfo2), Eq(111)); + EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(111)); } } // namespace diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc index 0933094..24480ef 100644 --- a/icing/scoring/scoring-processor.cc +++ b/icing/scoring/scoring-processor.cc @@ -58,9 +58,11 @@ ScoringProcessor::Create(const ScoringSpecProto& scoring_spec, } std::vector<ScoredDocumentHit> ScoringProcessor::Score( - std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator, - int num_to_score) { + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator, int num_to_score, + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>* + query_term_iterators) { std::vector<ScoredDocumentHit> scored_document_hits; + scorer_->PrepareToScore(query_term_iterators); while (doc_hit_info_iterator->Advance().ok() && num_to_score-- > 0) { const DocHitInfo& doc_hit_info = doc_hit_info_iterator->doc_hit_info(); @@ -69,7 +71,8 @@ std::vector<ScoredDocumentHit> ScoringProcessor::Score( // The final score of the doc_hit_info = score of doc * demotion factor of // hit. double score = - scorer_->GetScore(doc_hit_info.document_id()) * hit_demotion_factor; + scorer_->GetScore(doc_hit_info, doc_hit_info_iterator.get()) * + hit_demotion_factor; scored_document_hits.emplace_back( doc_hit_info.document_id(), doc_hit_info.hit_section_ids_mask(), score); } diff --git a/icing/scoring/scoring-processor.h b/icing/scoring/scoring-processor.h index 60c3b32..2289605 100644 --- a/icing/scoring/scoring-processor.h +++ b/icing/scoring/scoring-processor.h @@ -48,11 +48,14 @@ class ScoringProcessor { // num_to_score. The order of results is the same as DocHitInfos from // DocHitInfoIterator. // - // NOTE: if the scoring spec doesn't require a scoring strategy, all + // If necessary, query_term_iterators is used to compute the BM25F relevance + // score. NOTE: if the scoring spec doesn't require a scoring strategy, all // ScoredDocumentHits will be assigned a default score 0. std::vector<ScoredDocumentHit> Score( std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator, - int num_to_score); + int num_to_score, + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>>* + query_term_iterators = nullptr); private: explicit ScoringProcessor(std::unique_ptr<Scorer> scorer) diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc index 14b2a20..65eecd1 100644 --- a/icing/scoring/scoring-processor_test.cc +++ b/icing/scoring/scoring-processor_test.cc @@ -253,6 +253,216 @@ TEST_F(ScoringProcessorTest, ShouldScoreByDocumentScore) { EqualsScoredDocumentHit(scored_document_hits.at(2)))); } +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_DocumentsWithDifferentLength) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document3 = + CreateDocument("icing", "email/3", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/10)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store()->Put(document2, /*num_tokens=*/100)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id3, + document_store()->Put(document3, /*num_tokens=*/50)); + + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); + DocHitInfo doc_hit_info2(document_id2); + doc_hit_info2.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); + DocHitInfo doc_hit_info3(document_id3); + doc_hit_info3.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); + + SectionId section_id = 0; + SectionIdMask section_id_mask = 1U << section_id; + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2, + doc_hit_info3}; + + // Creates a dummy DocHitInfoIterator with 3 results for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + // Since the three documents all contain the query term "foo" exactly once, + // the document's length determines the final score. Document shorter than the + // average corpus length are slightly boosted. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask, + /*score=*/0.255482); + ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask, + /*score=*/0.115927); + ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask, + /*score=*/0.166435); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/3, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1), + EqualsScoredDocumentHit(expected_scored_doc_hit2), + EqualsScoredDocumentHit(expected_scored_doc_hit3))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_DocumentsWithSameLength) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document3 = + CreateDocument("icing", "email/3", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/10)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store()->Put(document2, /*num_tokens=*/10)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id3, + document_store()->Put(document3, /*num_tokens=*/10)); + + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); + DocHitInfo doc_hit_info2(document_id2); + doc_hit_info2.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); + DocHitInfo doc_hit_info3(document_id3); + doc_hit_info3.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); + + SectionId section_id = 0; + SectionIdMask section_id_mask = 1U << section_id; + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2, + doc_hit_info3}; + + // Creates a dummy DocHitInfoIterator with 3 results for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + // Since the three documents all contain the query term "foo" exactly once + // and they have the same length, they will have the same BM25F scoret. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask, + /*score=*/0.16173716); + ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask, + /*score=*/0.16173716); + ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask, + /*score=*/0.16173716); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/3, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1), + EqualsScoredDocumentHit(expected_scored_doc_hit2), + EqualsScoredDocumentHit(expected_scored_doc_hit3))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_DocumentsWithDifferentQueryFrequency) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document3 = + CreateDocument("icing", "email/3", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/10)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store()->Put(document2, /*num_tokens=*/10)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id3, + document_store()->Put(document3, /*num_tokens=*/10)); + + DocHitInfo doc_hit_info1(document_id1); + // Document 1 contains the query term "foo" 5 times + doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/5); + DocHitInfo doc_hit_info2(document_id2); + // Document 1 contains the query term "foo" 1 time + doc_hit_info2.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); + DocHitInfo doc_hit_info3(document_id3); + // Document 1 contains the query term "foo" 3 times + doc_hit_info3.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); + doc_hit_info3.UpdateSection(/*section_id*/ 1, /*hit_term_frequency=*/2); + + SectionIdMask section_id_mask1 = 0b00000001; + SectionIdMask section_id_mask2 = 0b00000001; + SectionIdMask section_id_mask3 = 0b00000011; + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2, + doc_hit_info3}; + + // Creates a dummy DocHitInfoIterator with 3 results for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + // Since the three documents all have the same length, the score is decided by + // the frequency of the query term "foo". + ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask1, + /*score=*/0.309497); + ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask2, + /*score=*/0.16173716); + ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask3, + /*score=*/0.268599); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/3, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1), + EqualsScoredDocumentHit(expected_scored_doc_hit2), + EqualsScoredDocumentHit(expected_scored_doc_hit3))); +} + TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, diff --git a/icing/store/corpus-associated-scoring-data.h b/icing/store/corpus-associated-scoring-data.h new file mode 100644 index 0000000..52be5cd --- /dev/null +++ b/icing/store/corpus-associated-scoring-data.h @@ -0,0 +1,79 @@ +// Copyright (C) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_STORE_TYPE_NAMESPACE_ASSOCIATED_SCORING_DATA_H_ +#define ICING_STORE_TYPE_NAMESPACE_ASSOCIATED_SCORING_DATA_H_ + +#include <cstdint> +#include <limits> +#include <type_traits> + +#include "icing/legacy/core/icing-packed-pod.h" + +namespace icing { +namespace lib { + +// This is the cache entity of corpus-associated scores. The ground-truth data +// is stored somewhere else. The cache includes: +// 1. Number of documents contained in the corpus. +// Positive values are required. +// 2. The sum of the documents' lengths, in number of tokens. +class CorpusAssociatedScoreData { + public: + explicit CorpusAssociatedScoreData(int num_docs = 0, + int64_t sum_length_in_tokens = 0) + : sum_length_in_tokens_(sum_length_in_tokens), num_docs_(num_docs) {} + + bool operator==(const CorpusAssociatedScoreData& other) const { + return num_docs_ == other.num_docs() && + sum_length_in_tokens_ == other.sum_length_in_tokens(); + } + + uint32_t num_docs() const { return num_docs_; } + void set_num_docs(uint32_t val) { num_docs_ = val; } + + uint64_t sum_length_in_tokens() const { return sum_length_in_tokens_; } + void set_sum_length_in_tokens(uint64_t val) { sum_length_in_tokens_ = val; } + + float average_doc_length_in_tokens() const { + return sum_length_in_tokens_ / (1.0f + num_docs_); + } + + // Adds a new document. + // Adds the document's length to the total length of the corpus, + // sum_length_in_tokens_. + void AddDocument(uint32_t doc_length_in_tokens) { + ++num_docs_; + sum_length_in_tokens_ = + (std::numeric_limits<int>::max() - doc_length_in_tokens < + sum_length_in_tokens_) + ? std::numeric_limits<int>::max() + : sum_length_in_tokens_ + doc_length_in_tokens; + } + + private: + // The sum total of the length of all documents in the corpus. + int sum_length_in_tokens_; + int num_docs_; +} __attribute__((packed)); + +static_assert(sizeof(CorpusAssociatedScoreData) == 8, + "Size of CorpusAssociatedScoreData should be 8"); +static_assert(icing_is_packed_pod<CorpusAssociatedScoreData>::value, + "go/icing-ubsan"); + +} // namespace lib +} // namespace icing + +#endif // ICING_STORE_TYPE_NAMESPACE_ASSOCIATED_SCORING_DATA_H_ diff --git a/icing/store/corpus-id.h b/icing/store/corpus-id.h index a8f21ba..01135b9 100644 --- a/icing/store/corpus-id.h +++ b/icing/store/corpus-id.h @@ -24,6 +24,8 @@ namespace lib { // DocumentProto. Generated in DocumentStore. using CorpusId = int32_t; +inline constexpr CorpusId kInvalidCorpusId = -1; + } // namespace lib } // namespace icing diff --git a/icing/store/document-associated-score-data.h b/icing/store/document-associated-score-data.h index b9039c5..9a711c8 100644 --- a/icing/store/document-associated-score-data.h +++ b/icing/store/document-associated-score-data.h @@ -19,6 +19,7 @@ #include <type_traits> #include "icing/legacy/core/icing-packed-pod.h" +#include "icing/store/corpus-id.h" namespace icing { namespace lib { @@ -26,33 +27,46 @@ namespace lib { // This is the cache entity of document-associated scores. It contains scores // that are related to the document itself. The ground-truth data is stored // somewhere else. The cache includes: -// 1. Document score. It's defined in and passed from DocumentProto.score. +// 1. Corpus Id. +// 2. Document score. It's defined in and passed from DocumentProto.score. // Positive values are required. -// 2. Document creation timestamp. Unix timestamp of when the document is +// 3. Document creation timestamp. Unix timestamp of when the document is // created and inserted into Icing. +// 4. Document length in number of tokens. class DocumentAssociatedScoreData { public: - explicit DocumentAssociatedScoreData(int document_score, - int64_t creation_timestamp_ms) - : document_score_(document_score), - creation_timestamp_ms_(creation_timestamp_ms) {} + explicit DocumentAssociatedScoreData(CorpusId corpus_id, int document_score, + int64_t creation_timestamp_ms, + int length_in_tokens = 0) + : creation_timestamp_ms_(creation_timestamp_ms), + corpus_id_(corpus_id), + document_score_(document_score), + length_in_tokens_(length_in_tokens) {} bool operator==(const DocumentAssociatedScoreData& other) const { return document_score_ == other.document_score() && - creation_timestamp_ms_ == other.creation_timestamp_ms(); + creation_timestamp_ms_ == other.creation_timestamp_ms() && + length_in_tokens_ == other.length_in_tokens() && + corpus_id_ == other.corpus_id(); } + CorpusId corpus_id() const { return corpus_id_; } + int document_score() const { return document_score_; } int64_t creation_timestamp_ms() const { return creation_timestamp_ms_; } + int length_in_tokens() const { return length_in_tokens_; } + private: - int document_score_; int64_t creation_timestamp_ms_; + CorpusId corpus_id_; + int document_score_; + int length_in_tokens_; } __attribute__((packed)); -static_assert(sizeof(DocumentAssociatedScoreData) == 12, - "Size of DocumentAssociatedScoreData should be 12"); +static_assert(sizeof(DocumentAssociatedScoreData) == 20, + "Size of DocumentAssociatedScoreData should be 20"); static_assert(icing_is_packed_pod<DocumentAssociatedScoreData>::value, "go/icing-ubsan"); diff --git a/icing/store/document-store.cc b/icing/store/document-store.cc index 6a664a3..72bf736 100644 --- a/icing/store/document-store.cc +++ b/icing/store/document-store.cc @@ -37,18 +37,20 @@ #include "icing/proto/document_wrapper.pb.h" #include "icing/proto/logging.pb.h" #include "icing/schema/schema-store.h" +#include "icing/store/corpus-associated-scoring-data.h" #include "icing/store/corpus-id.h" #include "icing/store/document-associated-score-data.h" #include "icing/store/document-filter-data.h" #include "icing/store/document-id.h" -#include "icing/store/enable-bm25f.h" #include "icing/store/key-mapper.h" #include "icing/store/namespace-id.h" +#include "icing/tokenization/language-segmenter.h" #include "icing/util/clock.h" #include "icing/util/crc32.h" #include "icing/util/data-loss.h" #include "icing/util/logging.h" #include "icing/util/status-macros.h" +#include "icing/util/tokenized-document.h" namespace icing { namespace lib { @@ -61,6 +63,7 @@ constexpr char kDocumentLogFilename[] = "document_log"; constexpr char kDocumentIdMapperFilename[] = "document_id_mapper"; constexpr char kDocumentStoreHeaderFilename[] = "document_store_header"; constexpr char kScoreCacheFilename[] = "score_cache"; +constexpr char kCorpusScoreCache[] = "corpus_score_cache"; constexpr char kFilterCacheFilename[] = "filter_cache"; constexpr char kNamespaceMapperFilename[] = "namespace_mapper"; constexpr char kUsageStoreDirectoryName[] = "usage_store"; @@ -122,6 +125,10 @@ std::string MakeScoreCacheFilename(const std::string& base_dir) { return absl_ports::StrCat(base_dir, "/", kScoreCacheFilename); } +std::string MakeCorpusScoreCache(const std::string& base_dir) { + return absl_ports::StrCat(base_dir, "/", kCorpusScoreCache); +} + std::string MakeFilterCacheFilename(const std::string& base_dir) { return absl_ports::StrCat(base_dir, "/", kFilterCacheFilename); } @@ -195,8 +202,16 @@ DocumentStore::DocumentStore(const Filesystem* filesystem, document_validator_(schema_store) {} libtextclassifier3::StatusOr<DocumentId> DocumentStore::Put( - const DocumentProto& document, NativePutDocumentStats* put_document_stats) { - return Put(DocumentProto(document), put_document_stats); + const DocumentProto& document, int32_t num_tokens, + NativePutDocumentStats* put_document_stats) { + return Put(DocumentProto(document), num_tokens, put_document_stats); +} + +libtextclassifier3::StatusOr<DocumentId> DocumentStore::Put( + DocumentProto&& document, int32_t num_tokens, + NativePutDocumentStats* put_document_stats) { + document.mutable_internal_fields()->set_length_in_tokens(num_tokens); + return InternalPut(document, put_document_stats); } DocumentStore::~DocumentStore() { @@ -366,12 +381,15 @@ libtextclassifier3::Status DocumentStore::InitializeDerivedFiles() { usage_store_, UsageStore::Create(filesystem_, MakeUsageStoreDirectoryName(base_dir_))); - if (enableBm25f()) { - ICING_ASSIGN_OR_RETURN( - corpus_mapper_, KeyMapper<CorpusId>::Create( - *filesystem_, MakeCorpusMapperFilename(base_dir_), - kCorpusMapperMaxSize)); - } + ICING_ASSIGN_OR_RETURN(corpus_mapper_, + KeyMapper<CorpusId>::Create( + *filesystem_, MakeCorpusMapperFilename(base_dir_), + kCorpusMapperMaxSize)); + + ICING_ASSIGN_OR_RETURN(corpus_score_cache_, + FileBackedVector<CorpusAssociatedScoreData>::Create( + *filesystem_, MakeCorpusScoreCache(base_dir_), + MemoryMappedFile::READ_WRITE_AUTO_SYNC)); // Ensure the usage store is the correct size. ICING_RETURN_IF_ERROR( @@ -392,9 +410,8 @@ libtextclassifier3::Status DocumentStore::RegenerateDerivedFiles() { ICING_RETURN_IF_ERROR(ResetDocumentAssociatedScoreCache()); ICING_RETURN_IF_ERROR(ResetFilterCache()); ICING_RETURN_IF_ERROR(ResetNamespaceMapper()); - if (enableBm25f()) { - ICING_RETURN_IF_ERROR(ResetCorpusMapper()); - } + ICING_RETURN_IF_ERROR(ResetCorpusMapper()); + ICING_RETURN_IF_ERROR(ResetCorpusAssociatedScoreCache()); // Creates a new UsageStore instance. Note that we don't reset the data in // usage store here because we're not able to regenerate the usage scores. @@ -506,12 +523,6 @@ libtextclassifier3::Status DocumentStore::RegenerateDerivedFiles() { ICING_RETURN_IF_ERROR( document_id_mapper_->Set(new_document_id, iterator.GetOffset())); - ICING_RETURN_IF_ERROR(UpdateDocumentAssociatedScoreCache( - new_document_id, - DocumentAssociatedScoreData( - document_wrapper.document().score(), - document_wrapper.document().creation_timestamp_ms()))); - SchemaTypeId schema_type_id; auto schema_type_id_or = schema_store_->GetSchemaTypeId(document_wrapper.document().schema()); @@ -536,13 +547,30 @@ libtextclassifier3::Status DocumentStore::RegenerateDerivedFiles() { namespace_mapper_->GetOrPut(document_wrapper.document().namespace_(), namespace_mapper_->num_keys())); - if (enableBm25f()) { - // Update corpus maps - std::string corpus = - MakeFingerprint(document_wrapper.document().namespace_(), - document_wrapper.document().schema()); - corpus_mapper_->GetOrPut(corpus, corpus_mapper_->num_keys()); - } + // Update corpus maps + std::string corpus = + MakeFingerprint(document_wrapper.document().namespace_(), + document_wrapper.document().schema()); + ICING_ASSIGN_OR_RETURN( + CorpusId corpusId, + corpus_mapper_->GetOrPut(corpus, corpus_mapper_->num_keys())); + + ICING_ASSIGN_OR_RETURN(CorpusAssociatedScoreData scoring_data, + GetCorpusAssociatedScoreDataToUpdate(corpusId)); + scoring_data.AddDocument( + document_wrapper.document().internal_fields().length_in_tokens()); + + ICING_RETURN_IF_ERROR( + UpdateCorpusAssociatedScoreCache(corpusId, scoring_data)); + + ICING_RETURN_IF_ERROR(UpdateDocumentAssociatedScoreCache( + new_document_id, + DocumentAssociatedScoreData( + corpusId, document_wrapper.document().score(), + document_wrapper.document().creation_timestamp_ms(), + document_wrapper.document() + .internal_fields() + .length_in_tokens()))); int64_t expiration_timestamp_ms = CalculateExpirationTimestampMs( document_wrapper.document().creation_timestamp_ms(), @@ -638,6 +666,18 @@ libtextclassifier3::Status DocumentStore::ResetDocumentAssociatedScoreCache() { return libtextclassifier3::Status::OK; } +libtextclassifier3::Status DocumentStore::ResetCorpusAssociatedScoreCache() { + // TODO(b/139734457): Replace ptr.reset()->Delete->Create flow with Reset(). + corpus_score_cache_.reset(); + ICING_RETURN_IF_ERROR(FileBackedVector<CorpusAssociatedScoreData>::Delete( + *filesystem_, MakeCorpusScoreCache(base_dir_))); + ICING_ASSIGN_OR_RETURN(corpus_score_cache_, + FileBackedVector<CorpusAssociatedScoreData>::Create( + *filesystem_, MakeCorpusScoreCache(base_dir_), + MemoryMappedFile::READ_WRITE_AUTO_SYNC)); + return libtextclassifier3::Status::OK; +} + libtextclassifier3::Status DocumentStore::ResetFilterCache() { // TODO(b/139734457): Replace ptr.reset()->Delete->Create flow with Reset(). filter_cache_.reset(); @@ -671,23 +711,21 @@ libtextclassifier3::Status DocumentStore::ResetNamespaceMapper() { } libtextclassifier3::Status DocumentStore::ResetCorpusMapper() { - if (enableBm25f()) { - // TODO(b/139734457): Replace ptr.reset()->Delete->Create flow with Reset(). - corpus_mapper_.reset(); - // TODO(b/144458732): Implement a more robust version of TC_RETURN_IF_ERROR - // that can support error logging. - libtextclassifier3::Status status = KeyMapper<CorpusId>::Delete( - *filesystem_, MakeCorpusMapperFilename(base_dir_)); - if (!status.ok()) { - ICING_LOG(ERROR) << status.error_message() - << "Failed to delete old corpus_id mapper"; - return status; - } - ICING_ASSIGN_OR_RETURN( - corpus_mapper_, KeyMapper<CorpusId>::Create( - *filesystem_, MakeCorpusMapperFilename(base_dir_), - kCorpusMapperMaxSize)); + // TODO(b/139734457): Replace ptr.reset()->Delete->Create flow with Reset(). + corpus_mapper_.reset(); + // TODO(b/144458732): Implement a more robust version of TC_RETURN_IF_ERROR + // that can support error logging. + libtextclassifier3::Status status = KeyMapper<CorpusId>::Delete( + *filesystem_, MakeCorpusMapperFilename(base_dir_)); + if (!status.ok()) { + ICING_LOG(ERROR) << status.error_message() + << "Failed to delete old corpus_id mapper"; + return status; } + ICING_ASSIGN_OR_RETURN(corpus_mapper_, + KeyMapper<CorpusId>::Create( + *filesystem_, MakeCorpusMapperFilename(base_dir_), + kCorpusMapperMaxSize)); return libtextclassifier3::Status::OK; } @@ -738,16 +776,26 @@ libtextclassifier3::StatusOr<Crc32> DocumentStore::ComputeChecksum() const { Crc32 namespace_mapper_checksum = namespace_mapper_->ComputeChecksum(); + Crc32 corpus_mapper_checksum = corpus_mapper_->ComputeChecksum(); + + // TODO(b/144458732): Implement a more robust version of TC_ASSIGN_OR_RETURN + // that can support error logging. + checksum_or = corpus_score_cache_->ComputeChecksum(); + if (!checksum_or.ok()) { + ICING_LOG(WARNING) << checksum_or.status().error_message() + << "Failed to compute checksum of score cache"; + return checksum_or.status(); + } + Crc32 corpus_score_cache_checksum = std::move(checksum_or).ValueOrDie(); + total_checksum.Append(std::to_string(document_log_checksum.Get())); total_checksum.Append(std::to_string(document_key_mapper_checksum.Get())); total_checksum.Append(std::to_string(document_id_mapper_checksum.Get())); total_checksum.Append(std::to_string(score_cache_checksum.Get())); total_checksum.Append(std::to_string(filter_cache_checksum.Get())); total_checksum.Append(std::to_string(namespace_mapper_checksum.Get())); - if (enableBm25f()) { - Crc32 corpus_mapper_checksum = corpus_mapper_->ComputeChecksum(); - total_checksum.Append(std::to_string(corpus_mapper_checksum.Get())); - } + total_checksum.Append(std::to_string(corpus_mapper_checksum.Get())); + total_checksum.Append(std::to_string(corpus_score_cache_checksum.Get())); return total_checksum; } @@ -779,8 +827,8 @@ libtextclassifier3::Status DocumentStore::UpdateHeader(const Crc32& checksum) { return libtextclassifier3::Status::OK; } -libtextclassifier3::StatusOr<DocumentId> DocumentStore::Put( - DocumentProto&& document, NativePutDocumentStats* put_document_stats) { +libtextclassifier3::StatusOr<DocumentId> DocumentStore::InternalPut( + DocumentProto& document, NativePutDocumentStats* put_document_stats) { std::unique_ptr<Timer> put_timer = clock_.GetNewTimer(); ICING_RETURN_IF_ERROR(document_validator_.Validate(document)); @@ -793,6 +841,7 @@ libtextclassifier3::StatusOr<DocumentId> DocumentStore::Put( std::string uri = document.uri(); std::string schema = document.schema(); int document_score = document.score(); + int32_t length_in_tokens = document.internal_fields().length_in_tokens(); int64_t creation_timestamp_ms = document.creation_timestamp_ms(); // Sets the creation timestamp if caller hasn't specified. @@ -829,20 +878,28 @@ libtextclassifier3::StatusOr<DocumentId> DocumentStore::Put( MakeFingerprint(name_space, uri), new_document_id)); ICING_RETURN_IF_ERROR(document_id_mapper_->Set(new_document_id, file_offset)); - ICING_RETURN_IF_ERROR(UpdateDocumentAssociatedScoreCache( - new_document_id, - DocumentAssociatedScoreData(document_score, creation_timestamp_ms))); - // Update namespace maps ICING_ASSIGN_OR_RETURN( NamespaceId namespace_id, namespace_mapper_->GetOrPut(name_space, namespace_mapper_->num_keys())); - if (enableBm25f()) { - // Update corpus maps - ICING_RETURN_IF_ERROR(corpus_mapper_->GetOrPut( - MakeFingerprint(name_space, schema), corpus_mapper_->num_keys())); - } + // Update corpus maps + ICING_ASSIGN_OR_RETURN( + CorpusId corpusId, + corpus_mapper_->GetOrPut(MakeFingerprint(name_space, schema), + corpus_mapper_->num_keys())); + + ICING_ASSIGN_OR_RETURN(CorpusAssociatedScoreData scoring_data, + GetCorpusAssociatedScoreDataToUpdate(corpusId)); + scoring_data.AddDocument(length_in_tokens); + + ICING_RETURN_IF_ERROR( + UpdateCorpusAssociatedScoreCache(corpusId, scoring_data)); + + ICING_RETURN_IF_ERROR(UpdateDocumentAssociatedScoreCache( + new_document_id, + DocumentAssociatedScoreData(corpusId, document_score, + creation_timestamp_ms, length_in_tokens))); ICING_ASSIGN_OR_RETURN(SchemaTypeId schema_type_id, schema_store_->GetSchemaTypeId(schema)); @@ -876,7 +933,8 @@ libtextclassifier3::StatusOr<DocumentId> DocumentStore::Put( } libtextclassifier3::StatusOr<DocumentProto> DocumentStore::Get( - const std::string_view name_space, const std::string_view uri) const { + const std::string_view name_space, const std::string_view uri, + bool clear_internal_fields) const { // TODO(b/147231617): Make a better way to replace the error message in an // existing Status. auto document_id_or = GetDocumentId(name_space, uri); @@ -903,7 +961,7 @@ libtextclassifier3::StatusOr<DocumentProto> DocumentStore::Get( } libtextclassifier3::StatusOr<DocumentProto> DocumentStore::Get( - DocumentId document_id) const { + DocumentId document_id, bool clear_internal_fields) const { ICING_ASSIGN_OR_RETURN(int64_t document_log_offset, DoesDocumentExistAndGetFileOffset(document_id)); @@ -917,6 +975,9 @@ libtextclassifier3::StatusOr<DocumentProto> DocumentStore::Get( } DocumentWrapper document_wrapper = std::move(document_wrapper_or).ValueOrDie(); + if (clear_internal_fields) { + document_wrapper.mutable_document()->clear_internal_fields(); + } return std::move(*document_wrapper.mutable_document()); } @@ -1088,10 +1149,7 @@ libtextclassifier3::StatusOr<NamespaceId> DocumentStore::GetNamespaceId( libtextclassifier3::StatusOr<CorpusId> DocumentStore::GetCorpusId( const std::string_view name_space, const std::string_view schema) const { - if (enableBm25f()) { - return corpus_mapper_->Get(MakeFingerprint(name_space, schema)); - } - return absl_ports::NotFoundError("corpus_mapper disabled"); + return corpus_mapper_->Get(MakeFingerprint(name_space, schema)); } libtextclassifier3::StatusOr<DocumentAssociatedScoreData> @@ -1112,6 +1170,34 @@ DocumentStore::GetDocumentAssociatedScoreData(DocumentId document_id) const { return document_associated_score_data; } +libtextclassifier3::StatusOr<CorpusAssociatedScoreData> +DocumentStore::GetCorpusAssociatedScoreData(CorpusId corpus_id) const { + auto score_data_or = corpus_score_cache_->Get(corpus_id); + if (!score_data_or.ok()) { + return score_data_or.status(); + } + + CorpusAssociatedScoreData corpus_associated_score_data = + *std::move(score_data_or).ValueOrDie(); + return corpus_associated_score_data; +} + +libtextclassifier3::StatusOr<CorpusAssociatedScoreData> +DocumentStore::GetCorpusAssociatedScoreDataToUpdate(CorpusId corpus_id) const { + auto corpus_scoring_data_or = GetCorpusAssociatedScoreData(corpus_id); + if (corpus_scoring_data_or.ok()) { + return std::move(corpus_scoring_data_or).ValueOrDie(); + } + CorpusAssociatedScoreData scoringData; + // OUT_OF_RANGE is the StatusCode returned when a corpus id is added to + // corpus_score_cache_ for the first time. + if (corpus_scoring_data_or.status().CanonicalCode() == + libtextclassifier3::StatusCode::OUT_OF_RANGE) { + return scoringData; + } + return corpus_scoring_data_or.status(); +} + libtextclassifier3::StatusOr<DocumentFilterData> DocumentStore::GetDocumentFilterData(DocumentId document_id) const { auto filter_data_or = filter_cache_->Get(document_id); @@ -1308,10 +1394,8 @@ libtextclassifier3::Status DocumentStore::PersistToDisk() { ICING_RETURN_IF_ERROR(filter_cache_->PersistToDisk()); ICING_RETURN_IF_ERROR(namespace_mapper_->PersistToDisk()); ICING_RETURN_IF_ERROR(usage_store_->PersistToDisk()); - - if (enableBm25f()) { - ICING_RETURN_IF_ERROR(corpus_mapper_->PersistToDisk()); - } + ICING_RETURN_IF_ERROR(corpus_mapper_->PersistToDisk()); + ICING_RETURN_IF_ERROR(corpus_score_cache_->PersistToDisk()); // Update the combined checksum and write to header file. ICING_ASSIGN_OR_RETURN(Crc32 checksum, ComputeChecksum()); @@ -1333,16 +1417,16 @@ libtextclassifier3::StatusOr<int64_t> DocumentStore::GetDiskUsage() const { filter_cache_->GetDiskUsage()); ICING_ASSIGN_OR_RETURN(const int64_t namespace_mapper_disk_usage, namespace_mapper_->GetDiskUsage()); + ICING_ASSIGN_OR_RETURN(const int64_t corpus_mapper_disk_usage, + corpus_mapper_->GetDiskUsage()); + ICING_ASSIGN_OR_RETURN(const int64_t corpus_score_cache_disk_usage, + corpus_score_cache_->GetDiskUsage()); int64_t disk_usage = document_log_disk_usage + document_key_mapper_disk_usage + document_id_mapper_disk_usage + score_cache_disk_usage + - filter_cache_disk_usage + namespace_mapper_disk_usage; - if (enableBm25f()) { - ICING_ASSIGN_OR_RETURN(const int64_t corpus_mapper_disk_usage, - corpus_mapper_->GetDiskUsage()); - disk_usage += corpus_mapper_disk_usage; - } + filter_cache_disk_usage + namespace_mapper_disk_usage + + corpus_mapper_disk_usage + corpus_score_cache_disk_usage; return disk_usage; } @@ -1493,7 +1577,7 @@ libtextclassifier3::Status DocumentStore::Optimize() { } libtextclassifier3::Status DocumentStore::OptimizeInto( - const std::string& new_directory) { + const std::string& new_directory, const LanguageSegmenter* lang_segmenter) { // Validates directory if (new_directory == base_dir_) { return absl_ports::InvalidArgumentError( @@ -1509,7 +1593,7 @@ libtextclassifier3::Status DocumentStore::OptimizeInto( // Writes all valid docs into new document store (new directory) int size = document_id_mapper_->num_elements(); for (DocumentId document_id = 0; document_id < size; document_id++) { - auto document_or = Get(document_id); + auto document_or = Get(document_id, /*clear_internal_fields=*/false); if (absl_ports::IsNotFound(document_or.status())) { // Skip nonexistent documents continue; @@ -1523,9 +1607,26 @@ libtextclassifier3::Status DocumentStore::OptimizeInto( // Guaranteed to have a document now. DocumentProto document_to_keep = document_or.ValueOrDie(); - // TODO(b/144458732): Implement a more robust version of TC_ASSIGN_OR_RETURN - // that can support error logging. - auto new_document_id_or = new_doc_store->Put(std::move(document_to_keep)); + + libtextclassifier3::StatusOr<DocumentId> new_document_id_or; + if (document_to_keep.internal_fields().length_in_tokens() == 0) { + auto tokenized_document_or = TokenizedDocument::Create( + schema_store_, lang_segmenter, document_to_keep); + if (!tokenized_document_or.ok()) { + return absl_ports::Annotate( + tokenized_document_or.status(), + IcingStringUtil::StringPrintf( + "Failed to tokenize Document for DocumentId %d", document_id)); + } + TokenizedDocument tokenized_document( + std::move(tokenized_document_or).ValueOrDie()); + new_document_id_or = + new_doc_store->Put(document_to_keep, tokenized_document.num_tokens()); + } else { + // TODO(b/144458732): Implement a more robust version of + // TC_ASSIGN_OR_RETURN that can support error logging. + new_document_id_or = new_doc_store->InternalPut(document_to_keep); + } if (!new_document_id_or.ok()) { ICING_LOG(ERROR) << new_document_id_or.status().error_message() << "Failed to write into new document store"; @@ -1577,26 +1678,39 @@ DocumentStore::GetOptimizeInfo() const { score_cache_->GetElementsFileSize()); ICING_ASSIGN_OR_RETURN(const int64_t filter_cache_file_size, filter_cache_->GetElementsFileSize()); + ICING_ASSIGN_OR_RETURN(const int64_t corpus_score_cache_file_size, + corpus_score_cache_->GetElementsFileSize()); + + // Usage store might be sparse, but we'll still use file size for more + // accurate counting. + ICING_ASSIGN_OR_RETURN(const int64_t usage_store_file_size, + usage_store_->GetElementsFileSize()); // We use a combined disk usage and file size for the KeyMapper because it's // backed by a trie, which has some sparse property bitmaps. ICING_ASSIGN_OR_RETURN(const int64_t document_key_mapper_size, document_key_mapper_->GetElementsSize()); - // We don't include the namespace mapper because it's not clear if we could - // recover any space even if Optimize were called. Deleting 100s of documents - // could still leave a few documents of a namespace, and then there would be - // no change. + // We don't include the namespace_mapper or the corpus_mapper because it's not + // clear if we could recover any space even if Optimize were called. Deleting + // 100s of documents could still leave a few documents of a namespace, and + // then there would be no change. int64_t total_size = document_log_file_size + document_key_mapper_size + document_id_mapper_file_size + score_cache_file_size + - filter_cache_file_size; + filter_cache_file_size + corpus_score_cache_file_size + + usage_store_file_size; optimize_info.estimated_optimizable_bytes = total_size * optimize_info.optimizable_docs / optimize_info.total_docs; return optimize_info; } +libtextclassifier3::Status DocumentStore::UpdateCorpusAssociatedScoreCache( + CorpusId corpus_id, const CorpusAssociatedScoreData& score_data) { + return corpus_score_cache_->Set(corpus_id, score_data); +} + libtextclassifier3::Status DocumentStore::UpdateDocumentAssociatedScoreCache( DocumentId document_id, const DocumentAssociatedScoreData& score_data) { return score_cache_->Set(document_id, score_data); @@ -1617,8 +1731,10 @@ libtextclassifier3::Status DocumentStore::ClearDerivedData( // Resets the score cache entry ICING_RETURN_IF_ERROR(UpdateDocumentAssociatedScoreCache( - document_id, DocumentAssociatedScoreData(/*document_score=*/-1, - /*creation_timestamp_ms=*/-1))); + document_id, DocumentAssociatedScoreData(kInvalidCorpusId, + /*document_score=*/-1, + /*creation_timestamp_ms=*/-1, + /*length_in_tokens=*/0))); // Resets the filter cache entry ICING_RETURN_IF_ERROR(UpdateFilterCache( diff --git a/icing/store/document-store.h b/icing/store/document-store.h index 78590a5..b2908f0 100644 --- a/icing/store/document-store.h +++ b/icing/store/document-store.h @@ -30,6 +30,7 @@ #include "icing/proto/document_wrapper.pb.h" #include "icing/proto/logging.pb.h" #include "icing/schema/schema-store.h" +#include "icing/store/corpus-associated-scoring-data.h" #include "icing/store/corpus-id.h" #include "icing/store/document-associated-score-data.h" #include "icing/store/document-filter-data.h" @@ -37,6 +38,7 @@ #include "icing/store/key-mapper.h" #include "icing/store/namespace-id.h" #include "icing/store/usage-store.h" +#include "icing/tokenization/language-segmenter.h" #include "icing/util/clock.h" #include "icing/util/crc32.h" #include "icing/util/data-loss.h" @@ -149,23 +151,27 @@ class DocumentStore { // exist in schema // INTERNAL_ERROR on IO error libtextclassifier3::StatusOr<DocumentId> Put( - const DocumentProto& document, + const DocumentProto& document, int32_t num_tokens = 0, NativePutDocumentStats* put_document_stats = nullptr); libtextclassifier3::StatusOr<DocumentId> Put( - DocumentProto&& document, + DocumentProto&& document, int32_t num_tokens = 0, NativePutDocumentStats* put_document_stats = nullptr); // Finds and returns the document identified by the given key (namespace + - // uri) + // uri). If 'clear_internal_fields' is true, document level data that's + // generated internally by DocumentStore is cleared. // // Returns: // The document found on success // NOT_FOUND if the key doesn't exist or document has been deleted // INTERNAL_ERROR on IO error - libtextclassifier3::StatusOr<DocumentProto> Get(std::string_view name_space, - std::string_view uri) const; + libtextclassifier3::StatusOr<DocumentProto> Get( + std::string_view name_space, std::string_view uri, + bool clear_internal_fields = true) const; - // Finds and returns the document identified by the given document id + // Finds and returns the document identified by the given document id. If + // 'clear_internal_fields' is true, document level data that's generated + // internally by DocumentStore is cleared. // // Returns: // The document found on success @@ -173,7 +179,8 @@ class DocumentStore { // maximum value // NOT_FOUND if the document doesn't exist or has been deleted // INTERNAL_ERROR on IO error - libtextclassifier3::StatusOr<DocumentProto> Get(DocumentId document_id) const; + libtextclassifier3::StatusOr<DocumentProto> Get( + DocumentId document_id, bool clear_internal_fields = true) const; // Returns all namespaces which have at least 1 active document (not deleted // or expired). Order of namespaces is undefined. @@ -256,6 +263,20 @@ class DocumentStore { libtextclassifier3::StatusOr<DocumentAssociatedScoreData> GetDocumentAssociatedScoreData(DocumentId document_id) const; + // Returns the CorpusAssociatedScoreData of the corpus specified by the + // corpus_id. + // + // NOTE: This does not check if the corpus exists and will return the + // CorpusAssociatedScoreData of the corpus even if all documents belonging to + // that corpus have been deleted. + // + // Returns: + // CorpusAssociatedScoreData on success + // OUT_OF_RANGE if corpus_id is negative or exceeds previously seen + // CorpusIds + libtextclassifier3::StatusOr<CorpusAssociatedScoreData> + GetCorpusAssociatedScoreData(CorpusId corpus_id) const; + // Returns the DocumentFilterData of the document specified by the DocumentId. // // NOTE: This does not check if the document exists and will return the @@ -394,7 +415,9 @@ class DocumentStore { // OK on success // INVALID_ARGUMENT if new_directory is same as current base directory // INTERNAL_ERROR on IO error - libtextclassifier3::Status OptimizeInto(const std::string& new_directory); + libtextclassifier3::Status OptimizeInto( + const std::string& new_directory, + const LanguageSegmenter* lang_segmenter); // Calculates status for a potential Optimize call. Includes how many docs // there are vs how many would be optimized away. And also includes an @@ -441,8 +464,10 @@ class DocumentStore { // A cache of document associated scores. The ground truth of the scores is // DocumentProto stored in document_log_. This cache contains: + // - CorpusId // - Document score // - Document creation timestamp in seconds + // - Document length in number of tokens std::unique_ptr<FileBackedVector<DocumentAssociatedScoreData>> score_cache_; // A cache of data, indexed by DocumentId, used to filter documents. Currently @@ -452,6 +477,13 @@ class DocumentStore { // - Expiration timestamp in seconds std::unique_ptr<FileBackedVector<DocumentFilterData>> filter_cache_; + // A cache of corpus associated scores. The ground truth of the scores is + // DocumentProto stored in document_log_. This cache contains: + // - Number of documents belonging to the corpus score + // - The sum of the documents' lengths, in number of tokens. + std::unique_ptr<FileBackedVector<CorpusAssociatedScoreData>> + corpus_score_cache_; + // Maps namespaces to a densely-assigned unique id. Namespaces are assigned an // id when the first document belonging to that namespace is added to the // DocumentStore. Namespaces may be removed from the mapper during compaction. @@ -516,6 +548,12 @@ class DocumentStore { // Returns OK or any IO errors. libtextclassifier3::Status ResetDocumentAssociatedScoreCache(); + // Resets the unique_ptr to the corpus_score_cache, deletes the underlying + // file, and re-creates a new instance of the corpus_score_cache. + // + // Returns OK or any IO errors. + libtextclassifier3::Status ResetCorpusAssociatedScoreCache(); + // Resets the unique_ptr to the filter_cache, deletes the underlying file, and // re-creates a new instance of the filter_cache. // @@ -546,6 +584,10 @@ class DocumentStore { // INTERNAL on I/O error libtextclassifier3::Status UpdateHeader(const Crc32& checksum); + libtextclassifier3::StatusOr<DocumentId> InternalPut( + DocumentProto& document, + NativePutDocumentStats* put_document_stats = nullptr); + // Helper function to do batch deletes. Documents with the given // "namespace_id" and "schema_type_id" will be deleted. If callers don't need // to specify the namespace or schema type, pass in kInvalidNamespaceId or @@ -597,6 +639,21 @@ class DocumentStore { libtextclassifier3::StatusOr<DocumentId> GetDocumentId( std::string_view name_space, std::string_view uri) const; + // Returns the CorpusAssociatedScoreData of the corpus specified by the + // corpus_id. + // + // If the corpus_id has never been seen before, it returns a + // CorpusAssociatedScoreData with properties set to default values. + // + // NOTE: This does not check if the corpus exists and will return the + // CorpusAssociatedScoreData of the corpus even if all documents belonging to + // that corpus have been deleted. + // + // Returns: + // CorpusAssociatedScoreData on success + libtextclassifier3::StatusOr<CorpusAssociatedScoreData> + GetCorpusAssociatedScoreDataToUpdate(CorpusId corpus_id) const; + // Helper method to validate the document id and return the file offset of the // associated document in document_log_. // @@ -617,6 +674,10 @@ class DocumentStore { libtextclassifier3::Status UpdateDocumentAssociatedScoreCache( DocumentId document_id, const DocumentAssociatedScoreData& score_data); + // Updates the entry in the corpus score cache for corpus_id. + libtextclassifier3::Status UpdateCorpusAssociatedScoreCache( + CorpusId corpus_id, const CorpusAssociatedScoreData& score_data); + // Updates the entry in the filter cache for document_id. libtextclassifier3::Status UpdateFilterCache( DocumentId document_id, const DocumentFilterData& filter_data); diff --git a/icing/store/document-store_test.cc b/icing/store/document-store_test.cc index 29bf8bb..7754373 100644 --- a/icing/store/document-store_test.cc +++ b/icing/store/document-store_test.cc @@ -27,20 +27,25 @@ #include "icing/file/filesystem.h" #include "icing/file/memory-mapped-file.h" #include "icing/file/mock-filesystem.h" +#include "icing/helpers/icu/icu-data-file-helper.h" #include "icing/portable/equals-proto.h" #include "icing/proto/document.pb.h" #include "icing/proto/schema.pb.h" #include "icing/schema/schema-store.h" +#include "icing/store/corpus-associated-scoring-data.h" +#include "icing/store/corpus-id.h" #include "icing/store/document-filter-data.h" #include "icing/store/document-id.h" -#include "icing/store/enable-bm25f.h" #include "icing/store/namespace-id.h" #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" #include "icing/testing/platform.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/tokenization/language-segmenter.h" #include "icing/util/crc32.h" +#include "unicode/uloc.h" namespace icing { namespace lib { @@ -101,7 +106,19 @@ class DocumentStoreTest : public ::testing::Test { } void SetUp() override { - setEnableBm25f(true); + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + // If we've specified using the reverse-JNI method for segmentation (i.e. + // not ICU), then we won't have the ICU data file included to set up. + // Technically, we could choose to use reverse-JNI for segmentation AND + // include an ICU data file, but that seems unlikely and our current BUILD + // setup doesn't do this. + // File generated via icu_data_file rule in //icing/BUILD. + std::string icu_data_file_path = + GetTestFilePath("icing/icu.dat"); + ICING_ASSERT_OK( + icu_data_file_helper::SetUpICUDataFile(icu_data_file_path)); + } + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); filesystem_.CreateDirectoryRecursively(document_store_dir_.c_str()); @@ -133,6 +150,11 @@ class DocumentStoreTest : public ::testing::Test { schema_store_, SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + language_segmenter_factory::SegmenterOptions segmenter_options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + lang_segmenter_, + language_segmenter_factory::Create(std::move(segmenter_options))); } void TearDown() override { @@ -147,6 +169,7 @@ class DocumentStoreTest : public ::testing::Test { DocumentProto test_document1_; DocumentProto test_document2_; std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<LanguageSegmenter> lang_segmenter_; // Document1 values const int document1_score_ = 1; @@ -1184,9 +1207,10 @@ TEST_F(DocumentStoreTest, OptimizeInto) { filesystem_.GetFileSize(original_document_log.c_str()); // Optimizing into the same directory is not allowed - EXPECT_THAT(doc_store->OptimizeInto(document_store_dir_), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT, - HasSubstr("directory is the same"))); + EXPECT_THAT( + doc_store->OptimizeInto(document_store_dir_, lang_segmenter_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT, + HasSubstr("directory is the same"))); std::string optimized_dir = document_store_dir_ + "_optimize"; std::string optimized_document_log = optimized_dir + "/document_log"; @@ -1195,7 +1219,8 @@ TEST_F(DocumentStoreTest, OptimizeInto) { // deleted ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); - ICING_ASSERT_OK(doc_store->OptimizeInto(optimized_dir)); + ICING_ASSERT_OK( + doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); int64_t optimized_size1 = filesystem_.GetFileSize(optimized_document_log.c_str()); EXPECT_EQ(original_size, optimized_size1); @@ -1205,7 +1230,8 @@ TEST_F(DocumentStoreTest, OptimizeInto) { ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); ICING_ASSERT_OK(doc_store->Delete("namespace", "uri1")); - ICING_ASSERT_OK(doc_store->OptimizeInto(optimized_dir)); + ICING_ASSERT_OK( + doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); int64_t optimized_size2 = filesystem_.GetFileSize(optimized_document_log.c_str()); EXPECT_THAT(original_size, Gt(optimized_size2)); @@ -1218,7 +1244,8 @@ TEST_F(DocumentStoreTest, OptimizeInto) { // expired ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); - ICING_ASSERT_OK(doc_store->OptimizeInto(optimized_dir)); + ICING_ASSERT_OK( + doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); int64_t optimized_size3 = filesystem_.GetFileSize(optimized_document_log.c_str()); EXPECT_THAT(optimized_size2, Gt(optimized_size3)); @@ -1235,14 +1262,32 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromDataLoss) { std::unique_ptr<DocumentStore> doc_store = std::move(create_result.document_store); - ICING_ASSERT_OK_AND_ASSIGN(document_id1, - doc_store->Put(DocumentProto(test_document1_))); - ICING_ASSERT_OK_AND_ASSIGN(document_id2, - doc_store->Put(DocumentProto(test_document2_))); + ICING_ASSERT_OK_AND_ASSIGN( + document_id1, + doc_store->Put(DocumentProto(test_document1_), /*num_tokens=*/4)); + ICING_ASSERT_OK_AND_ASSIGN( + document_id2, + doc_store->Put(DocumentProto(test_document2_), /*num_tokens=*/4)); EXPECT_THAT(doc_store->Get(document_id1), IsOkAndHolds(EqualsProto(test_document1_))); EXPECT_THAT(doc_store->Get(document_id2), IsOkAndHolds(EqualsProto(test_document2_))); + // Checks derived score cache + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id1), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document1_score_, document1_creation_timestamp_, + /*length_in_tokens=*/4))); + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id2), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document2_score_, document2_creation_timestamp_, + /*length_in_tokens=*/4))); + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/2, /*sum_length_in_tokens=*/8))); + + // Delete document 1 EXPECT_THAT(doc_store->Delete("icing", "email/1"), IsOk()); EXPECT_THAT(doc_store->Get(document_id1), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); @@ -1281,9 +1326,14 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromDataLoss) { /*namespace_id=*/0, /*schema_type_id=*/0, document2_expiration_timestamp_))); // Checks derived score cache - EXPECT_THAT(doc_store->GetDocumentAssociatedScoreData(document_id2), - IsOkAndHolds(DocumentAssociatedScoreData( - document2_score_, document2_creation_timestamp_))); + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id2), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document2_score_, document2_creation_timestamp_, + /*length_in_tokens=*/4))); + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/1, /*sum_length_in_tokens=*/4))); } TEST_F(DocumentStoreTest, ShouldRecoverFromCorruptDerivedFile) { @@ -1297,14 +1347,31 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromCorruptDerivedFile) { std::unique_ptr<DocumentStore> doc_store = std::move(create_result.document_store); - ICING_ASSERT_OK_AND_ASSIGN(document_id1, - doc_store->Put(DocumentProto(test_document1_))); - ICING_ASSERT_OK_AND_ASSIGN(document_id2, - doc_store->Put(DocumentProto(test_document2_))); + ICING_ASSERT_OK_AND_ASSIGN( + document_id1, + doc_store->Put(DocumentProto(test_document1_), /*num_tokens=*/4)); + ICING_ASSERT_OK_AND_ASSIGN( + document_id2, + doc_store->Put(DocumentProto(test_document2_), /*num_tokens=*/4)); EXPECT_THAT(doc_store->Get(document_id1), IsOkAndHolds(EqualsProto(test_document1_))); EXPECT_THAT(doc_store->Get(document_id2), IsOkAndHolds(EqualsProto(test_document2_))); + // Checks derived score cache + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id1), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document1_score_, document1_creation_timestamp_, + /*length_in_tokens=*/4))); + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id2), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document2_score_, document2_creation_timestamp_, + /*length_in_tokens=*/4))); + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/2, /*sum_length_in_tokens=*/8))); + // Delete document 1 EXPECT_THAT(doc_store->Delete("icing", "email/1"), IsOk()); EXPECT_THAT(doc_store->Get(document_id1), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); @@ -1328,6 +1395,7 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromCorruptDerivedFile) { IsOk()); // Successfully recover from a corrupt derived file issue. + // NOTE: this doesn't trigger RegenerateDerivedFiles. ICING_ASSERT_OK_AND_ASSIGN( DocumentStore::CreateResult create_result, DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, @@ -1345,10 +1413,16 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromCorruptDerivedFile) { IsOkAndHolds(DocumentFilterData( /*namespace_id=*/0, /*schema_type_id=*/0, document2_expiration_timestamp_))); - // Checks derived score cache - EXPECT_THAT(doc_store->GetDocumentAssociatedScoreData(document_id2), - IsOkAndHolds(DocumentAssociatedScoreData( - document2_score_, document2_creation_timestamp_))); + // Checks derived score cache - note that they aren't regenerated from + // scratch. + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id2), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document2_score_, document2_creation_timestamp_, + /*length_in_tokens=*/4))); + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/2, /*sum_length_in_tokens=*/8))); } TEST_F(DocumentStoreTest, ShouldRecoverFromBadChecksum) { @@ -1362,14 +1436,30 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromBadChecksum) { std::unique_ptr<DocumentStore> doc_store = std::move(create_result.document_store); - ICING_ASSERT_OK_AND_ASSIGN(document_id1, - doc_store->Put(DocumentProto(test_document1_))); - ICING_ASSERT_OK_AND_ASSIGN(document_id2, - doc_store->Put(DocumentProto(test_document2_))); + ICING_ASSERT_OK_AND_ASSIGN( + document_id1, + doc_store->Put(DocumentProto(test_document1_), /*num_tokens=*/4)); + ICING_ASSERT_OK_AND_ASSIGN( + document_id2, + doc_store->Put(DocumentProto(test_document2_), /*num_tokens=*/4)); EXPECT_THAT(doc_store->Get(document_id1), IsOkAndHolds(EqualsProto(test_document1_))); EXPECT_THAT(doc_store->Get(document_id2), IsOkAndHolds(EqualsProto(test_document2_))); + // Checks derived score cache + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id1), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document1_score_, document1_creation_timestamp_, + /*length_in_tokens=*/4))); + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id2), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document2_score_, document2_creation_timestamp_, + /*length_in_tokens=*/4))); + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/2, /*sum_length_in_tokens=*/8))); EXPECT_THAT(doc_store->Delete("icing", "email/1"), IsOk()); EXPECT_THAT(doc_store->Get(document_id1), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); @@ -1407,9 +1497,14 @@ TEST_F(DocumentStoreTest, ShouldRecoverFromBadChecksum) { /*namespace_id=*/0, /*schema_type_id=*/0, document2_expiration_timestamp_))); // Checks derived score cache - EXPECT_THAT(doc_store->GetDocumentAssociatedScoreData(document_id2), - IsOkAndHolds(DocumentAssociatedScoreData( - document2_score_, document2_creation_timestamp_))); + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id2), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document2_score_, document2_creation_timestamp_, + /*length_in_tokens=*/4))); + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/1, /*sum_length_in_tokens=*/4))); } TEST_F(DocumentStoreTest, GetDiskUsage) { @@ -1544,28 +1639,6 @@ TEST_F(DocumentStoreTest, NonexistentNamespaceNotFound) { StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } -TEST_F(DocumentStoreTest, GetCorpusIdReturnsNotFoundWhenFeatureIsDisabled) { - setEnableBm25f(false); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, - schema_store_.get())); - std::unique_ptr<DocumentStore> doc_store = - std::move(create_result.document_store); - - DocumentProto document1 = - DocumentBuilder().SetKey("namespace", "1").SetSchema("email").Build(); - DocumentProto document2 = - DocumentBuilder().SetKey("namespace", "2").SetSchema("email").Build(); - - ICING_ASSERT_OK(doc_store->Put(document1)); - ICING_ASSERT_OK(doc_store->Put(document2)); - - EXPECT_THAT(doc_store->GetCorpusId("namespace", "email"), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND, - HasSubstr("corpus_mapper disabled"))); -} - TEST_F(DocumentStoreTest, GetCorpusDuplicateCorpusId) { ICING_ASSERT_OK_AND_ASSIGN( DocumentStore::CreateResult create_result, @@ -1582,7 +1655,7 @@ TEST_F(DocumentStoreTest, GetCorpusDuplicateCorpusId) { ICING_ASSERT_OK(doc_store->Put(document1)); ICING_ASSERT_OK(doc_store->Put(document2)); - // NamespaceId of 0 since it was the first namespace seen by the DocumentStore + // CorpusId of 0 since it was the first namespace seen by the DocumentStore EXPECT_THAT(doc_store->GetCorpusId("namespace", "email"), IsOkAndHolds(Eq(0))); } @@ -1642,6 +1715,183 @@ TEST_F(DocumentStoreTest, NonexistentCorpusNotFound) { StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); EXPECT_THAT(doc_store->GetCorpusId("namespace1", "nonexistent_schema"), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/1), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); +} + +TEST_F(DocumentStoreTest, GetCorpusAssociatedScoreDataSameCorpus) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + DocumentProto document1 = + DocumentBuilder().SetKey("namespace", "1").SetSchema("email").Build(); + DocumentProto document2 = + DocumentBuilder().SetKey("namespace", "2").SetSchema("email").Build(); + + ICING_ASSERT_OK(doc_store->Put(document1, /*num_tokens=*/5)); + ICING_ASSERT_OK(doc_store->Put(document2, /*num_tokens=*/7)); + + // CorpusId of 0 since it was the first namespace seen by the DocumentStore + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/2, /*sum_length_in_tokens=*/12))); + // Only one corpus exists + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/1), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); +} + +TEST_F(DocumentStoreTest, GetCorpusAssociatedScoreData) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + DocumentProto document_corpus1 = + DocumentBuilder().SetKey("namespace1", "1").SetSchema("email").Build(); + DocumentProto document_corpus2 = + DocumentBuilder().SetKey("namespace2", "2").SetSchema("email").Build(); + + ICING_ASSERT_OK( + doc_store->Put(DocumentProto(document_corpus1), /*num_tokens=*/5)); + ICING_ASSERT_OK( + doc_store->Put(DocumentProto(document_corpus2), /*num_tokens=*/7)); + + // CorpusId of 0 since it was the first corpus seen by the DocumentStore + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/1, /*sum_length_in_tokens=*/5))); + + // CorpusId of 1 since it was the second corpus seen by the + // DocumentStore + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/1), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/1, /*sum_length_in_tokens=*/7))); + + // DELETE namespace1 - document_corpus1 is deleted. + ICING_EXPECT_OK(doc_store->DeleteByNamespace("namespace1").status); + + // Corpus score cache doesn't care if the document has been deleted + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + IsOkAndHolds(CorpusAssociatedScoreData( + /*num_docs=*/1, /*sum_length_in_tokens=*/5))); +} + +TEST_F(DocumentStoreTest, NonexistentCorpusAssociatedScoreDataOutOfRange) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + EXPECT_THAT(doc_store->GetCorpusAssociatedScoreData(/*corpus_id=*/0), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); +} + +TEST_F(DocumentStoreTest, GetDocumentAssociatedScoreDataSameCorpus) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + DocumentProto document1 = + DocumentBuilder() + .SetKey("namespace", "1") + .SetSchema("email") + .SetScore(document1_score_) + .SetCreationTimestampMs( + document1_creation_timestamp_) // A random timestamp + .Build(); + DocumentProto document2 = + DocumentBuilder() + .SetKey("namespace", "2") + .SetSchema("email") + .SetScore(document2_score_) + .SetCreationTimestampMs( + document2_creation_timestamp_) // A random timestamp + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + doc_store->Put(DocumentProto(document1), /*num_tokens=*/5)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + doc_store->Put(DocumentProto(document2), /*num_tokens=*/7)); + + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id1), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document1_score_, document1_creation_timestamp_, + /*length_in_tokens=*/5))); + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id2), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document2_score_, document2_creation_timestamp_, + /*length_in_tokens=*/7))); +} + +TEST_F(DocumentStoreTest, GetCorpusAssociatedScoreDataDifferentCorpus) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + DocumentProto document1 = + DocumentBuilder() + .SetKey("namespace1", "1") + .SetSchema("email") + .SetScore(document1_score_) + .SetCreationTimestampMs( + document1_creation_timestamp_) // A random timestamp + .Build(); + DocumentProto document2 = + DocumentBuilder() + .SetKey("namespace2", "2") + .SetSchema("email") + .SetScore(document2_score_) + .SetCreationTimestampMs( + document2_creation_timestamp_) // A random timestamp + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + doc_store->Put(DocumentProto(document1), /*num_tokens=*/5)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + doc_store->Put(DocumentProto(document2), /*num_tokens=*/7)); + + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id1), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, document1_score_, document1_creation_timestamp_, + /*length_in_tokens=*/5))); + EXPECT_THAT( + doc_store->GetDocumentAssociatedScoreData(document_id2), + IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/1, document2_score_, document2_creation_timestamp_, + /*length_in_tokens=*/7))); +} + +TEST_F(DocumentStoreTest, NonexistentDocumentAssociatedScoreDataOutOfRange) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + EXPECT_THAT(doc_store->GetDocumentAssociatedScoreData(/*document_id=*/0), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); } TEST_F(DocumentStoreTest, SoftDeletionDoesNotClearFilterCache) { @@ -1700,12 +1950,13 @@ TEST_F(DocumentStoreTest, SoftDeletionDoesNotClearScoreCache) { std::move(create_result.document_store); ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, - doc_store->Put(test_document1_)); + doc_store->Put(test_document1_, /*num_tokens=*/4)); EXPECT_THAT(doc_store->GetDocumentAssociatedScoreData(document_id), IsOkAndHolds(DocumentAssociatedScoreData( - /*document_score=*/document1_score_, - /*creation_timestamp_ms=*/document1_creation_timestamp_))); + /*corpus_id=*/0, /*document_score=*/document1_score_, + /*creation_timestamp_ms=*/document1_creation_timestamp_, + /*length_in_tokens=*/4))); ICING_ASSERT_OK(doc_store->Delete("icing", "email/1", /*soft_delete=*/true)); // Associated entry of the deleted document is removed. @@ -1722,12 +1973,14 @@ TEST_F(DocumentStoreTest, HardDeleteClearsScoreCache) { std::move(create_result.document_store); ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, - doc_store->Put(test_document1_)); + doc_store->Put(test_document1_, /*num_tokens=*/4)); EXPECT_THAT(doc_store->GetDocumentAssociatedScoreData(document_id), IsOkAndHolds(DocumentAssociatedScoreData( + /*corpus_id=*/0, /*document_score=*/document1_score_, - /*creation_timestamp_ms=*/document1_creation_timestamp_))); + /*creation_timestamp_ms=*/document1_creation_timestamp_, + /*length_in_tokens=*/4))); ICING_ASSERT_OK(doc_store->Delete("icing", "email/1", /*soft_delete=*/false)); // Associated entry of the deleted document is removed. @@ -1931,11 +2184,15 @@ TEST_F(DocumentStoreTest, ShouldWriteAndReadScoresCorrectly) { EXPECT_THAT(doc_store->GetDocumentAssociatedScoreData(document_id1), IsOkAndHolds(DocumentAssociatedScoreData( - /*document_score=*/0, /*creation_timestamp_ms=*/0))); + /*corpus_id=*/0, + /*document_score=*/0, /*creation_timestamp_ms=*/0, + /*length_in_tokens=*/0))); EXPECT_THAT(doc_store->GetDocumentAssociatedScoreData(document_id2), IsOkAndHolds(DocumentAssociatedScoreData( - /*document_score=*/5, /*creation_timestamp_ms=*/0))); + /*corpus_id=*/0, + /*document_score=*/5, /*creation_timestamp_ms=*/0, + /*length_in_tokens=*/0))); } TEST_F(DocumentStoreTest, ComputeChecksumSameBetweenCalls) { @@ -2636,7 +2893,8 @@ TEST_F(DocumentStoreTest, GetOptimizeInfo) { std::string optimized_dir = document_store_dir_ + "_optimize"; EXPECT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); EXPECT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); - ICING_ASSERT_OK(document_store->OptimizeInto(optimized_dir)); + ICING_ASSERT_OK( + document_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); document_store.reset(); ICING_ASSERT_OK_AND_ASSIGN( create_result, DocumentStore::Create(&filesystem_, optimized_dir, @@ -3046,7 +3304,8 @@ TEST_F(DocumentStoreTest, UsageScoresShouldPersistOnOptimize) { // Run optimize std::string optimized_dir = document_store_dir_ + "/optimize_test"; filesystem_.CreateDirectoryRecursively(optimized_dir.c_str()); - ICING_ASSERT_OK(document_store->OptimizeInto(optimized_dir)); + ICING_ASSERT_OK( + document_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); // Get optimized document store ICING_ASSERT_OK_AND_ASSIGN( @@ -3149,9 +3408,9 @@ TEST_F(DocumentStoreTest, LoadScoreCacheAndInitializeSuccessfully) { // the current code is compatible with the format of the v0 scoring_cache, // then an empty document store should be initialized, but the non-empty // scoring_cache should be retained. - // Since the current document-asscoiated-score-data is compatible with the - // score_cache in testdata/v0/document_store, the document store should be - // initialized without having to re-generate the derived files. + // The current document-asscoiated-score-data has a new field with respect to + // the ones stored in testdata/v0, hence the document store's initialization + // requires regenerating its derived files. // Create dst directory ASSERT_THAT(filesystem_.CreateDirectory(document_store_dir_.c_str()), true); @@ -3186,9 +3445,10 @@ TEST_F(DocumentStoreTest, LoadScoreCacheAndInitializeSuccessfully) { schema_store_.get(), &initializeStats)); std::unique_ptr<DocumentStore> doc_store = std::move(create_result.document_store); - // Regeneration never happens. - EXPECT_EQ(initializeStats.document_store_recovery_cause(), - NativeInitializeStats::NONE); + // The store_cache trigger regeneration because its element size is + // inconsistent: expected 20 (current new size), actual 12 (as per the v0 + // score_cache). + EXPECT_TRUE(initializeStats.has_document_store_recovery_cause()); } } // namespace diff --git a/icing/store/usage-store.cc b/icing/store/usage-store.cc index 7a0af9c..54896dc 100644 --- a/icing/store/usage-store.cc +++ b/icing/store/usage-store.cc @@ -214,6 +214,10 @@ libtextclassifier3::StatusOr<Crc32> UsageStore::ComputeChecksum() { return usage_score_cache_->ComputeChecksum(); } +libtextclassifier3::StatusOr<int64_t> UsageStore::GetElementsFileSize() const { + return usage_score_cache_->GetElementsFileSize(); +} + libtextclassifier3::Status UsageStore::TruncateTo(DocumentId num_documents) { if (num_documents >= usage_score_cache_->num_elements()) { // No need to truncate diff --git a/icing/store/usage-store.h b/icing/store/usage-store.h index 0a622a0..b7de970 100644 --- a/icing/store/usage-store.h +++ b/icing/store/usage-store.h @@ -148,6 +148,15 @@ class UsageStore { // INTERNAL_ERROR if the internal state is inconsistent libtextclassifier3::StatusOr<Crc32> ComputeChecksum(); + // Returns the file size of the all the elements held in the UsageStore. File + // size is in bytes. This excludes the size of any internal metadata, e.g. any + // internal headers. + // + // Returns: + // File size on success + // INTERNAL_ERROR on IO error + libtextclassifier3::StatusOr<int64_t> GetElementsFileSize() const; + // Resizes the storage so that only the usage scores of and before // last_document_id are stored. // diff --git a/icing/store/usage-store_test.cc b/icing/store/usage-store_test.cc index f7fa778..220c226 100644 --- a/icing/store/usage-store_test.cc +++ b/icing/store/usage-store_test.cc @@ -24,6 +24,7 @@ namespace lib { namespace { using ::testing::Eq; +using ::testing::Gt; using ::testing::Not; class UsageStoreTest : public testing::Test { @@ -560,6 +561,22 @@ TEST_F(UsageStoreTest, StoreShouldBeResetOnHeaderChecksumMismatch) { IsOkAndHolds(UsageStore::UsageScores())); } +TEST_F(UsageStoreTest, GetElementsFileSize) { + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<UsageStore> usage_store, + UsageStore::Create(&filesystem_, test_dir_)); + + ICING_ASSERT_OK_AND_ASSIGN(int64_t empty_file_size, + usage_store->GetElementsFileSize()); + EXPECT_THAT(empty_file_size, Eq(0)); + + UsageReport usage_report = CreateUsageReport( + "namespace", "uri", /*timestamp_ms=*/1000, UsageReport::USAGE_TYPE1); + usage_store->AddUsageReport(usage_report, /*document_id=*/1); + + EXPECT_THAT(usage_store->GetElementsFileSize(), + IsOkAndHolds(Gt(empty_file_size))); +} + } // namespace } // namespace lib diff --git a/icing/testing/common-matchers.h b/icing/testing/common-matchers.h index a15e64e..b7f54ba 100644 --- a/icing/testing/common-matchers.h +++ b/icing/testing/common-matchers.h @@ -15,6 +15,8 @@ #ifndef ICING_TESTING_COMMON_MATCHERS_H_ #define ICING_TESTING_COMMON_MATCHERS_H_ +#include <cmath> + #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/status_macros.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" @@ -103,7 +105,7 @@ MATCHER_P(EqualsScoredDocumentHit, expected_scored_document_hit, "") { if (arg.document_id() != expected_scored_document_hit.document_id() || arg.hit_section_id_mask() != expected_scored_document_hit.hit_section_id_mask() || - arg.score() != expected_scored_document_hit.score()) { + std::fabs(arg.score() - expected_scored_document_hit.score()) > 1e-6) { *result_listener << IcingStringUtil::StringPrintf( "Expected: document_id=%d, hit_section_id_mask=%d, score=%.2f. Actual: " "document_id=%d, hit_section_id_mask=%d, score=%.2f", diff --git a/icing/util/document-validator.cc b/icing/util/document-validator.cc index fb1fc4b..8d6d51a 100644 --- a/icing/util/document-validator.cc +++ b/icing/util/document-validator.cc @@ -32,12 +32,13 @@ DocumentValidator::DocumentValidator(const SchemaStore* schema_store) : schema_store_(schema_store) {} libtextclassifier3::Status DocumentValidator::Validate( - const DocumentProto& document) { + const DocumentProto& document, int depth) { if (document.namespace_().empty()) { return absl_ports::InvalidArgumentError("Field 'namespace' is empty."); } - if (document.uri().empty()) { + // Only require a non-empty uri on top-level documents. + if (depth == 0 && document.uri().empty()) { return absl_ports::InvalidArgumentError("Field 'uri' is empty."); } @@ -160,7 +161,7 @@ libtextclassifier3::Status DocumentValidator::Validate( nested_document.schema(), "' for key: (", document.namespace_(), ", ", document.uri(), ").")); } - ICING_RETURN_IF_ERROR(Validate(nested_document)); + ICING_RETURN_IF_ERROR(Validate(nested_document, depth + 1)); } } } diff --git a/icing/util/document-validator.h b/icing/util/document-validator.h index 036d1fa..8542283 100644 --- a/icing/util/document-validator.h +++ b/icing/util/document-validator.h @@ -32,7 +32,8 @@ class DocumentValidator { // This function validates: // 1. DocumentProto.namespace is not empty - // 2. DocumentProto.uri is not empty + // 2. DocumentProto.uri is not empty in top-level documents. Nested documents + // may have empty uris. // 3. DocumentProto.schema is not empty // 4. DocumentProto.schema matches one of SchemaTypeConfigProto.schema_type // in the given SchemaProto in constructor @@ -56,6 +57,9 @@ class DocumentValidator { // In addition, all nested DocumentProto will also be validated towards the // requirements above. // + // 'depth' indicates what nesting level the document may be at. A top-level + // document has a nesting depth of 0. + // // Returns: // OK on success // FAILED_PRECONDITION if no schema is set yet @@ -63,7 +67,8 @@ class DocumentValidator { // NOT_FOUND if case 4 or 7 fails // ALREADY_EXISTS if case 6 fails // INTERNAL on any I/O error - libtextclassifier3::Status Validate(const DocumentProto& document); + libtextclassifier3::Status Validate(const DocumentProto& document, + int depth = 0); void UpdateSchemaStore(const SchemaStore* schema_store) { schema_store_ = schema_store; diff --git a/icing/util/document-validator_test.cc b/icing/util/document-validator_test.cc index ad5a93e..f05e8a6 100644 --- a/icing/util/document-validator_test.cc +++ b/icing/util/document-validator_test.cc @@ -141,13 +141,27 @@ TEST_F(DocumentValidatorTest, ValidateEmptyNamespaceInvalid) { HasSubstr("'namespace' is empty"))); } -TEST_F(DocumentValidatorTest, ValidateEmptyUriInvalid) { +TEST_F(DocumentValidatorTest, ValidateTopLevelEmptyUriInvalid) { DocumentProto email = SimpleEmailBuilder().SetUri("").Build(); EXPECT_THAT(document_validator_->Validate(email), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT, HasSubstr("'uri' is empty"))); } +TEST_F(DocumentValidatorTest, ValidateNestedEmptyUriValid) { + DocumentProto conversation = + SimpleConversationBuilder() + .ClearProperties() + .AddStringProperty(kPropertyName, kDefaultString) + .AddDocumentProperty(kPropertyEmails, + SimpleEmailBuilder() + .SetUri("") // Empty nested uri + .Build()) + .Build(); + + EXPECT_THAT(document_validator_->Validate(conversation), IsOk()); +} + TEST_F(DocumentValidatorTest, ValidateEmptySchemaInvalid) { DocumentProto email = SimpleEmailBuilder().SetSchema("").Build(); EXPECT_THAT(document_validator_->Validate(email), diff --git a/icing/util/tokenized-document.cc b/icing/util/tokenized-document.cc new file mode 100644 index 0000000..02ee459 --- /dev/null +++ b/icing/util/tokenized-document.cc @@ -0,0 +1,74 @@ +// Copyright (C) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/util/tokenized-document.h" + +#include <string> +#include <string_view> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/proto/document.proto.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/tokenization/language-segmenter.h" +#include "icing/tokenization/tokenizer-factory.h" +#include "icing/tokenization/tokenizer.h" +#include "icing/util/document-validator.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<TokenizedDocument> TokenizedDocument::Create( + const SchemaStore* schema_store, + const LanguageSegmenter* language_segmenter, DocumentProto document) { + TokenizedDocument tokenized_document(std::move(document)); + ICING_RETURN_IF_ERROR( + tokenized_document.Tokenize(schema_store, language_segmenter)); + return tokenized_document; +} + +TokenizedDocument::TokenizedDocument(DocumentProto document) + : document_(std::move(document)) {} + +libtextclassifier3::Status TokenizedDocument::Tokenize( + const SchemaStore* schema_store, + const LanguageSegmenter* language_segmenter) { + DocumentValidator validator(schema_store); + ICING_RETURN_IF_ERROR(validator.Validate(document_)); + + ICING_ASSIGN_OR_RETURN(std::vector<Section> sections, + schema_store->ExtractSections(document_)); + for (const Section& section : sections) { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer> tokenizer, + tokenizer_factory::CreateIndexingTokenizer( + section.metadata.tokenizer, language_segmenter)); + std::vector<std::string_view> token_sequence; + for (std::string_view subcontent : section.content) { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer::Iterator> itr, + tokenizer->Tokenize(subcontent)); + while (itr->Advance()) { + token_sequence.push_back(itr->GetToken().text); + } + } + tokenized_sections_.emplace_back(SectionMetadata(section.metadata), + std::move(token_sequence)); + } + + return libtextclassifier3::Status::OK; +} + +} // namespace lib +} // namespace icing diff --git a/icing/util/tokenized-document.h b/icing/util/tokenized-document.h new file mode 100644 index 0000000..5283195 --- /dev/null +++ b/icing/util/tokenized-document.h @@ -0,0 +1,76 @@ +// Copyright (C) 2020 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_STORE_TOKENIZED_DOCUMENT_H_ +#define ICING_STORE_TOKENIZED_DOCUMENT_H_ + +#include <cstdint> +#include <string> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/proto/document.proto.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/tokenization/language-segmenter.h" + +namespace icing { +namespace lib { + +struct TokenizedSection { + SectionMetadata metadata; + std::vector<std::string_view> token_sequence; + + TokenizedSection(SectionMetadata&& metadata_in, + std::vector<std::string_view>&& token_sequence_in) + : metadata(std::move(metadata_in)), + token_sequence(std::move(token_sequence_in)) {} +}; + +class TokenizedDocument { + public: + static libtextclassifier3::StatusOr<TokenizedDocument> Create( + const SchemaStore* schema_store, + const LanguageSegmenter* language_segmenter, DocumentProto document); + + const DocumentProto& document() const { return document_; } + + int32_t num_tokens() const { + int32_t num_tokens = 0; + for (const TokenizedSection& section : tokenized_sections_) { + num_tokens += section.token_sequence.size(); + } + return num_tokens; + } + + const std::vector<TokenizedSection>& sections() const { + return tokenized_sections_; + } + + private: + // Use TokenizedDocument::Create() to instantiate. + explicit TokenizedDocument(DocumentProto document); + + DocumentProto document_; + std::vector<TokenizedSection> tokenized_sections_; + + libtextclassifier3::Status Tokenize( + const SchemaStore* schema_store, + const LanguageSegmenter* language_segmenter); +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_STORE_TOKENIZED_DOCUMENT_H_ diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java index 22c607c..88d0578 100644 --- a/java/src/com/google/android/icing/IcingSearchEngine.java +++ b/java/src/com/google/android/icing/IcingSearchEngine.java @@ -24,6 +24,7 @@ import com.google.android.icing.proto.DocumentProto; import com.google.android.icing.proto.GetAllNamespacesResultProto; import com.google.android.icing.proto.GetOptimizeInfoResultProto; import com.google.android.icing.proto.GetResultProto; +import com.google.android.icing.proto.GetResultSpecProto; import com.google.android.icing.proto.GetSchemaResultProto; import com.google.android.icing.proto.GetSchemaTypeResultProto; import com.google.android.icing.proto.IcingSearchEngineOptions; @@ -41,8 +42,8 @@ import com.google.android.icing.proto.SearchSpecProto; import com.google.android.icing.proto.SetSchemaResultProto; import com.google.android.icing.proto.StatusProto; import com.google.android.icing.proto.UsageReport; -import com.google.android.icing.protobuf.ExtensionRegistryLite; -import com.google.android.icing.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.ExtensionRegistryLite; +import com.google.protobuf.InvalidProtocolBufferException; import java.io.Closeable; /** @@ -84,7 +85,9 @@ public final class IcingSearchEngine implements Closeable { @Override public void close() { - throwIfClosed(); + if (closed) { + return; + } if (nativePointer != 0) { nativeDestroy(this); @@ -95,8 +98,8 @@ public final class IcingSearchEngine implements Closeable { @Override protected void finalize() throws Throwable { - super.finalize(); close(); + super.finalize(); } @NonNull @@ -217,10 +220,11 @@ public final class IcingSearchEngine implements Closeable { } @NonNull - public GetResultProto get(@NonNull String namespace, @NonNull String uri) { + public GetResultProto get( + @NonNull String namespace, @NonNull String uri, @NonNull GetResultSpecProto getResultSpec) { throwIfClosed(); - byte[] getResultBytes = nativeGet(this, namespace, uri); + byte[] getResultBytes = nativeGet(this, namespace, uri, getResultSpec.toByteArray()); if (getResultBytes == null) { Log.e(TAG, "Received null GetResultProto from native."); return GetResultProto.newBuilder() @@ -533,7 +537,8 @@ public final class IcingSearchEngine implements Closeable { private static native byte[] nativePut(IcingSearchEngine instance, byte[] documentBytes); - private static native byte[] nativeGet(IcingSearchEngine instance, String namespace, String uri); + private static native byte[] nativeGet( + IcingSearchEngine instance, String namespace, String uri, byte[] getResultSpecBytes); private static native byte[] nativeReportUsage( IcingSearchEngine instance, byte[] usageReportBytes); diff --git a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java index 6f07e1a..56edaf1 100644 --- a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java +++ b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java @@ -25,6 +25,7 @@ import com.google.android.icing.proto.DocumentProto; import com.google.android.icing.proto.GetAllNamespacesResultProto; import com.google.android.icing.proto.GetOptimizeInfoResultProto; import com.google.android.icing.proto.GetResultProto; +import com.google.android.icing.proto.GetResultSpecProto; import com.google.android.icing.proto.GetSchemaResultProto; import com.google.android.icing.proto.GetSchemaTypeResultProto; import com.google.android.icing.proto.IcingSearchEngineOptions; @@ -163,7 +164,8 @@ public final class IcingSearchEngineTest { PutResultProto putResultProto = icingSearchEngine.put(emailDocument); assertStatusOk(putResultProto.getStatus()); - GetResultProto getResultProto = icingSearchEngine.get("namespace", "uri"); + GetResultProto getResultProto = + icingSearchEngine.get("namespace", "uri", GetResultSpecProto.getDefaultInstance()); assertStatusOk(getResultProto.getStatus()); assertThat(getResultProto.getDocument()).isEqualTo(emailDocument); } @@ -281,7 +283,8 @@ public final class IcingSearchEngineTest { DeleteResultProto deleteResultProto = icingSearchEngine.delete("namespace", "uri"); assertStatusOk(deleteResultProto.getStatus()); - GetResultProto getResultProto = icingSearchEngine.get("namespace", "uri"); + GetResultProto getResultProto = + icingSearchEngine.get("namespace", "uri", GetResultSpecProto.getDefaultInstance()); assertThat(getResultProto.getStatus().getCode()).isEqualTo(StatusProto.Code.NOT_FOUND); } @@ -305,7 +308,8 @@ public final class IcingSearchEngineTest { icingSearchEngine.deleteByNamespace("namespace"); assertStatusOk(deleteByNamespaceResultProto.getStatus()); - GetResultProto getResultProto = icingSearchEngine.get("namespace", "uri"); + GetResultProto getResultProto = + icingSearchEngine.get("namespace", "uri", GetResultSpecProto.getDefaultInstance()); assertThat(getResultProto.getStatus().getCode()).isEqualTo(StatusProto.Code.NOT_FOUND); } @@ -329,7 +333,8 @@ public final class IcingSearchEngineTest { icingSearchEngine.deleteBySchemaType(EMAIL_TYPE); assertStatusOk(deleteBySchemaTypeResultProto.getStatus()); - GetResultProto getResultProto = icingSearchEngine.get("namespace", "uri"); + GetResultProto getResultProto = + icingSearchEngine.get("namespace", "uri", GetResultSpecProto.getDefaultInstance()); assertThat(getResultProto.getStatus().getCode()).isEqualTo(StatusProto.Code.NOT_FOUND); } @@ -377,9 +382,11 @@ public final class IcingSearchEngineTest { DeleteByQueryResultProto deleteResultProto = icingSearchEngine.deleteByQuery(searchSpec); assertStatusOk(deleteResultProto.getStatus()); - GetResultProto getResultProto = icingSearchEngine.get("namespace", "uri1"); + GetResultProto getResultProto = + icingSearchEngine.get("namespace", "uri1", GetResultSpecProto.getDefaultInstance()); assertThat(getResultProto.getStatus().getCode()).isEqualTo(StatusProto.Code.NOT_FOUND); - getResultProto = icingSearchEngine.get("namespace", "uri2"); + getResultProto = + icingSearchEngine.get("namespace", "uri2", GetResultSpecProto.getDefaultInstance()); assertStatusOk(getResultProto.getStatus()); } diff --git a/proto/icing/proto/document.proto b/proto/icing/proto/document.proto index ae73917..d55b7e2 100644 --- a/proto/icing/proto/document.proto +++ b/proto/icing/proto/document.proto @@ -24,7 +24,7 @@ option java_multiple_files = true; option objc_class_prefix = "ICNG"; // Defines a unit of data understood by the IcingSearchEngine. -// Next tag: 9 +// Next tag: 10 message DocumentProto { // REQUIRED: Namespace that this Document resides in. // Namespaces can affect read/write permissions. @@ -65,6 +65,15 @@ message DocumentProto { // in terms of space/time efficiency. Both for ttl_ms and timestamp fields optional int64 ttl_ms = 8 [default = 0]; + // Defines document level data that's generated internally by Icing. + message InternalFields { + // The length of the document as a count of tokens (or terms) in all indexed + // text properties. This field is used in the computation of BM25F relevance + // score. + optional int32 length_in_tokens = 1; + } + optional InternalFields internal_fields = 9; + reserved 6; } diff --git a/proto/icing/proto/schema.proto b/proto/icing/proto/schema.proto index 0298f65..4188a8c 100644 --- a/proto/icing/proto/schema.proto +++ b/proto/icing/proto/schema.proto @@ -34,7 +34,7 @@ option objc_class_prefix = "ICNG"; // TODO(cassiewang) Define a sample proto file that can be used by tests and for // documentation. // -// Next tag: 5 +// Next tag: 6 message SchemaTypeConfigProto { // REQUIRED: Named type that uniquely identifies the structured, logical // schema being defined. @@ -51,6 +51,15 @@ message SchemaTypeConfigProto { // easier. repeated PropertyConfigProto properties = 4; + // Version is an arbitrary number that the client may use to keep track of + // different incarnations of the schema. Icing library imposes no requirements + // on this field and will not validate it in anyway. If a client calls + // SetSchema with a schema that contains one or more new version numbers, then + // those version numbers will be updated so long as the SetSchema call + // succeeds. Clients are free to leave the version number unset, in which case + // it will default to value == 0. + optional int32 version = 5; + reserved 2, 3; } diff --git a/proto/icing/proto/scoring.proto b/proto/icing/proto/scoring.proto index bfa7aec..6186fde 100644 --- a/proto/icing/proto/scoring.proto +++ b/proto/icing/proto/scoring.proto @@ -64,11 +64,8 @@ message ScoringSpecProto { // compared in seconds. USAGE_TYPE3_LAST_USED_TIMESTAMP = 8; - // Placeholder for ranking by relevance score, currently computed as BM25F - // score. - // TODO(b/173156803): one the implementation is ready, rename to - // RELEVANCE_SCORE. - RELEVANCE_SCORE_NONFUNCTIONAL_PLACEHOLDER = 9; + // Ranked by relevance score, currently computed as BM25F score. + RELEVANCE_SCORE = 9; } } optional RankingStrategy.Code rank_by = 1; diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index f63acfa..6c4e3c9 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -101,20 +101,6 @@ message ResultSpecProto { // How to specify a subset of properties to retrieve. If no type property mask // has been specified for a schema type, then *all* properties of that schema // type will be retrieved. - // Next tag: 3 - message TypePropertyMask { - // The schema type to which these property masks should apply. - // If the schema type is the wildcard ("*"), then the type property masks - // will apply to all results of types that don't have their own, specific - // type property mask entry. - optional string schema_type = 1; - - // The property masks specifying the property to be retrieved. Property - // masks must be composed only of property names, property separators (the - // '.' character). For example, "subject", "recipients.name". Specifying no - // property masks will result in *no* properties being retrieved. - repeated string paths = 2; - } repeated TypePropertyMask type_property_masks = 4; } @@ -214,3 +200,26 @@ message SearchResultProto { // Stats for query execution performance. optional NativeQueryStats query_stats = 5; } + +// Next tag: 3 +message TypePropertyMask { + // The schema type to which these property masks should apply. + // If the schema type is the wildcard ("*"), then the type property masks + // will apply to all results of types that don't have their own, specific + // type property mask entry. + optional string schema_type = 1; + + // The property masks specifying the property to be retrieved. Property + // masks must be composed only of property names, property separators (the + // '.' character). For example, "subject", "recipients.name". Specifying no + // property masks will result in *no* properties being retrieved. + repeated string paths = 2; +} + +// Next tag: 2 +message GetResultSpecProto { + // How to specify a subset of properties to retrieve. If no type property mask + // has been specified for a schema type, then *all* properties of that schema + // type will be retrieved. + repeated TypePropertyMask type_property_masks = 1; +} diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index be9e98c..af8248d 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=349594076) +set(synced_AOSP_CL_number=351841227) |