aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--icing/file/file-backed-vector.h19
-rw-r--r--icing/file/persistent-hash-map.cc179
-rw-r--r--icing/file/persistent-hash-map.h107
-rw-r--r--icing/file/persistent-hash-map_test.cc651
-rw-r--r--icing/file/posting_list/posting-list-accessor.cc110
-rw-r--r--icing/file/posting_list/posting-list-accessor.h (renamed from icing/index/main/posting-list-accessor.h)87
-rw-r--r--icing/file/posting_list/posting-list-common.h2
-rw-r--r--icing/file/posting_list/posting-list-used.h22
-rw-r--r--icing/icing-search-engine.cc235
-rw-r--r--icing/icing-search-engine.h42
-rw-r--r--icing/icing-search-engine_test.cc419
-rw-r--r--icing/index/index-processor.cc124
-rw-r--r--icing/index/index-processor.h30
-rw-r--r--icing/index/index-processor_benchmark.cc25
-rw-r--r--icing/index/index-processor_test.cc189
-rw-r--r--icing/index/integer-section-indexing-handler.cc70
-rw-r--r--icing/index/integer-section-indexing-handler.h55
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-test-util.h1
-rw-r--r--icing/index/main/doc-hit-info-iterator-term-main.cc2
-rw-r--r--icing/index/main/doc-hit-info-iterator-term-main.h4
-rw-r--r--icing/index/main/main-index.cc86
-rw-r--r--icing/index/main/main-index.h24
-rw-r--r--icing/index/main/posting-list-accessor.cc215
-rw-r--r--icing/index/main/posting-list-hit-accessor.cc126
-rw-r--r--icing/index/main/posting-list-hit-accessor.h103
-rw-r--r--icing/index/main/posting-list-hit-accessor_test.cc (renamed from icing/index/main/posting-list-accessor_test.cc)142
-rw-r--r--icing/index/main/posting-list-used-hit-serializer.cc1
-rw-r--r--icing/index/main/posting-list-used-hit-serializer.h11
-rw-r--r--icing/index/main/posting-list-used-hit-serializer_test.cc6
-rw-r--r--icing/index/numeric/doc-hit-info-iterator-numeric.h63
-rw-r--r--icing/index/numeric/dummy-numeric-index.h239
-rw-r--r--icing/index/numeric/integer-index-data.h59
-rw-r--r--icing/index/numeric/numeric-index.h146
-rw-r--r--icing/index/numeric/numeric-index_test.cc361
-rw-r--r--icing/index/numeric/posting-list-integer-index-data-accessor.cc136
-rw-r--r--icing/index/numeric/posting-list-integer-index-data-accessor.h108
-rw-r--r--icing/index/numeric/posting-list-integer-index-data-accessor_test.cc410
-rw-r--r--icing/index/numeric/posting-list-used-integer-index-data-serializer.cc514
-rw-r--r--icing/index/numeric/posting-list-used-integer-index-data-serializer.h338
-rw-r--r--icing/index/numeric/posting-list-used-integer-index-data-serializer_test.cc523
-rw-r--r--icing/index/section-indexing-handler.h60
-rw-r--r--icing/index/string-section-indexing-handler.cc146
-rw-r--r--icing/index/string-section-indexing-handler.h67
-rw-r--r--icing/jni/icing-search-engine-jni.cc267
-rw-r--r--icing/join/aggregate-scorer.cc117
-rw-r--r--icing/join/aggregate-scorer.h41
-rw-r--r--icing/join/join-processor.cc180
-rw-r--r--icing/join/join-processor.h57
-rw-r--r--icing/monkey_test/icing-monkey-test-runner.cc63
-rw-r--r--icing/monkey_test/icing-monkey-test-runner.h36
-rw-r--r--icing/monkey_test/icing-search-engine_monkey_test.cc62
-rw-r--r--icing/monkey_test/monkey-test-generators.cc15
-rw-r--r--icing/monkey_test/monkey-test-generators.h7
-rw-r--r--icing/portable/platform.h23
-rw-r--r--icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h108
-rw-r--r--icing/query/advanced_query_parser/abstract-syntax-tree.h168
-rw-r--r--icing/query/advanced_query_parser/abstract-syntax-tree_test.cc141
-rw-r--r--icing/query/advanced_query_parser/lexer.cc228
-rw-r--r--icing/query/advanced_query_parser/lexer.h153
-rw-r--r--icing/query/advanced_query_parser/lexer_fuzz_test.cc37
-rw-r--r--icing/query/advanced_query_parser/lexer_test.cc613
-rw-r--r--icing/query/advanced_query_parser/parser.cc414
-rw-r--r--icing/query/advanced_query_parser/parser.h140
-rw-r--r--icing/query/advanced_query_parser/parser_integration_test.cc945
-rw-r--r--icing/query/advanced_query_parser/parser_test.cc1043
-rw-r--r--icing/query/advanced_query_parser/query-visitor.cc228
-rw-r--r--icing/query/advanced_query_parser/query-visitor.h119
-rw-r--r--icing/query/advanced_query_parser/query-visitor_test.cc557
-rw-r--r--icing/query/query-processor.cc71
-rw-r--r--icing/query/query-processor.h20
-rw-r--r--icing/query/query-processor_benchmark.cc38
-rw-r--r--icing/query/query-processor_test.cc887
-rw-r--r--icing/result/result-retriever-v2.cc39
-rw-r--r--icing/result/result-retriever-v2_group-result-limiter_test.cc24
-rw-r--r--icing/result/result-retriever-v2_projection_test.cc36
-rw-r--r--icing/result/result-retriever-v2_snippet_test.cc20
-rw-r--r--icing/result/result-retriever-v2_test.cc30
-rw-r--r--icing/result/result-retriever_test.cc2
-rw-r--r--icing/result/result-state-manager_test.cc147
-rw-r--r--icing/result/result-state-manager_thread-safety_test.cc9
-rw-r--r--icing/result/result-state-v2_test.cc33
-rw-r--r--icing/result/snippet-retriever_test.cc2
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer.cc58
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer.h66
-rw-r--r--icing/scoring/advanced_scoring/advanced-scorer_test.cc404
-rw-r--r--icing/scoring/advanced_scoring/score-expression.cc203
-rw-r--r--icing/scoring/advanced_scoring/score-expression.h154
-rw-r--r--icing/scoring/advanced_scoring/scoring-visitor.cc159
-rw-r--r--icing/scoring/advanced_scoring/scoring-visitor.h77
-rw-r--r--icing/scoring/priority-queue-scored-document-hits-ranker.cc53
-rw-r--r--icing/scoring/priority-queue-scored-document-hits-ranker.h69
-rw-r--r--icing/scoring/priority-queue-scored-document-hits-ranker_test.cc112
-rw-r--r--icing/scoring/scored-document-hit.cc30
-rw-r--r--icing/scoring/scored-document-hit.h67
-rw-r--r--icing/scoring/scored-document-hit_test.cc77
-rw-r--r--icing/scoring/scored-document-hits-ranker.h17
-rw-r--r--icing/scoring/scorer-factory.cc (renamed from icing/scoring/scorer.cc)26
-rw-r--r--icing/scoring/scorer-factory.h46
-rw-r--r--icing/scoring/scorer.h18
-rw-r--r--icing/scoring/scorer_test.cc238
-rw-r--r--icing/scoring/scoring-processor.cc11
-rw-r--r--icing/store/document-store.cc4
-rw-r--r--icing/store/dynamic-trie-key-mapper.h7
-rw-r--r--icing/store/dynamic-trie-key-mapper_test.cc167
-rw-r--r--icing/store/key-mapper_benchmark.cc316
-rw-r--r--icing/store/key-mapper_test.cc215
-rw-r--r--icing/store/persistent-hash-map-key-mapper.h209
-rw-r--r--icing/store/persistent-hash-map-key-mapper_test.cc52
-rw-r--r--icing/testing/common-matchers.h88
-rw-r--r--icing/tokenization/combined-tokenizer_test.cc84
-rw-r--r--icing/tokenization/icu/icu-language-segmenter_test.cc91
-rw-r--r--icing/tokenization/raw-query-tokenizer_test.cc26
-rw-r--r--icing/util/snippet-helpers.cc (renamed from icing/testing/snippet-helpers.cc)2
-rw-r--r--icing/util/snippet-helpers.h (renamed from icing/testing/snippet-helpers.h)0
-rw-r--r--icing/util/tokenized-document.cc56
-rw-r--r--icing/util/tokenized-document.h33
-rw-r--r--icing/util/tokenized-document_test.cc335
-rw-r--r--java/src/com/google/android/icing/IcingSearchEngine.java607
-rw-r--r--java/src/com/google/android/icing/IcingSearchEngineImpl.java330
-rw-r--r--java/src/com/google/android/icing/IcingSearchEngineInterface.java153
-rw-r--r--java/src/com/google/android/icing/IcingSearchEngineUtils.java471
-rw-r--r--proto/icing/proto/scoring.proto9
-rw-r--r--proto/icing/proto/search.proto44
-rw-r--r--synced_AOSP_CL_number.txt2
124 files changed, 15644 insertions, 2919 deletions
diff --git a/icing/file/file-backed-vector.h b/icing/file/file-backed-vector.h
index 183c091..1d99e24 100644
--- a/icing/file/file-backed-vector.h
+++ b/icing/file/file-backed-vector.h
@@ -76,6 +76,7 @@
#include "icing/file/filesystem.h"
#include "icing/file/memory-mapped-file.h"
#include "icing/legacy/core/icing-string-util.h"
+#include "icing/portable/platform.h"
#include "icing/util/crc32.h"
#include "icing/util/logging.h"
#include "icing/util/math-util.h"
@@ -147,10 +148,17 @@ class FileBackedVector {
}
};
- // Absolute max file size for FileBackedVector. Note that Android has a
- // (2^31-1)-byte single file size limit, so kMaxFileSize is 2^31-1.
+ // Absolute max file size for FileBackedVector.
+ // - We memory map the whole file, so file size ~= memory size.
+ // - On 32-bit platform, the virtual memory address space is 4GB. To avoid
+ // exhausting the memory, set smaller file size limit for 32-bit platform.
+#ifdef ICING_ARCH_BIT_64
static constexpr int32_t kMaxFileSize =
- std::numeric_limits<int32_t>::max(); // 2^31-1 Bytes, ~2.1 GB;
+ std::numeric_limits<int32_t>::max(); // 2^31-1 Bytes, ~2.1 GB
+#else
+ static constexpr int32_t kMaxFileSize =
+ (1 << 28) + Header::kHeaderSize; // 2^28 + 12 Bytes, ~256 MiB
+#endif
// Size of element type T. The value is same as sizeof(T), while we should
// avoid using sizeof(T) in our codebase to prevent unexpected unsigned
@@ -461,7 +469,8 @@ class FileBackedVector {
// Absolute max # of elements allowed. Since we are using int32_t to store
// num_elements, max value is 2^31-1. Still the actual max # of elements are
- // determined by max_file_size, kElementTypeSize, and Header::kHeaderSize.
+ // determined by max_file_size, kMaxFileSize, kElementTypeSize, and
+ // Header::kHeaderSize.
static constexpr int32_t kMaxNumElements =
std::numeric_limits<int32_t>::max();
@@ -887,7 +896,7 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary(
kElementTypeSize) {
return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf(
"%d elements total size exceed maximum bytes of elements allowed, "
- "%d bytes",
+ "%" PRId64 " bytes",
num_elements, mmapped_file_->max_file_size() - Header::kHeaderSize));
}
diff --git a/icing/file/persistent-hash-map.cc b/icing/file/persistent-hash-map.cc
index 0c9fd7f..0af5e2f 100644
--- a/icing/file/persistent-hash-map.cc
+++ b/icing/file/persistent-hash-map.cc
@@ -180,13 +180,66 @@ std::string GetKeyValueStorageFilePath(std::string_view base_dir) {
".k");
}
+// Calculates how many buckets we need given num_entries and
+// max_load_factor_percent. Round it up to 2's power.
+//
+// REQUIRES: 0 < num_entries <= Entry::kMaxNumEntries &&
+// max_load_factor_percent > 0
+int32_t CalculateNumBucketsRequired(int32_t num_entries,
+ int32_t max_load_factor_percent) {
+ // Calculate ceil(num_entries * 100 / max_load_factor_percent)
+ int32_t num_entries_100 = num_entries * 100;
+ int32_t num_buckets_required =
+ num_entries_100 / max_load_factor_percent +
+ (num_entries_100 % max_load_factor_percent == 0 ? 0 : 1);
+ if ((num_buckets_required & (num_buckets_required - 1)) != 0) {
+ // not 2's power
+ return 1 << (32 - __builtin_clz(num_buckets_required));
+ }
+ return num_buckets_required;
+}
+
} // namespace
+bool PersistentHashMap::Options::IsValid() const {
+ if (!(value_type_size > 0 && value_type_size <= kMaxValueTypeSize &&
+ max_num_entries > 0 && max_num_entries <= Entry::kMaxNumEntries &&
+ max_load_factor_percent > 0 && average_kv_byte_size > 0 &&
+ init_num_buckets > 0 && init_num_buckets <= Bucket::kMaxNumBuckets)) {
+ return false;
+ }
+
+ // We've ensured (static_assert) that storing kMaxNumBuckets buckets won't
+ // exceed FileBackedVector::kMaxFileSize, so only need to verify # of buckets
+ // required won't exceed kMaxNumBuckets.
+ if (CalculateNumBucketsRequired(max_num_entries, max_load_factor_percent) >
+ Bucket::kMaxNumBuckets) {
+ return false;
+ }
+
+ // Verify # of key value pairs can fit into kv_storage.
+ if (average_kv_byte_size > kMaxKVTotalByteSize / max_num_entries) {
+ return false;
+ }
+
+ // Verify init_num_buckets is 2's power. Requiring init_num_buckets to be 2^n
+ // guarantees that num_buckets will eventually grow to be exactly
+ // max_num_buckets since CalculateNumBucketsRequired rounds it up to 2^n.
+ if ((init_num_buckets & (init_num_buckets - 1)) != 0) {
+ return false;
+ }
+
+ return true;
+}
+
/* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
PersistentHashMap::Create(const Filesystem& filesystem,
- std::string_view base_dir, int32_t value_type_size,
- int32_t max_load_factor_percent,
- int32_t init_num_buckets) {
+ std::string_view base_dir, const Options& options) {
+ if (!options.IsValid()) {
+ return absl_ports::InvalidArgumentError(
+ "Invalid PersistentHashMap options");
+ }
+
if (!filesystem.FileExists(
GetMetadataFilePath(base_dir, kSubDirectory).c_str()) ||
!filesystem.FileExists(
@@ -195,11 +248,9 @@ PersistentHashMap::Create(const Filesystem& filesystem,
GetEntryStorageFilePath(base_dir, kSubDirectory).c_str()) ||
!filesystem.FileExists(GetKeyValueStorageFilePath(base_dir).c_str())) {
// TODO: erase all files if missing any.
- return InitializeNewFiles(filesystem, base_dir, value_type_size,
- max_load_factor_percent, init_num_buckets);
+ return InitializeNewFiles(filesystem, base_dir, options);
}
- return InitializeExistingFiles(filesystem, base_dir, value_type_size,
- max_load_factor_percent);
+ return InitializeExistingFiles(filesystem, base_dir, options);
}
PersistentHashMap::~PersistentHashMap() {
@@ -398,9 +449,7 @@ libtextclassifier3::StatusOr<Crc32> PersistentHashMap::ComputeChecksum() {
/* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem,
std::string_view base_dir,
- int32_t value_type_size,
- int32_t max_load_factor_percent,
- int32_t init_num_buckets) {
+ const Options& options) {
// Create directory.
const std::string dir_path = absl_ports::StrCat(base_dir, "/", kSubDirectory);
if (!filesystem.CreateDirectoryRecursively(dir_path.c_str())) {
@@ -408,32 +457,54 @@ PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem,
absl_ports::StrCat("Failed to create directory: ", dir_path));
}
- // Initialize 3 storages
+ int32_t max_num_buckets_required =
+ std::max(options.init_num_buckets,
+ CalculateNumBucketsRequired(options.max_num_entries,
+ options.max_load_factor_percent));
+
+ // Initialize bucket_storage
+ int32_t pre_mapping_mmap_size = sizeof(Bucket) * max_num_buckets_required;
+ int32_t max_file_size =
+ pre_mapping_mmap_size + FileBackedVector<Bucket>::Header::kHeaderSize;
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<FileBackedVector<Bucket>> bucket_storage,
FileBackedVector<Bucket>::Create(
filesystem, GetBucketStorageFilePath(base_dir, kSubDirectory),
- MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC));
+ MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size,
+ pre_mapping_mmap_size));
+
+ // Initialize entry_storage
+ pre_mapping_mmap_size = sizeof(Entry) * options.max_num_entries;
+ max_file_size =
+ pre_mapping_mmap_size + FileBackedVector<Entry>::Header::kHeaderSize;
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<FileBackedVector<Entry>> entry_storage,
FileBackedVector<Entry>::Create(
filesystem, GetEntryStorageFilePath(base_dir, kSubDirectory),
- MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC));
+ MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size,
+ pre_mapping_mmap_size));
+
+ // Initialize kv_storage
+ pre_mapping_mmap_size =
+ options.average_kv_byte_size * options.max_num_entries;
+ max_file_size =
+ pre_mapping_mmap_size + FileBackedVector<char>::Header::kHeaderSize;
ICING_ASSIGN_OR_RETURN(std::unique_ptr<FileBackedVector<char>> kv_storage,
FileBackedVector<char>::Create(
filesystem, GetKeyValueStorageFilePath(base_dir),
- MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC));
+ MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC,
+ max_file_size, pre_mapping_mmap_size));
// Initialize buckets.
- ICING_RETURN_IF_ERROR(
- bucket_storage->Set(/*idx=*/0, /*len=*/init_num_buckets, Bucket()));
+ ICING_RETURN_IF_ERROR(bucket_storage->Set(
+ /*idx=*/0, /*len=*/options.init_num_buckets, Bucket()));
ICING_RETURN_IF_ERROR(bucket_storage->PersistToDisk());
// Create and initialize new info
Info new_info;
new_info.version = kVersion;
- new_info.value_type_size = value_type_size;
- new_info.max_load_factor_percent = max_load_factor_percent;
+ new_info.value_type_size = options.value_type_size;
+ new_info.max_load_factor_percent = options.max_load_factor_percent;
new_info.num_deleted_entries = 0;
new_info.num_deleted_key_value_bytes = 0;
@@ -458,7 +529,7 @@ PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem,
/*file_offset=*/0, /*mmap_size=*/sizeof(Crcs) + sizeof(Info)));
return std::unique_ptr<PersistentHashMap>(new PersistentHashMap(
- filesystem, base_dir, std::move(metadata_mmapped_file),
+ filesystem, base_dir, options, std::move(metadata_mmapped_file),
std::move(bucket_storage), std::move(entry_storage),
std::move(kv_storage)));
}
@@ -466,8 +537,7 @@ PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem,
/* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem,
std::string_view base_dir,
- int32_t value_type_size,
- int32_t max_load_factor_percent) {
+ const Options& options) {
// Mmap the content of the crcs and info.
ICING_ASSIGN_OR_RETURN(
MemoryMappedFile metadata_mmapped_file,
@@ -477,21 +547,41 @@ PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem,
ICING_RETURN_IF_ERROR(metadata_mmapped_file.Remap(
/*file_offset=*/0, /*mmap_size=*/sizeof(Crcs) + sizeof(Info)));
- // Initialize 3 storages
+ int32_t max_num_buckets_required = CalculateNumBucketsRequired(
+ options.max_num_entries, options.max_load_factor_percent);
+
+ // Initialize bucket_storage
+ int32_t pre_mapping_mmap_size = sizeof(Bucket) * max_num_buckets_required;
+ int32_t max_file_size =
+ pre_mapping_mmap_size + FileBackedVector<Bucket>::Header::kHeaderSize;
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<FileBackedVector<Bucket>> bucket_storage,
FileBackedVector<Bucket>::Create(
filesystem, GetBucketStorageFilePath(base_dir, kSubDirectory),
- MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC));
+ MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size,
+ pre_mapping_mmap_size));
+
+ // Initialize entry_storage
+ pre_mapping_mmap_size = sizeof(Entry) * options.max_num_entries;
+ max_file_size =
+ pre_mapping_mmap_size + FileBackedVector<Entry>::Header::kHeaderSize;
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<FileBackedVector<Entry>> entry_storage,
FileBackedVector<Entry>::Create(
filesystem, GetEntryStorageFilePath(base_dir, kSubDirectory),
- MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC));
+ MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size,
+ pre_mapping_mmap_size));
+
+ // Initialize kv_storage
+ pre_mapping_mmap_size =
+ options.average_kv_byte_size * options.max_num_entries;
+ max_file_size =
+ pre_mapping_mmap_size + FileBackedVector<char>::Header::kHeaderSize;
ICING_ASSIGN_OR_RETURN(std::unique_ptr<FileBackedVector<char>> kv_storage,
FileBackedVector<char>::Create(
filesystem, GetKeyValueStorageFilePath(base_dir),
- MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC));
+ MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC,
+ max_file_size, pre_mapping_mmap_size));
Crcs* crcs_ptr = reinterpret_cast<Crcs*>(
metadata_mmapped_file.mutable_region() + Crcs::kFileOffset);
@@ -499,22 +589,38 @@ PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem,
metadata_mmapped_file.mutable_region() + Info::kFileOffset);
// Value type size should be consistent.
- if (value_type_size != info_ptr->value_type_size) {
+ if (options.value_type_size != info_ptr->value_type_size) {
return absl_ports::FailedPreconditionError("Incorrect value type size");
}
+ // Current # of entries should not exceed options.max_num_entries
+ // We compute max_file_size of 3 storages by options.max_num_entries. Since we
+ // won't recycle space of deleted entries (and key-value bytes), they're still
+ // occupying space in storages. Even if # of "active" entries doesn't exceed
+ // options.max_num_entries, the new kvp to be inserted still potentially
+ // exceeds max_file_size.
+ // Therefore, we should use entry_storage->num_elements() instead of # of
+ // "active" entries
+ // (i.e. entry_storage->num_elements() - info_ptr->num_deleted_entries) to
+ // check. This feature avoids storages being grown extremely large when there
+ // are many Delete() and Put() operations.
+ if (entry_storage->num_elements() > options.max_num_entries) {
+ return absl_ports::FailedPreconditionError(
+ "Current # of entries exceeds max num entries");
+ }
+
// Validate checksums of info and 3 storages.
ICING_RETURN_IF_ERROR(
ValidateChecksums(crcs_ptr, info_ptr, bucket_storage.get(),
entry_storage.get(), kv_storage.get()));
// Allow max_load_factor_percent_ change.
- if (max_load_factor_percent != info_ptr->max_load_factor_percent) {
+ if (options.max_load_factor_percent != info_ptr->max_load_factor_percent) {
ICING_VLOG(2) << "Changing max_load_factor_percent from "
<< info_ptr->max_load_factor_percent << " to "
- << max_load_factor_percent;
+ << options.max_load_factor_percent;
- info_ptr->max_load_factor_percent = max_load_factor_percent;
+ info_ptr->max_load_factor_percent = options.max_load_factor_percent;
crcs_ptr->component_crcs.info_crc = info_ptr->ComputeChecksum().Get();
crcs_ptr->all_crc = crcs_ptr->component_crcs.ComputeChecksum().Get();
ICING_RETURN_IF_ERROR(metadata_mmapped_file.PersistToDisk());
@@ -522,7 +628,7 @@ PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem,
auto persistent_hash_map =
std::unique_ptr<PersistentHashMap>(new PersistentHashMap(
- filesystem, base_dir, std::move(metadata_mmapped_file),
+ filesystem, base_dir, options, std::move(metadata_mmapped_file),
std::move(bucket_storage), std::move(entry_storage),
std::move(kv_storage)));
ICING_RETURN_IF_ERROR(
@@ -576,8 +682,17 @@ libtextclassifier3::Status PersistentHashMap::CopyEntryValue(
libtextclassifier3::Status PersistentHashMap::Insert(int32_t bucket_idx,
std::string_view key,
const void* value) {
- // If size() + 1 exceeds Entry::kMaxNumEntries, then return error.
- if (size() > Entry::kMaxNumEntries - 1) {
+ // If entry_storage_->num_elements() + 1 exceeds options_.max_num_entries,
+ // then return error.
+ // We compute max_file_size of 3 storages by options_.max_num_entries. Since
+ // we won't recycle space of deleted entries (and key-value bytes), they're
+ // still occupying space in storages. Even if # of "active" entries (i.e.
+ // size()) doesn't exceed options_.max_num_entries, the new kvp to be inserted
+ // still potentially exceeds max_file_size.
+ // Therefore, we should use entry_storage_->num_elements() instead of size()
+ // to check. This feature avoids storages being grown extremely large when
+ // there are many Delete() and Put() operations.
+ if (entry_storage_->num_elements() > options_.max_num_entries - 1) {
return absl_ports::ResourceExhaustedError("Cannot insert new entry");
}
diff --git a/icing/file/persistent-hash-map.h b/icing/file/persistent-hash-map.h
index ef3995c..57fa070 100644
--- a/icing/file/persistent-hash-map.h
+++ b/icing/file/persistent-hash-map.h
@@ -134,12 +134,10 @@ class PersistentHashMap {
// Bucket
class Bucket {
public:
- // Absolute max # of buckets allowed. Since max file size on Android is
- // 2^31-1, we can at most have ~2^29 buckets. To make it power of 2, round
- // it down to 2^28. Also since we're using FileBackedVector to store
- // buckets, add some static_asserts to ensure numbers here are compatible
- // with FileBackedVector.
- static constexpr int32_t kMaxNumBuckets = 1 << 28;
+ // Absolute max # of buckets allowed. Since we're using FileBackedVector to
+ // store buckets, add some static_asserts to ensure numbers here are
+ // compatible with FileBackedVector.
+ static constexpr int32_t kMaxNumBuckets = 1 << 24;
explicit Bucket(int32_t head_entry_index = Entry::kInvalidIndex)
: head_entry_index_(head_entry_index) {}
@@ -170,16 +168,14 @@ class PersistentHashMap {
// Entry
class Entry {
public:
- // Absolute max # of entries allowed. Since max file size on Android is
- // 2^31-1, we can at most have ~2^28 entries. To make it power of 2, round
- // it down to 2^27. Also since we're using FileBackedVector to store
- // entries, add some static_asserts to ensure numbers here are compatible
- // with FileBackedVector.
+ // Absolute max # of entries allowed. Since we're using FileBackedVector to
+ // store entries, add some static_asserts to ensure numbers here are
+ // compatible with FileBackedVector.
//
// Still the actual max # of entries are determined by key-value storage,
// since length of the key varies and affects # of actual key-value pairs
// that can be stored.
- static constexpr int32_t kMaxNumEntries = 1 << 27;
+ static constexpr int32_t kMaxNumEntries = 1 << 23;
static constexpr int32_t kMaxIndex = kMaxNumEntries - 1;
static constexpr int32_t kInvalidIndex = -1;
@@ -217,19 +213,64 @@ class PersistentHashMap {
"Max # of entries cannot fit into FileBackedVector");
// Key-value serialized type
- static constexpr int32_t kMaxKVTotalByteSize =
- (FileBackedVector<char>::kMaxFileSize -
- FileBackedVector<char>::Header::kHeaderSize) /
- FileBackedVector<char>::kElementTypeSize;
+ static constexpr int32_t kMaxKVTotalByteSize = 1 << 28;
static constexpr int32_t kMaxKVIndex = kMaxKVTotalByteSize - 1;
static constexpr int32_t kInvalidKVIndex = -1;
static_assert(sizeof(char) == FileBackedVector<char>::kElementTypeSize,
"Char type size is inconsistent with FileBackedVector element "
"type size");
+ static_assert(kMaxKVTotalByteSize <=
+ FileBackedVector<char>::kMaxFileSize -
+ FileBackedVector<char>::Header::kHeaderSize,
+ "Max total byte size of key value pairs cannot fit into "
+ "FileBackedVector");
+
+ static constexpr int32_t kMaxValueTypeSize = 1 << 10;
+
+ struct Options {
+ static constexpr int32_t kDefaultMaxLoadFactorPercent = 100;
+ static constexpr int32_t kDefaultAverageKVByteSize = 32;
+ static constexpr int32_t kDefaultInitNumBuckets = 1 << 13;
+
+ explicit Options(
+ int32_t value_type_size_in,
+ int32_t max_num_entries_in = Entry::kMaxNumEntries,
+ int32_t max_load_factor_percent_in = kDefaultMaxLoadFactorPercent,
+ int32_t average_kv_byte_size_in = kDefaultAverageKVByteSize,
+ int32_t init_num_buckets_in = kDefaultInitNumBuckets)
+ : value_type_size(value_type_size_in),
+ max_num_entries(max_num_entries_in),
+ max_load_factor_percent(max_load_factor_percent_in),
+ average_kv_byte_size(average_kv_byte_size_in),
+ init_num_buckets(init_num_buckets_in) {}
+
+ bool IsValid() const;
+
+ // (fixed) size of the serialized value type for hash map.
+ int32_t value_type_size;
+
+ // Max # of entries, default Entry::kMaxNumEntries.
+ int32_t max_num_entries;
+
+ // Percentage of the max loading for the hash map. If load_factor_percent
+ // exceeds max_load_factor_percent, then rehash will be invoked (and # of
+ // buckets will be doubled).
+ // load_factor_percent = 100 * num_keys / num_buckets
+ //
+ // Note that load_factor_percent exceeding 100 is considered valid.
+ int32_t max_load_factor_percent;
+
+ // Average byte size of a key value pair. It is used to estimate kv_storage_
+ // pre_mapping_mmap_size.
+ int32_t average_kv_byte_size;
+
+ // Initial # of buckets for the persistent hash map. It should be 2's power.
+ // It is used when creating new persistent hash map and ignored when
+ // creating the instance from existing files.
+ int32_t init_num_buckets;
+ };
static constexpr int32_t kVersion = 1;
- static constexpr int32_t kDefaultMaxLoadFactorPercent = 100;
- static constexpr int32_t kDefaultInitNumBuckets = 8192;
static constexpr std::string_view kFilePrefix = "persistent_hash_map";
// Only metadata, bucket, entry files are stored under this sub-directory, for
@@ -243,28 +284,17 @@ class PersistentHashMap {
// sub-directory and files to be stored. If base_dir doesn't exist,
// then PersistentHashMap will automatically create it. If files
// exist, then it will initialize the hash map from existing files.
- // value_type_size: (fixed) size of the serialized value type for hash map.
- // max_load_factor_percent: percentage of the max loading for the hash map.
- // load_factor_percent = 100 * num_keys / num_buckets
- // If load_factor_percent exceeds
- // max_load_factor_percent, then rehash will be
- // invoked (and # of buckets will be doubled).
- // Note that load_factor_percent exceeding 100 is
- // considered valid.
- // init_num_buckets: initial # of buckets for the persistent hash map. It is
- // used when creating new persistent hash map and ignored
- // when creating the instance from existing files.
+ // options: Options instance.
//
// Returns:
+ // INVALID_ARGUMENT_ERROR if any value in options is invalid.
// FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored
// checksum.
// INTERNAL_ERROR on I/O errors.
// Any FileBackedVector errors.
static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
Create(const Filesystem& filesystem, std::string_view base_dir,
- int32_t value_type_size,
- int32_t max_load_factor_percent = kDefaultMaxLoadFactorPercent,
- int32_t init_num_buckets = kDefaultInitNumBuckets);
+ const Options& options);
~PersistentHashMap();
@@ -275,7 +305,7 @@ class PersistentHashMap {
//
// Returns:
// OK on success
- // RESOURCE_EXHAUSTED_ERROR if # of entries reach kMaxNumEntries
+ // RESOURCE_EXHAUSTED_ERROR if # of entries reach options_.max_num_entries
// INVALID_ARGUMENT_ERROR if the key is invalid (i.e. contains '\0')
// INTERNAL_ERROR on I/O error or any data inconsistency
// Any FileBackedVector errors
@@ -373,12 +403,13 @@ class PersistentHashMap {
explicit PersistentHashMap(
const Filesystem& filesystem, std::string_view base_dir,
- MemoryMappedFile&& metadata_mmapped_file,
+ const Options& options, MemoryMappedFile&& metadata_mmapped_file,
std::unique_ptr<FileBackedVector<Bucket>> bucket_storage,
std::unique_ptr<FileBackedVector<Entry>> entry_storage,
std::unique_ptr<FileBackedVector<char>> kv_storage)
: filesystem_(&filesystem),
base_dir_(base_dir),
+ options_(options),
metadata_mmapped_file_(std::make_unique<MemoryMappedFile>(
std::move(metadata_mmapped_file))),
bucket_storage_(std::move(bucket_storage)),
@@ -387,13 +418,11 @@ class PersistentHashMap {
static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
InitializeNewFiles(const Filesystem& filesystem, std::string_view base_dir,
- int32_t value_type_size, int32_t max_load_factor_percent,
- int32_t init_num_buckets);
+ const Options& options);
static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
InitializeExistingFiles(const Filesystem& filesystem,
- std::string_view base_dir, int32_t value_type_size,
- int32_t max_load_factor_percent);
+ std::string_view base_dir, const Options& options);
// Find the index of the target entry (that contains the key) from a bucket
// (specified by bucket index). Also return the previous entry index, since
@@ -457,6 +486,8 @@ class PersistentHashMap {
const Filesystem* filesystem_;
std::string base_dir_;
+ Options options_;
+
std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_;
// Storages
diff --git a/icing/file/persistent-hash-map_test.cc b/icing/file/persistent-hash-map_test.cc
index 8024388..8fde4a8 100644
--- a/icing/file/persistent-hash-map_test.cc
+++ b/icing/file/persistent-hash-map_test.cc
@@ -24,20 +24,11 @@
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
-#include "icing/file/file-backed-vector.h"
#include "icing/file/filesystem.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/tmp-directory.h"
#include "icing/util/crc32.h"
-namespace icing {
-namespace lib {
-
-namespace {
-
-static constexpr int32_t kCorruptedValueOffset = 3;
-static constexpr int32_t kTestInitNumBuckets = 1;
-
using ::testing::Contains;
using ::testing::Eq;
using ::testing::Gt;
@@ -51,10 +42,19 @@ using ::testing::Pointee;
using ::testing::SizeIs;
using ::testing::UnorderedElementsAre;
+namespace icing {
+namespace lib {
+
+namespace {
+
using Bucket = PersistentHashMap::Bucket;
using Crcs = PersistentHashMap::Crcs;
using Entry = PersistentHashMap::Entry;
using Info = PersistentHashMap::Info;
+using Options = PersistentHashMap::Options;
+
+static constexpr int32_t kCorruptedValueOffset = 3;
+static constexpr int32_t kTestInitNumBuckets = 1;
class PersistentHashMapTest : public ::testing::Test {
protected:
@@ -95,10 +95,110 @@ class PersistentHashMapTest : public ::testing::Test {
std::string base_dir_;
};
+TEST_F(PersistentHashMapTest, OptionsInvalidValueTypeSize) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+ ASSERT_TRUE(options.IsValid());
+
+ options.value_type_size = -1;
+ EXPECT_FALSE(options.IsValid());
+
+ options.value_type_size = 0;
+ EXPECT_FALSE(options.IsValid());
+
+ options.value_type_size = PersistentHashMap::kMaxValueTypeSize + 1;
+ EXPECT_FALSE(options.IsValid());
+}
+
+TEST_F(PersistentHashMapTest, OptionsInvalidMaxNumEntries) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+ ASSERT_TRUE(options.IsValid());
+
+ options.max_num_entries = -1;
+ EXPECT_FALSE(options.IsValid());
+
+ options.max_num_entries = 0;
+ EXPECT_FALSE(options.IsValid());
+
+ options.max_num_entries = Entry::kMaxNumEntries + 1;
+ EXPECT_FALSE(options.IsValid());
+}
+
+TEST_F(PersistentHashMapTest, OptionsInvalidMaxLoadFactorPercent) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+ ASSERT_TRUE(options.IsValid());
+
+ options.max_load_factor_percent = -1;
+ EXPECT_FALSE(options.IsValid());
+
+ options.max_load_factor_percent = 0;
+ EXPECT_FALSE(options.IsValid());
+}
+
+TEST_F(PersistentHashMapTest, OptionsInvalidAverageKVByteSize) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+ ASSERT_TRUE(options.IsValid());
+
+ options.average_kv_byte_size = -1;
+ EXPECT_FALSE(options.IsValid());
+
+ options.average_kv_byte_size = 0;
+ EXPECT_FALSE(options.IsValid());
+}
+
+TEST_F(PersistentHashMapTest, OptionsInvalidInitNumBuckets) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+ ASSERT_TRUE(options.IsValid());
+
+ options.init_num_buckets = -1;
+ EXPECT_FALSE(options.IsValid());
+
+ options.init_num_buckets = 0;
+ EXPECT_FALSE(options.IsValid());
+
+ options.init_num_buckets = Bucket::kMaxNumBuckets + 1;
+ EXPECT_FALSE(options.IsValid());
+
+ // not 2's power
+ options.init_num_buckets = 3;
+ EXPECT_FALSE(options.IsValid());
+}
+
+TEST_F(PersistentHashMapTest, OptionsNumBucketsRequiredExceedsMaxNumBuckets) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+ ASSERT_TRUE(options.IsValid());
+
+ options.max_num_entries = Entry::kMaxNumEntries;
+ options.max_load_factor_percent = 30;
+ EXPECT_FALSE(options.IsValid());
+}
+
+TEST_F(PersistentHashMapTest,
+ OptionsEstimatedNumKeyValuePairExceedsStorageMaxSize) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+ ASSERT_TRUE(options.IsValid());
+
+ options.max_num_entries = 1 << 20;
+ options.average_kv_byte_size = 1 << 20;
+ ASSERT_THAT(static_cast<int64_t>(options.max_num_entries) *
+ options.average_kv_byte_size,
+ Gt(PersistentHashMap::kMaxKVTotalByteSize));
+ EXPECT_FALSE(options.IsValid());
+}
+
TEST_F(PersistentHashMapTest, InvalidBaseDir) {
- EXPECT_THAT(PersistentHashMap::Create(filesystem_, "/dev/null",
- /*value_type_size=*/sizeof(int)),
- StatusIs(libtextclassifier3::StatusCode::INTERNAL));
+ EXPECT_THAT(
+ PersistentHashMap::Create(filesystem_, "/dev/null",
+ Options(/*value_type_size_in=*/sizeof(int))),
+ StatusIs(libtextclassifier3::StatusCode::INTERNAL));
+}
+
+TEST_F(PersistentHashMapTest, CreateWithInvalidOptionsShouldFail) {
+ Options invalid_options(/*value_type_size_in=*/-1);
+ ASSERT_FALSE(invalid_options.IsValid());
+
+ EXPECT_THAT(
+ PersistentHashMap::Create(filesystem_, base_dir_, invalid_options),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
TEST_F(PersistentHashMapTest, InitializeNewFiles) {
@@ -107,7 +207,7 @@ TEST_F(PersistentHashMapTest, InitializeNewFiles) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ Options(/*value_type_size_in=*/sizeof(int))));
EXPECT_THAT(persistent_hash_map, Pointee(IsEmpty()));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
@@ -128,7 +228,7 @@ TEST_F(PersistentHashMapTest, InitializeNewFiles) {
EXPECT_THAT(info.version, Eq(PersistentHashMap::kVersion));
EXPECT_THAT(info.value_type_size, Eq(sizeof(int)));
EXPECT_THAT(info.max_load_factor_percent,
- Eq(PersistentHashMap::kDefaultMaxLoadFactorPercent));
+ Eq(Options::kDefaultMaxLoadFactorPercent));
EXPECT_THAT(info.num_deleted_entries, Eq(0));
EXPECT_THAT(info.num_deleted_key_value_bytes, Eq(0));
@@ -153,52 +253,81 @@ TEST_F(PersistentHashMapTest, InitializeNewFiles) {
.Get()));
}
-TEST_F(PersistentHashMapTest, InitializeNewFilesWithCustomInitBucketSize) {
+TEST_F(PersistentHashMapTest, InitializeNewFilesWithCustomInitNumBuckets) {
+ int custom_init_num_buckets = 128;
+
// Create new persistent hash map
- int custom_init_bucket_size = 123;
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- custom_init_bucket_size));
- EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(custom_init_bucket_size));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/custom_init_num_buckets)));
+ EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(custom_init_num_buckets));
+}
+
+TEST_F(PersistentHashMapTest,
+ InitializeNewFilesWithInitNumBucketsSmallerThanNumBucketsRequired) {
+ int init_num_buckets = 65536;
+
+ // Create new persistent hash map
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PersistentHashMap> persistent_hash_map,
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/1,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/init_num_buckets)));
+ EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(init_num_buckets));
}
-TEST_F(PersistentHashMapTest, InitBucketSizeShouldNotAffectExistingFiles) {
- int init_bucket_size1 = 4;
+TEST_F(PersistentHashMapTest, InitNumBucketsShouldNotAffectExistingFiles) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
+ int original_init_num_buckets = 4;
{
+ options.init_num_buckets = original_init_num_buckets;
+ ASSERT_TRUE(options.IsValid());
+
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(
- filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- init_bucket_size1));
- EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(init_bucket_size1));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
+ EXPECT_THAT(persistent_hash_map->num_buckets(),
+ Eq(original_init_num_buckets));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
}
- int init_bucket_size2 = 8;
+ // Set new init_num_buckets.
+ options.init_num_buckets = 8;
+ ASSERT_TRUE(options.IsValid());
+
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- init_bucket_size2));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
// # of buckets should still be the original value.
- EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(init_bucket_size1));
+ EXPECT_THAT(persistent_hash_map->num_buckets(),
+ Eq(original_init_num_buckets));
}
TEST_F(PersistentHashMapTest,
- TestInitializationFailsWithoutPersistToDiskOrDestruction) {
+ InitializationShouldFailWithoutPersistToDiskOrDestruction) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
// Put some key value pairs.
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
@@ -214,17 +343,17 @@ TEST_F(PersistentHashMapTest,
// Without calling PersistToDisk, checksums will not be recomputed or synced
// to disk, so initializing another instance on the same files should fail.
- EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)),
+ EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_, options),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
-TEST_F(PersistentHashMapTest, TestInitializationSucceedsWithPersistToDisk) {
+TEST_F(PersistentHashMapTest, InitializationShouldSucceedWithPersistToDisk) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map1,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
// Put some key value pairs.
ICING_ASSERT_OK(persistent_hash_map1->Put("a", Serialize(1).data()));
@@ -245,20 +374,20 @@ TEST_F(PersistentHashMapTest, TestInitializationSucceedsWithPersistToDisk) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map2,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
EXPECT_THAT(persistent_hash_map2, Pointee(SizeIs(2)));
EXPECT_THAT(GetValueByKey(persistent_hash_map2.get(), "a"), IsOkAndHolds(1));
EXPECT_THAT(GetValueByKey(persistent_hash_map2.get(), "b"), IsOkAndHolds(2));
}
-TEST_F(PersistentHashMapTest, TestInitializationSucceedsAfterDestruction) {
+TEST_F(PersistentHashMapTest, InitializationShouldSucceedAfterDestruction) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
{
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data()));
ICING_ASSERT_OK(persistent_hash_map->Put("c", Serialize(3).data()));
@@ -278,8 +407,7 @@ TEST_F(PersistentHashMapTest, TestInitializationSucceedsAfterDestruction) {
// we should be able to get the same contents.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2)));
EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1));
EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2));
@@ -293,7 +421,7 @@ TEST_F(PersistentHashMapTest,
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ Options(/*value_type_size_in=*/sizeof(int))));
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
@@ -305,7 +433,8 @@ TEST_F(PersistentHashMapTest,
ASSERT_THAT(sizeof(char), Not(Eq(sizeof(int))));
libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
persistent_hash_map_or = PersistentHashMap::Create(
- filesystem_, base_dir_, /*value_type_size=*/sizeof(char));
+ filesystem_, base_dir_,
+ Options(/*value_type_size_in=*/sizeof(char)));
EXPECT_THAT(persistent_hash_map_or,
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(persistent_hash_map_or.status().error_message(),
@@ -313,13 +442,55 @@ TEST_F(PersistentHashMapTest,
}
}
+TEST_F(PersistentHashMapTest,
+ InitializeExistingFilesWithMaxNumEntriesSmallerThanSizeShouldFail) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
+ // Create new persistent hash map
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PersistentHashMap> persistent_hash_map,
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
+ ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
+ ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data()));
+
+ ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
+
+ {
+ // Attempt to create the persistent hash map with max num entries smaller
+ // than the current size. This should fail.
+ options.max_num_entries = 1;
+ ASSERT_TRUE(options.IsValid());
+
+ EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_, options),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ }
+
+ // Delete 1 kvp.
+ ICING_ASSERT_OK(persistent_hash_map->Delete("a"));
+ ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(1)));
+ ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
+
+ {
+ // Attempt to create the persistent hash map with max num entries:
+ // - Not smaller than current # of active kvps.
+ // - Smaller than # of all inserted kvps (regardless of activeness).
+ // This should fail.
+ options.max_num_entries = 1;
+ ASSERT_TRUE(options.IsValid());
+
+ EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_, options),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ }
+}
+
TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongAllCrc) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
{
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
@@ -345,8 +516,8 @@ TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongAllCrc) {
// Attempt to create the persistent hash map with metadata containing
// corrupted all_crc. This should fail.
libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
- persistent_hash_map_or = PersistentHashMap::Create(
- filesystem_, base_dir_, /*value_type_size=*/sizeof(int));
+ persistent_hash_map_or =
+ PersistentHashMap::Create(filesystem_, base_dir_, options);
EXPECT_THAT(persistent_hash_map_or,
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(persistent_hash_map_or.status().error_message(),
@@ -356,12 +527,13 @@ TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongAllCrc) {
TEST_F(PersistentHashMapTest,
InitializeExistingFilesWithCorruptedInfoShouldFail) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
{
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
@@ -386,8 +558,8 @@ TEST_F(PersistentHashMapTest,
// Attempt to create the persistent hash map with info that doesn't match
// its checksum and confirm that it fails.
libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
- persistent_hash_map_or = PersistentHashMap::Create(
- filesystem_, base_dir_, /*value_type_size=*/sizeof(int));
+ persistent_hash_map_or =
+ PersistentHashMap::Create(filesystem_, base_dir_, options);
EXPECT_THAT(persistent_hash_map_or,
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(persistent_hash_map_or.status().error_message(),
@@ -397,12 +569,13 @@ TEST_F(PersistentHashMapTest,
TEST_F(PersistentHashMapTest,
InitializeExistingFilesWithWrongBucketStorageCrc) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
{
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
@@ -430,8 +603,8 @@ TEST_F(PersistentHashMapTest,
// Attempt to create the persistent hash map with metadata containing
// corrupted bucket_storage_crc. This should fail.
libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
- persistent_hash_map_or = PersistentHashMap::Create(
- filesystem_, base_dir_, /*value_type_size=*/sizeof(int));
+ persistent_hash_map_or =
+ PersistentHashMap::Create(filesystem_, base_dir_, options);
EXPECT_THAT(persistent_hash_map_or,
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(
@@ -441,12 +614,13 @@ TEST_F(PersistentHashMapTest,
}
TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongEntryStorageCrc) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
{
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
@@ -474,8 +648,8 @@ TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongEntryStorageCrc) {
// Attempt to create the persistent hash map with metadata containing
// corrupted entry_storage_crc. This should fail.
libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
- persistent_hash_map_or = PersistentHashMap::Create(
- filesystem_, base_dir_, /*value_type_size=*/sizeof(int));
+ persistent_hash_map_or =
+ PersistentHashMap::Create(filesystem_, base_dir_, options);
EXPECT_THAT(persistent_hash_map_or,
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(persistent_hash_map_or.status().error_message(),
@@ -485,12 +659,13 @@ TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongEntryStorageCrc) {
TEST_F(PersistentHashMapTest,
InitializeExistingFilesWithWrongKeyValueStorageCrc) {
+ Options options(/*value_type_size_in=*/sizeof(int));
+
{
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
@@ -518,8 +693,8 @@ TEST_F(PersistentHashMapTest,
// Attempt to create the persistent hash map with metadata containing
// corrupted kv_storage_crc. This should fail.
libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>>
- persistent_hash_map_or = PersistentHashMap::Create(
- filesystem_, base_dir_, /*value_type_size=*/sizeof(int));
+ persistent_hash_map_or =
+ PersistentHashMap::Create(filesystem_, base_dir_, options);
EXPECT_THAT(persistent_hash_map_or,
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(
@@ -530,15 +705,18 @@ TEST_F(PersistentHashMapTest,
TEST_F(PersistentHashMapTest,
InitializeExistingFilesAllowDifferentMaxLoadFactorPercent) {
+ Options options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets);
+
{
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(
- filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data()));
@@ -549,18 +727,19 @@ TEST_F(PersistentHashMapTest,
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
}
- int32_t new_max_load_factor_percent = 200;
{
- ASSERT_THAT(new_max_load_factor_percent,
- Not(Eq(PersistentHashMap::kDefaultMaxLoadFactorPercent)));
+ // Set new max_load_factor_percent.
+ options.max_load_factor_percent = 200;
+ ASSERT_TRUE(options.IsValid());
+ ASSERT_THAT(options.max_load_factor_percent,
+ Not(Eq(Options::kDefaultMaxLoadFactorPercent)));
+
// Attempt to create the persistent hash map with different max load factor
// percent. This should succeed and metadata should be modified correctly.
// Also verify all entries should remain unchanged.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- new_max_load_factor_percent));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2)));
EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1));
@@ -578,17 +757,15 @@ TEST_F(PersistentHashMapTest,
Info info;
ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &info, sizeof(Info),
Info::kFileOffset));
- EXPECT_THAT(info.max_load_factor_percent, Eq(new_max_load_factor_percent));
+ EXPECT_THAT(info.max_load_factor_percent,
+ Eq(options.max_load_factor_percent));
// Also should update crcs correctly. We test it by creating instance again
// and make sure it won't get corrupted crcs/info errors.
{
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- new_max_load_factor_percent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
}
@@ -596,17 +773,20 @@ TEST_F(PersistentHashMapTest,
TEST_F(PersistentHashMapTest,
InitializeExistingFilesWithDifferentMaxLoadFactorPercentShouldRehash) {
+ Options options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets);
+
double prev_loading_percent;
int prev_num_buckets;
{
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(
- filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data()));
ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data()));
ICING_ASSERT_OK(persistent_hash_map->Put("c", Serialize(3).data()));
@@ -620,47 +800,47 @@ TEST_F(PersistentHashMapTest,
persistent_hash_map->num_buckets();
prev_num_buckets = persistent_hash_map->num_buckets();
ASSERT_THAT(prev_loading_percent,
- Not(Gt(PersistentHashMap::kDefaultMaxLoadFactorPercent)));
+ Not(Gt(Options::kDefaultMaxLoadFactorPercent)));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
}
- int32_t greater_max_load_factor_percent = 150;
{
- ASSERT_THAT(greater_max_load_factor_percent, Gt(prev_loading_percent));
+ // Set greater max_load_factor_percent.
+ options.max_load_factor_percent = 150;
+ ASSERT_TRUE(options.IsValid());
+ ASSERT_THAT(options.max_load_factor_percent, Gt(prev_loading_percent));
+
// Attempt to create the persistent hash map with max load factor greater
// than previous loading. There should be no rehashing and # of buckets
// should remain the same.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- greater_max_load_factor_percent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(prev_num_buckets));
ICING_ASSERT_OK(persistent_hash_map->PersistToDisk());
}
- int32_t smaller_max_load_factor_percent = 25;
{
- ASSERT_THAT(smaller_max_load_factor_percent, Lt(prev_loading_percent));
+ // Set smaller max_load_factor_percent.
+ options.max_load_factor_percent = 50;
+ ASSERT_TRUE(options.IsValid());
+ ASSERT_THAT(options.max_load_factor_percent, Lt(prev_loading_percent));
+
// Attempt to create the persistent hash map with max load factor smaller
// than previous loading. There should be rehashing since the loading
// exceeds the limit.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- smaller_max_load_factor_percent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(filesystem_, base_dir_, options));
// After changing max_load_factor_percent, there should be rehashing and the
// new loading should not be greater than the new max load factor.
EXPECT_THAT(persistent_hash_map->size() * 100.0 /
persistent_hash_map->num_buckets(),
- Not(Gt(smaller_max_load_factor_percent)));
+ Not(Gt(options.max_load_factor_percent)));
EXPECT_THAT(persistent_hash_map->num_buckets(), Not(Eq(prev_num_buckets)));
EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1));
@@ -675,10 +855,15 @@ TEST_F(PersistentHashMapTest, PutAndGet) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
EXPECT_THAT(persistent_hash_map, Pointee(IsEmpty()));
EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"),
@@ -706,10 +891,15 @@ TEST_F(PersistentHashMapTest, PutShouldOverwriteValueIfKeyExists) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ICING_ASSERT_OK(
persistent_hash_map->Put("default-google.com", Serialize(100).data()));
@@ -734,10 +924,15 @@ TEST_F(PersistentHashMapTest, ShouldRehash) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
int original_num_buckets = persistent_hash_map->num_buckets();
// Insert 100 key value pairs. There should be rehashing so the loading of
@@ -749,7 +944,7 @@ TEST_F(PersistentHashMapTest, ShouldRehash) {
EXPECT_THAT(persistent_hash_map->size() * 100.0 /
persistent_hash_map->num_buckets(),
- Not(Gt(PersistentHashMap::kDefaultMaxLoadFactorPercent)));
+ Not(Gt(Options::kDefaultMaxLoadFactorPercent)));
}
EXPECT_THAT(persistent_hash_map->num_buckets(),
Not(Eq(original_num_buckets)));
@@ -765,10 +960,15 @@ TEST_F(PersistentHashMapTest, GetOrPutShouldPutIfKeyDoesNotExist) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"),
StatusIs(libtextclassifier3::StatusCode::NOT_FOUND));
@@ -786,10 +986,15 @@ TEST_F(PersistentHashMapTest, GetOrPutShouldGetIfKeyExists) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ASSERT_THAT(
persistent_hash_map->Put("default-google.com", Serialize(1).data()),
@@ -810,10 +1015,15 @@ TEST_F(PersistentHashMapTest, Delete) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
// Delete a non-existing key should get NOT_FOUND error
EXPECT_THAT(persistent_hash_map->Delete("default-google.com"),
@@ -856,10 +1066,15 @@ TEST_F(PersistentHashMapTest, DeleteMultiple) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
std::unordered_map<std::string, int> existing_keys;
std::unordered_set<std::string> deleted_keys;
@@ -909,10 +1124,14 @@ TEST_F(PersistentHashMapTest, DeleteBucketHeadElement) {
// Preventing rehashing makes it much easier to test collisions.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- /*max_load_factor_percent=*/1000,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/1000,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ICING_ASSERT_OK(
persistent_hash_map->Put("default-google.com-0", Serialize(0).data()));
@@ -943,10 +1162,14 @@ TEST_F(PersistentHashMapTest, DeleteBucketIntermediateElement) {
// Preventing rehashing makes it much easier to test collisions.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- /*max_load_factor_percent=*/1000,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/1000,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ICING_ASSERT_OK(
persistent_hash_map->Put("default-google.com-0", Serialize(0).data()));
@@ -976,10 +1199,14 @@ TEST_F(PersistentHashMapTest, DeleteBucketTailElement) {
// Preventing rehashing makes it much easier to test collisions.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- /*max_load_factor_percent=*/1000,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/1000,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ICING_ASSERT_OK(
persistent_hash_map->Put("default-google.com-0", Serialize(0).data()));
@@ -1010,10 +1237,14 @@ TEST_F(PersistentHashMapTest, DeleteBucketOnlySingleElement) {
// Preventing rehashing makes it much easier to test collisions.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- /*max_load_factor_percent=*/1000,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/1000,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ICING_ASSERT_OK(
persistent_hash_map->Put("default-google.com", Serialize(100).data()));
@@ -1026,12 +1257,48 @@ TEST_F(PersistentHashMapTest, DeleteBucketOnlySingleElement) {
StatusIs(libtextclassifier3::StatusCode::NOT_FOUND));
}
+TEST_F(PersistentHashMapTest, OperationsWhenReachingMaxNumEntries) {
+ // Create new persistent hash map
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PersistentHashMap> persistent_hash_map,
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/1,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/1)));
+
+ ICING_ASSERT_OK(
+ persistent_hash_map->Put("default-google.com", Serialize(100).data()));
+ ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(1)));
+
+ // Put new key should fail.
+ EXPECT_THAT(
+ persistent_hash_map->Put("default-youtube.com", Serialize(50).data()),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+ // Modify existing key should succeed.
+ EXPECT_THAT(
+ persistent_hash_map->Put("default-google.com", Serialize(200).data()),
+ IsOk());
+
+ // Put after delete should still fail. See the comment in
+ // PersistentHashMap::Insert for more details.
+ ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com"));
+ ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(0)));
+ EXPECT_THAT(
+ persistent_hash_map->Put("default-youtube.com", Serialize(50).data()),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
TEST_F(PersistentHashMapTest, ShouldFailIfKeyContainsTerminationCharacter) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int)));
+ Options(/*value_type_size_in=*/sizeof(int))));
const char invalid_key[] = "a\0bc";
std::string_view invalid_key_view(invalid_key, 4);
@@ -1051,10 +1318,15 @@ TEST_F(PersistentHashMapTest, EmptyHashMapIterator) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
EXPECT_FALSE(persistent_hash_map->GetIterator().Advance());
}
@@ -1063,10 +1335,15 @@ TEST_F(PersistentHashMapTest, Iterator) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
std::unordered_map<std::string, int> kvps;
// Insert 100 key value pairs
@@ -1085,10 +1362,15 @@ TEST_F(PersistentHashMapTest, IteratorAfterDeletingFirstKeyValuePair) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ICING_ASSERT_OK(
persistent_hash_map->Put("default-google.com-0", Serialize(0).data()));
@@ -1109,10 +1391,15 @@ TEST_F(PersistentHashMapTest, IteratorAfterDeletingIntermediateKeyValuePair) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ICING_ASSERT_OK(
persistent_hash_map->Put("default-google.com-0", Serialize(0).data()));
@@ -1133,10 +1420,15 @@ TEST_F(PersistentHashMapTest, IteratorAfterDeletingLastKeyValuePair) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ICING_ASSERT_OK(
persistent_hash_map->Put("default-google.com-0", Serialize(0).data()));
@@ -1157,10 +1449,15 @@ TEST_F(PersistentHashMapTest, IteratorAfterDeletingAllKeyValuePairs) {
// Create new persistent hash map
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PersistentHashMap> persistent_hash_map,
- PersistentHashMap::Create(filesystem_, base_dir_,
- /*value_type_size=*/sizeof(int),
- PersistentHashMap::kDefaultMaxLoadFactorPercent,
- kTestInitNumBuckets));
+ PersistentHashMap::Create(
+ filesystem_, base_dir_,
+ Options(
+ /*value_type_size_in=*/sizeof(int),
+ /*max_num_entries_in=*/Entry::kMaxNumEntries,
+ /*max_load_factor_percent_in=*/
+ Options::kDefaultMaxLoadFactorPercent,
+ /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize,
+ /*init_num_buckets_in=*/kTestInitNumBuckets)));
ICING_ASSERT_OK(
persistent_hash_map->Put("default-google.com-0", Serialize(0).data()));
diff --git a/icing/file/posting_list/posting-list-accessor.cc b/icing/file/posting_list/posting-list-accessor.cc
new file mode 100644
index 0000000..00f4417
--- /dev/null
+++ b/icing/file/posting_list/posting-list-accessor.cc
@@ -0,0 +1,110 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/file/posting_list/posting-list-accessor.h"
+
+#include <cstdint>
+#include <memory>
+
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/index-block.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+void PostingListAccessor::FlushPreexistingPostingList() {
+ if (preexisting_posting_list_->block.max_num_posting_lists() == 1) {
+ // If this is a max-sized posting list, then just keep track of the id for
+ // chaining. It'll be flushed to disk when preexisting_posting_list_ is
+ // destructed.
+ prev_block_identifier_ = preexisting_posting_list_->id;
+ } else {
+ // If this is NOT a max-sized posting list, then our data have outgrown this
+ // particular posting list. Move the data into the in-memory posting list
+ // and free this posting list.
+ //
+ // Move will always succeed since posting_list_buffer_ is max_pl_bytes.
+ GetSerializer()->MoveFrom(/*dst=*/&posting_list_buffer_,
+ /*src=*/&preexisting_posting_list_->posting_list);
+
+ // Now that all the contents of this posting list have been copied, there's
+ // no more use for it. Make it available to be used for another posting
+ // list.
+ storage_->FreePostingList(std::move(*preexisting_posting_list_));
+ }
+ preexisting_posting_list_.reset();
+}
+
+libtextclassifier3::Status PostingListAccessor::FlushInMemoryPostingList() {
+ // We exceeded max_pl_bytes(). Need to flush posting_list_buffer_ and update
+ // the chain.
+ uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes(
+ storage_->block_size(), GetSerializer()->GetDataTypeBytes());
+ ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
+ storage_->AllocatePostingList(max_posting_list_bytes));
+ holder.block.set_next_block_index(prev_block_identifier_.block_index());
+ prev_block_identifier_ = holder.id;
+ return GetSerializer()->MoveFrom(/*dst=*/&holder.posting_list,
+ /*src=*/&posting_list_buffer_);
+}
+
+PostingListAccessor::FinalizeResult PostingListAccessor::Finalize() && {
+ if (preexisting_posting_list_ != nullptr) {
+ // Our data are already in an existing posting list. Nothing else to do, but
+ // return its id.
+ return FinalizeResult(libtextclassifier3::Status::OK,
+ preexisting_posting_list_->id);
+ }
+ if (GetSerializer()->GetBytesUsed(&posting_list_buffer_) <= 0) {
+ return FinalizeResult(absl_ports::InvalidArgumentError(
+ "Can't finalize an empty PostingListAccessor. "
+ "There's nothing to Finalize!"),
+ PostingListIdentifier::kInvalid);
+ }
+ uint32_t posting_list_bytes =
+ GetSerializer()->GetMinPostingListSizeToFit(&posting_list_buffer_);
+ if (prev_block_identifier_.is_valid()) {
+ posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes(
+ storage_->block_size(), GetSerializer()->GetDataTypeBytes());
+ }
+ auto holder_or = storage_->AllocatePostingList(posting_list_bytes);
+ if (!holder_or.ok()) {
+ return FinalizeResult(std::move(holder_or).status(),
+ prev_block_identifier_);
+ }
+ PostingListHolder holder = std::move(holder_or).ValueOrDie();
+ if (prev_block_identifier_.is_valid()) {
+ holder.block.set_next_block_index(prev_block_identifier_.block_index());
+ }
+
+ // Move to allocated area. This should never actually return an error. We know
+ // that editor.posting_list() is valid because it wouldn't have successfully
+ // returned by AllocatePostingList if it wasn't. We know posting_list_buffer_
+ // is valid because we created it in-memory. And finally, we know that the
+ // data from posting_list_buffer_ will fit in editor.posting_list() because we
+ // requested it be at at least posting_list_bytes large.
+ auto status = GetSerializer()->MoveFrom(/*dst=*/&holder.posting_list,
+ /*src=*/&posting_list_buffer_);
+ if (!status.ok()) {
+ return FinalizeResult(std::move(status), prev_block_identifier_);
+ }
+ return FinalizeResult(libtextclassifier3::Status::OK, holder.id);
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/main/posting-list-accessor.h b/icing/file/posting_list/posting-list-accessor.h
index 3f93c3a..c7d614f 100644
--- a/icing/index/main/posting-list-accessor.h
+++ b/icing/file/posting_list/posting-list-accessor.h
@@ -1,4 +1,4 @@
-// Copyright (C) 2019 Google LLC
+// Copyright (C) 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,20 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef ICING_INDEX_POSTING_LIST_ACCESSOR_H_
-#define ICING_INDEX_POSTING_LIST_ACCESSOR_H_
+#ifndef ICING_FILE_POSTING_LIST_POSTING_LIST_ACCESSOR_H_
+#define ICING_FILE_POSTING_LIST_POSTING_LIST_ACCESSOR_H_
#include <cstdint>
#include <memory>
-#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
-#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/file/posting_list/flash-index-storage.h"
#include "icing/file/posting_list/posting-list-identifier.h"
#include "icing/file/posting_list/posting-list-used.h"
-#include "icing/index/hit/hit.h"
-#include "icing/index/main/posting-list-used-hit-serializer.h"
namespace icing {
namespace lib {
@@ -38,62 +34,14 @@ namespace lib {
// 3. Ensure that PostingListUseds can only be freed by calling methods which
// will also properly maintain the FlashIndexStorage free list and prevent
// callers from modifying the Posting List after freeing.
-
-// This class is used to provide a simple abstraction for adding hits to posting
-// lists. PostingListAccessor handles 1) selection of properly-sized posting
-// lists for the accumulated hits during Finalize() and 2) chaining of max-sized
-// posting lists.
class PostingListAccessor {
public:
- // Creates an empty PostingListAccessor.
- //
- // RETURNS:
- // - On success, a valid instance of PostingListAccessor
- // - INVALID_ARGUMENT error if storage has an invalid block_size.
- static libtextclassifier3::StatusOr<PostingListAccessor> Create(
- FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer);
-
- // Create a PostingListAccessor with an existing posting list identified by
- // existing_posting_list_id.
- //
- // The PostingListAccessor will add hits to this posting list until it is
- // necessary either to 1) chain the posting list (if it is max-sized) or 2)
- // move its hits to a larger posting list.
- //
- // RETURNS:
- // - On success, a valid instance of PostingListAccessor
- // - INVALID_ARGUMENT if storage has an invalid block_size.
- static libtextclassifier3::StatusOr<PostingListAccessor> CreateFromExisting(
- FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer,
- PostingListIdentifier existing_posting_list_id);
-
- // Retrieve the next batch of hits for the posting list chain
- //
- // RETURNS:
- // - On success, a vector of hits in the posting list chain
- // - INTERNAL if called on an instance of PostingListAccessor that was
- // created via PostingListAccessor::Create, if unable to read the next
- // posting list in the chain or if the posting list has been corrupted
- // somehow.
- libtextclassifier3::StatusOr<std::vector<Hit>> GetNextHitsBatch();
-
- // Prepend one hit. This may result in flushing the posting list to disk (if
- // the PostingListAccessor holds a max-sized posting list that is full) or
- // freeing a pre-existing posting list if it is too small to fit all hits
- // necessary.
- //
- // RETURNS:
- // - OK, on success
- // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the
- // previously added hit.
- // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new
- // posting list.
- libtextclassifier3::Status PrependHit(const Hit& hit);
+ virtual ~PostingListAccessor() = default;
struct FinalizeResult {
// - OK on success
// - INVALID_ARGUMENT if there was no pre-existing posting list and no
- // hits were added
+ // data were added
// - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a
// new posting list.
libtextclassifier3::Status status;
@@ -101,22 +49,27 @@ class PostingListAccessor {
// if status is OK. May be valid if status is non-OK, but previous blocks
// were written.
PostingListIdentifier id;
+
+ explicit FinalizeResult(libtextclassifier3::Status status_in,
+ PostingListIdentifier id_in)
+ : status(std::move(status_in)), id(std::move(id_in)) {}
};
- // Write all accumulated hits to storage.
+ // Write all accumulated data to storage.
//
// If accessor points to a posting list chain with multiple posting lists in
// the chain and unable to write the last posting list in the chain, Finalize
// will return the error and also populate id with the id of the
// second-to-last posting list.
- static FinalizeResult Finalize(PostingListAccessor accessor);
+ FinalizeResult Finalize() &&;
- private:
+ virtual PostingListUsedSerializer* GetSerializer() = 0;
+
+ protected:
explicit PostingListAccessor(
- FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer,
+ FlashIndexStorage* storage,
std::unique_ptr<uint8_t[]> posting_list_buffer_array,
PostingListUsed posting_list_buffer)
: storage_(storage),
- serializer_(serializer),
prev_block_identifier_(PostingListIdentifier::kInvalid),
posting_list_buffer_array_(std::move(posting_list_buffer_array)),
posting_list_buffer_(std::move(posting_list_buffer)),
@@ -141,23 +94,21 @@ class PostingListAccessor {
FlashIndexStorage* storage_; // Does not own.
- PostingListUsedHitSerializer* serializer_; // Does not own.
-
// The PostingListIdentifier of the first max-sized posting list in the
// posting list chain or PostingListIdentifier::kInvalid if there is no
// posting list chain.
PostingListIdentifier prev_block_identifier_;
// An editor to an existing posting list on disk. If available (non-NULL),
- // we'll try to add all hits to this posting list. Once this posting list
+ // we'll try to add all data to this posting list. Once this posting list
// fills up, we'll either 1) chain it (if a max-sized posting list) and put
- // future hits in posting_list_buffer_ or 2) copy all of its hits into
+ // future data in posting_list_buffer_ or 2) copy all of its data into
// posting_list_buffer_ and free this pl (if not a max-sized posting list).
// TODO(tjbarron) provide a benchmark to demonstrate the effects that re-using
// existing posting lists has on latency.
std::unique_ptr<PostingListHolder> preexisting_posting_list_;
- // In-memory posting list used to buffer hits before writing them to the
+ // In-memory posting list used to buffer data before writing them to the
// smallest on-disk posting list that will fit them.
// posting_list_buffer_array_ owns the memory region that posting_list_buffer_
// interprets. Therefore, posting_list_buffer_array_ must have the same
@@ -171,4 +122,4 @@ class PostingListAccessor {
} // namespace lib
} // namespace icing
-#endif // ICING_INDEX_POSTING_LIST_ACCESSOR_H_
+#endif // ICING_FILE_POSTING_LIST_POSTING_LIST_ACCESSOR_H_
diff --git a/icing/file/posting_list/posting-list-common.h b/icing/file/posting_list/posting-list-common.h
index cbe2ddf..44c6dd2 100644
--- a/icing/file/posting_list/posting-list-common.h
+++ b/icing/file/posting_list/posting-list-common.h
@@ -25,8 +25,6 @@ namespace lib {
using PostingListIndex = int32_t;
inline constexpr PostingListIndex kInvalidPostingListIndex = ~0U;
-inline constexpr uint32_t kNumSpecialData = 2;
-
inline constexpr uint32_t kInvalidBlockIndex = 0;
} // namespace lib
diff --git a/icing/file/posting_list/posting-list-used.h b/icing/file/posting_list/posting-list-used.h
index ec4b067..5821880 100644
--- a/icing/file/posting_list/posting-list-used.h
+++ b/icing/file/posting_list/posting-list-used.h
@@ -45,6 +45,28 @@ class PostingListUsed;
// posting list.
class PostingListUsedSerializer {
public:
+ // Special data is either a DataType instance or data_start_offset.
+ template <typename DataType>
+ union SpecialData {
+ explicit SpecialData(const DataType& data) : data_(data) {}
+
+ explicit SpecialData(uint32_t data_start_offset)
+ : data_start_offset_(data_start_offset) {}
+
+ const DataType& data() const { return data_; }
+
+ uint32_t data_start_offset() const { return data_start_offset_; }
+ void set_data_start_offset(uint32_t data_start_offset) {
+ data_start_offset_ = data_start_offset;
+ }
+
+ private:
+ DataType data_;
+ uint32_t data_start_offset_;
+ } __attribute__((packed));
+
+ static constexpr uint32_t kNumSpecialData = 2;
+
virtual ~PostingListUsedSerializer() = default;
// Returns byte size of the data type.
diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc
index 4c4bf65..60e347e 100644
--- a/icing/icing-search-engine.cc
+++ b/icing/icing-search-engine.cc
@@ -35,6 +35,8 @@
#include "icing/index/index-processor.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/dummy-numeric-index.h"
+#include "icing/join/join-processor.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/portable/endian.h"
#include "icing/proto/debug.pb.h"
@@ -477,6 +479,7 @@ void IcingSearchEngine::ResetMembers() {
language_segmenter_.reset();
normalizer_.reset();
index_.reset();
+ integer_index_.reset();
}
libtextclassifier3::Status IcingSearchEngine::CheckInitMarkerFile(
@@ -598,6 +601,11 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers(
std::string marker_filepath =
MakeSetSchemaMarkerFilePath(options_.base_dir());
+
+ // TODO(b/249829533): switch to use persistent numeric index after
+ // implementing and initialize numeric index.
+ integer_index_ = std::make_unique<DummyNumericIndex<int64_t>>();
+
libtextclassifier3::Status index_init_status;
if (absl_ports::IsNotFound(schema_store_->GetSchema().status())) {
// The schema was either lost or never set before. Wipe out the doc store
@@ -864,6 +872,12 @@ SetSchemaResultProto IcingSearchEngine::SetSchema(
return result_proto;
}
+ status = integer_index_->Reset();
+ if (!status.ok()) {
+ TransformStatus(status, result_status);
+ return result_proto;
+ }
+
IndexRestorationResult restore_result = RestoreIndexIfNeeded();
// DATA_LOSS means that we have successfully re-added content to the
// index. Some indexed content was lost, but otherwise the index is in a
@@ -963,17 +977,17 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) {
TokenizedDocument tokenized_document(
std::move(tokenized_document_or).ValueOrDie());
- auto document_id_or =
- document_store_->Put(tokenized_document.document(),
- tokenized_document.num_tokens(), put_document_stats);
+ auto document_id_or = document_store_->Put(
+ tokenized_document.document(), tokenized_document.num_string_tokens(),
+ put_document_stats);
if (!document_id_or.ok()) {
TransformStatus(document_id_or.status(), result_status);
return result_proto;
}
DocumentId document_id = document_id_or.ValueOrDie();
- auto index_processor_or =
- IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get());
+ auto index_processor_or = IndexProcessor::Create(
+ normalizer_.get(), index_.get(), integer_index_.get(), clock_.get());
if (!index_processor_or.ok()) {
TransformStatus(index_processor_or.status(), result_status);
return result_proto;
@@ -1243,8 +1257,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery(
std::unique_ptr<Timer> component_timer = clock_->GetNewTimer();
// Gets unordered results from query processor
auto query_processor_or = QueryProcessor::Create(
- index_.get(), language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get());
+ index_.get(), integer_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(), schema_store_.get());
if (!query_processor_or.ok()) {
TransformStatus(query_processor_or.status(), result_status);
delete_stats->set_parse_query_latency_ms(
@@ -1419,6 +1433,8 @@ OptimizeResultProto IcingSearchEngine::Optimize() {
optimize_stats->set_index_restoration_mode(
OptimizeStatsProto::FULL_INDEX_REBUILD);
ICING_LOG(WARNING) << "Resetting the entire index!";
+
+ // Reset string index
libtextclassifier3::Status index_reset_status = index_->Reset();
if (!index_reset_status.ok()) {
status = absl_ports::Annotate(
@@ -1430,6 +1446,18 @@ OptimizeResultProto IcingSearchEngine::Optimize() {
return result_proto;
}
+ // Reset integer index
+ index_reset_status = integer_index_->Reset();
+ if (!index_reset_status.ok()) {
+ status = absl_ports::Annotate(
+ absl_ports::InternalError("Failed to reset integer index."),
+ index_reset_status.error_message());
+ TransformStatus(status, result_status);
+ optimize_stats->set_index_restoration_latency_ms(
+ optimize_index_timer->GetElapsedMilliseconds());
+ return result_proto;
+ }
+
IndexRestorationResult index_restoration_status = RestoreIndexIfNeeded();
// DATA_LOSS means that we have successfully re-added content to the index.
// Some indexed content was lost, but otherwise the index is in a valid
@@ -1539,6 +1567,8 @@ GetOptimizeInfoResultProto IcingSearchEngine::GetOptimizeInfo() {
}
int64_t index_elements_size = index_elements_size_or.ValueOrDie();
+ // TODO(b/259744228): add stats for integer index
+
// Sum up the optimizable sizes from DocumentStore and Index
result_proto.set_estimated_optimizable_bytes(
index_elements_size * doc_store_optimize_info.optimizable_docs /
@@ -1568,6 +1598,7 @@ StorageInfoResultProto IcingSearchEngine::GetStorageInfo() {
schema_store_->GetStorageInfo();
*result.mutable_storage_info()->mutable_index_storage_info() =
index_->GetStorageInfo();
+ // TODO(b/259744228): add stats for integer index
result.mutable_status()->set_code(StatusProto::OK);
return result;
}
@@ -1588,6 +1619,8 @@ DebugInfoResultProto IcingSearchEngine::GetDebugInfo(
*debug_info.mutable_debug_info()->mutable_index_info() =
index_->GetDebugInfo(verbosity);
+ // TODO(b/259744228): add debug info for integer index
+
// Document Store
libtextclassifier3::StatusOr<DocumentDebugInfoProto> document_debug_info =
document_store_->GetDebugInfo(verbosity);
@@ -1620,6 +1653,7 @@ libtextclassifier3::Status IcingSearchEngine::InternalPersistToDisk(
ICING_RETURN_IF_ERROR(schema_store_->PersistToDisk());
ICING_RETURN_IF_ERROR(document_store_->PersistToDisk(PersistType::FULL));
ICING_RETURN_IF_ERROR(index_->PersistToDisk());
+ ICING_RETURN_IF_ERROR(integer_index_->PersistToDisk());
return libtextclassifier3::Status::OK;
}
@@ -1664,77 +1698,82 @@ SearchResultProto IcingSearchEngine::Search(
query_stats->set_is_first_page(true);
query_stats->set_requested_page_size(result_spec.num_per_page());
- std::unique_ptr<Timer> component_timer = clock_->GetNewTimer();
- // Gets unordered results from query processor
- auto query_processor_or = QueryProcessor::Create(
- index_.get(), language_segmenter_.get(), normalizer_.get(),
- document_store_.get(), schema_store_.get());
- if (!query_processor_or.ok()) {
- TransformStatus(query_processor_or.status(), result_status);
- query_stats->set_parse_query_latency_ms(
- component_timer->GetElapsedMilliseconds());
- return result_proto;
- }
- std::unique_ptr<QueryProcessor> query_processor =
- std::move(query_processor_or).ValueOrDie();
-
- auto query_results_or =
- query_processor->ParseSearch(search_spec, scoring_spec.rank_by());
- if (!query_results_or.ok()) {
- TransformStatus(query_results_or.status(), result_status);
- query_stats->set_parse_query_latency_ms(
- component_timer->GetElapsedMilliseconds());
- return result_proto;
- }
- QueryResults query_results = std::move(query_results_or).ValueOrDie();
- query_stats->set_parse_query_latency_ms(
- component_timer->GetElapsedMilliseconds());
-
+ // Process query and score
+ QueryScoringResults query_scoring_results =
+ ProcessQueryAndScore(search_spec, scoring_spec, result_spec);
int term_count = 0;
- for (const auto& section_and_terms : query_results.query_terms) {
+ for (const auto& section_and_terms : query_scoring_results.query_terms) {
term_count += section_and_terms.second.size();
}
query_stats->set_num_terms(term_count);
-
- component_timer = clock_->GetNewTimer();
- // Scores but does not rank the results.
- libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>>
- scoring_processor_or = ScoringProcessor::Create(
- scoring_spec, document_store_.get(), schema_store_.get());
- if (!scoring_processor_or.ok()) {
- TransformStatus(scoring_processor_or.status(), result_status);
- query_stats->set_scoring_latency_ms(
- component_timer->GetElapsedMilliseconds());
+ query_stats->set_parse_query_latency_ms(
+ query_scoring_results.parse_query_latency_ms);
+ query_stats->set_scoring_latency_ms(query_scoring_results.scoring_latency_ms);
+ if (!query_scoring_results.status.ok()) {
+ TransformStatus(query_scoring_results.status, result_status);
return result_proto;
}
- std::unique_ptr<ScoringProcessor> scoring_processor =
- std::move(scoring_processor_or).ValueOrDie();
- std::vector<ScoredDocumentHit> result_document_hits =
- scoring_processor->Score(std::move(query_results.root_iterator),
- performance_configuration_.num_to_score,
- &query_results.query_term_iterators);
- query_stats->set_scoring_latency_ms(
- component_timer->GetElapsedMilliseconds());
- query_stats->set_num_documents_scored(result_document_hits.size());
+ query_stats->set_num_documents_scored(
+ query_scoring_results.scored_document_hits.size());
// Returns early for empty result
- if (result_document_hits.empty()) {
+ if (query_scoring_results.scored_document_hits.empty()) {
result_status->set_code(StatusProto::OK);
return result_proto;
}
- component_timer = clock_->GetNewTimer();
- // Ranks results
- std::unique_ptr<ScoredDocumentHitsRanker> ranker =
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
- std::move(result_document_hits),
- /*is_descending=*/scoring_spec.order_by() ==
- ScoringSpecProto::Order::DESC);
- query_stats->set_ranking_latency_ms(
- component_timer->GetElapsedMilliseconds());
+ std::unique_ptr<ScoredDocumentHitsRanker> ranker;
+ if (search_spec.has_join_spec()) {
+ // Process 2nd query
+ QueryScoringResults nested_query_scoring_results = ProcessQueryAndScore(
+ search_spec.join_spec().nested_spec().search_spec(),
+ search_spec.join_spec().nested_spec().scoring_spec(),
+ search_spec.join_spec().nested_spec().result_spec());
+ // TOOD(b/256022027): set different kinds of latency for 2nd query.
+ if (!nested_query_scoring_results.status.ok()) {
+ TransformStatus(nested_query_scoring_results.status, result_status);
+ return result_proto;
+ }
- component_timer = clock_->GetNewTimer();
- // RanksAndPaginates and retrieves the document protos and snippets if
+ // Join 2 scored document hits
+ JoinProcessor join_processor(document_store_.get());
+ libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>>
+ joined_result_document_hits_or = join_processor.Join(
+ search_spec.join_spec(),
+ std::move(query_scoring_results.scored_document_hits),
+ std::move(nested_query_scoring_results.scored_document_hits));
+ if (!joined_result_document_hits_or.ok()) {
+ TransformStatus(joined_result_document_hits_or.status(), result_status);
+ return result_proto;
+ }
+ std::vector<JoinedScoredDocumentHit> joined_result_document_hits =
+ std::move(joined_result_document_hits_or).ValueOrDie();
+ // TODO(b/256022027): set join latency
+
+ std::unique_ptr<Timer> component_timer = clock_->GetNewTimer();
+ // Ranks results
+ ranker = std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<JoinedScoredDocumentHit>>(
+ std::move(joined_result_document_hits),
+ /*is_descending=*/scoring_spec.order_by() ==
+ ScoringSpecProto::Order::DESC);
+ query_stats->set_ranking_latency_ms(
+ component_timer->GetElapsedMilliseconds());
+ } else {
+ // Non-join query
+ std::unique_ptr<Timer> component_timer = clock_->GetNewTimer();
+ // Ranks results
+ ranker = std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
+ std::move(query_scoring_results.scored_document_hits),
+ /*is_descending=*/scoring_spec.order_by() ==
+ ScoringSpecProto::Order::DESC);
+ query_stats->set_ranking_latency_ms(
+ component_timer->GetElapsedMilliseconds());
+ }
+
+ std::unique_ptr<Timer> component_timer = clock_->GetNewTimer();
+ // CacheAndRetrieveFirstPage and retrieves the document protos and snippets if
// requested
auto result_retriever_or =
ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(),
@@ -1750,8 +1789,9 @@ SearchResultProto IcingSearchEngine::Search(
libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>>
page_result_info_or = result_state_manager_->CacheAndRetrieveFirstPage(
- std::move(ranker), std::move(query_results.query_terms), search_spec,
- scoring_spec, result_spec, *document_store_, *result_retriever);
+ std::move(ranker), std::move(query_scoring_results.query_terms),
+ search_spec, scoring_spec, result_spec, *document_store_,
+ *result_retriever);
if (!page_result_info_or.ok()) {
TransformStatus(page_result_info_or.status(), result_status);
query_stats->set_document_retrieval_latency_ms(
@@ -1783,6 +1823,63 @@ SearchResultProto IcingSearchEngine::Search(
return result_proto;
}
+IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore(
+ const SearchSpecProto& search_spec, const ScoringSpecProto& scoring_spec,
+ const ResultSpecProto& result_spec) {
+ std::unique_ptr<Timer> component_timer = clock_->GetNewTimer();
+
+ // Gets unordered results from query processor
+ auto query_processor_or = QueryProcessor::Create(
+ index_.get(), integer_index_.get(), language_segmenter_.get(),
+ normalizer_.get(), document_store_.get(), schema_store_.get());
+ if (!query_processor_or.ok()) {
+ return QueryScoringResults(
+ std::move(query_processor_or).status(), /*query_terms_in=*/{},
+ /*scored_document_hits_in=*/{},
+ /*parse_query_latency_ms_in=*/component_timer->GetElapsedMilliseconds(),
+ /*scoring_latency_ms_in=*/0);
+ }
+ std::unique_ptr<QueryProcessor> query_processor =
+ std::move(query_processor_or).ValueOrDie();
+
+ auto query_results_or =
+ query_processor->ParseSearch(search_spec, scoring_spec.rank_by());
+ if (!query_results_or.ok()) {
+ return QueryScoringResults(
+ std::move(query_results_or).status(), /*query_terms_in=*/{},
+ /*scored_document_hits_in=*/{},
+ /*parse_query_latency_ms_in=*/component_timer->GetElapsedMilliseconds(),
+ /*scoring_latency_ms_in=*/0);
+ }
+ QueryResults query_results = std::move(query_results_or).ValueOrDie();
+ int64_t parse_query_latency_ms = component_timer->GetElapsedMilliseconds();
+
+ component_timer = clock_->GetNewTimer();
+ // Scores but does not rank the results.
+ libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>>
+ scoring_processor_or = ScoringProcessor::Create(
+ scoring_spec, document_store_.get(), schema_store_.get());
+ if (!scoring_processor_or.ok()) {
+ return QueryScoringResults(std::move(scoring_processor_or).status(),
+ std::move(query_results.query_terms),
+ /*scored_document_hits_in=*/{},
+ parse_query_latency_ms,
+ /*scoring_latency_ms_in=*/0);
+ }
+ std::unique_ptr<ScoringProcessor> scoring_processor =
+ std::move(scoring_processor_or).ValueOrDie();
+ std::vector<ScoredDocumentHit> scored_document_hits =
+ scoring_processor->Score(std::move(query_results.root_iterator),
+ performance_configuration_.num_to_score,
+ &query_results.query_term_iterators);
+ int64_t scoring_latency_ms = component_timer->GetElapsedMilliseconds();
+
+ return QueryScoringResults(libtextclassifier3::Status::OK,
+ std::move(query_results.query_terms),
+ std::move(scored_document_hits),
+ parse_query_latency_ms, scoring_latency_ms);
+}
+
SearchResultProto IcingSearchEngine::GetNextPage(uint64_t next_page_token) {
SearchResultProto result_proto;
StatusProto* result_status = result_proto.mutable_status();
@@ -2006,8 +2103,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() {
return {libtextclassifier3::Status::OK, false};
}
- auto index_processor_or =
- IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get());
+ auto index_processor_or = IndexProcessor::Create(
+ normalizer_.get(), index_.get(), integer_index_.get(), clock_.get());
if (!index_processor_or.ok()) {
return {index_processor_or.status(), true};
}
diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h
index 4b0576f..221d86c 100644
--- a/icing/icing-search-engine.h
+++ b/icing/icing-search-engine.h
@@ -19,6 +19,7 @@
#include <memory>
#include <string>
#include <string_view>
+#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/statusor.h"
@@ -26,6 +27,7 @@
#include "icing/absl_ports/thread_annotations.h"
#include "icing/file/filesystem.h"
#include "icing/index/index.h"
+#include "icing/index/numeric/numeric-index.h"
#include "icing/jni/jni-cache.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/performance-configuration.h"
@@ -41,8 +43,10 @@
#include "icing/proto/search.pb.h"
#include "icing/proto/storage.pb.h"
#include "icing/proto/usage.pb.h"
+#include "icing/query/query-terms.h"
#include "icing/result/result-state-manager.h"
#include "icing/schema/schema-store.h"
+#include "icing/scoring/scored-document-hit.h"
#include "icing/store/document-store.h"
#include "icing/tokenization/language-segmenter.h"
#include "icing/transform/normalizer.h"
@@ -464,9 +468,14 @@ class IcingSearchEngine {
std::unique_ptr<const Normalizer> normalizer_ ICING_GUARDED_BY(mutex_);
- // Storage for all hits of content from the document store.
+ // Storage for all hits of string contents from the document store.
std::unique_ptr<Index> index_ ICING_GUARDED_BY(mutex_);
+ // Storage for all hits of numeric contents from the document store.
+ // TODO(b/249829533): integrate more functions with integer_index_.
+ std::unique_ptr<NumericIndex<int64_t>> integer_index_
+ ICING_GUARDED_BY(mutex_);
+
// Pointer to JNI class references
const std::unique_ptr<const JniCache> jni_cache_;
@@ -552,6 +561,37 @@ class IcingSearchEngine {
InitializeStatsProto* initialize_stats)
ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ // Processes query and scores according to the specs. It is a helper function
+ // (called by Search) to process and score normal query and the nested child
+ // query for join search.
+ //
+ // Returns a QueryScoringResults
+ // OK on success with a vector of ScoredDocumentHits,
+ // SectionRestrictQueryTermsMap, and other stats fields for logging.
+ // Any other errors when processing the query or scoring
+ struct QueryScoringResults {
+ libtextclassifier3::Status status;
+ SectionRestrictQueryTermsMap query_terms;
+ std::vector<ScoredDocumentHit> scored_document_hits;
+ int64_t parse_query_latency_ms;
+ int64_t scoring_latency_ms;
+
+ explicit QueryScoringResults(
+ libtextclassifier3::Status status_in,
+ SectionRestrictQueryTermsMap&& query_terms_in,
+ std::vector<ScoredDocumentHit>&& scored_document_hits_in,
+ int64_t parse_query_latency_ms_in, int64_t scoring_latency_ms_in)
+ : status(std::move(status_in)),
+ query_terms(std::move(query_terms_in)),
+ scored_document_hits(std::move(scored_document_hits_in)),
+ parse_query_latency_ms(parse_query_latency_ms_in),
+ scoring_latency_ms(scoring_latency_ms_in) {}
+ };
+ QueryScoringResults ProcessQueryAndScore(const SearchSpecProto& search_spec,
+ const ScoringSpecProto& scoring_spec,
+ const ResultSpecProto& result_spec)
+ ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
// Many of the internal components rely on other components' derived data.
// Check that everything is consistent with each other so that we're not
// using outdated derived data in some parts of our system.
diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc
index 7a60101..8cb7e7f 100644
--- a/icing/icing-search-engine_test.cc
+++ b/icing/icing-search-engine_test.cc
@@ -55,9 +55,9 @@
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/jni-test-helpers.h"
#include "icing/testing/random-string.h"
-#include "icing/testing/snippet-helpers.h"
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
+#include "icing/util/snippet-helpers.h"
namespace icing {
namespace lib {
@@ -9895,6 +9895,423 @@ TEST_F(IcingSearchEngineTest, IcingShouldWorkFor64Sections) {
EqualsSearchResultIgnoreStatsAndScores(expected_no_documents));
}
+TEST_F(IcingSearchEngineTest, SimpleJoin) {
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("Person")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("firstName")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("lastName")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("emailAddress")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .AddType(SchemaTypeConfigBuilder().SetType("Email").AddProperty(
+ PropertyConfigBuilder()
+ .SetName("subjectId")
+ .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ DocumentProto person1 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "person1")
+ .SetSchema("Person")
+ .AddStringProperty("firstName", "first1")
+ .AddStringProperty("lastName", "last1")
+ .AddStringProperty("emailAddress", "email1@gmail.com")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+ DocumentProto person2 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "person2")
+ .SetSchema("Person")
+ .AddStringProperty("firstName", "first2")
+ .AddStringProperty("lastName", "last2")
+ .AddStringProperty("emailAddress", "email2@gmail.com")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+ DocumentProto person3 =
+ DocumentBuilder()
+ .SetKey("pkg$db/name#space\\\\", "person3")
+ .SetSchema("Person")
+ .AddStringProperty("firstName", "first3")
+ .AddStringProperty("lastName", "last3")
+ .AddStringProperty("emailAddress", "email3@gmail.com")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+
+ DocumentProto email1 =
+ DocumentBuilder()
+ .SetKey("namespace", "email1")
+ .SetSchema("Email")
+ .AddStringProperty("subjectId", "pkg$db/namespace#person1")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+ DocumentProto email2 =
+ DocumentBuilder()
+ .SetKey("namespace", "email2")
+ .SetSchema("Email")
+ .AddStringProperty("subjectId", "pkg$db/namespace#person2")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+ DocumentProto email3 =
+ DocumentBuilder()
+ .SetKey("namespace", "email3")
+ .SetSchema("Email")
+ .AddStringProperty("subjectId", "pkg$db/name\\#space\\\\#person3")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(person1).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(person2).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(person3).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(email1).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(email2).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(email3).status(), ProtoIsOk());
+
+ // Parent SearchSpec
+ SearchSpecProto search_spec;
+ search_spec.set_term_match_type(TermMatchType::PREFIX);
+ search_spec.set_query("first");
+
+ // JoinSpec
+ JoinSpecProto* join_spec = search_spec.mutable_join_spec();
+ // Set max_joined_child_count as 2, so only email 3, email2 will be included
+ // in the nested result and email1 will be truncated.
+ join_spec->set_max_joined_child_count(2);
+ join_spec->set_parent_property_expression("this.fullyQualifiedId()");
+ join_spec->set_child_property_expression("subjectId");
+ JoinSpecProto::NestedSpecProto* nested_spec =
+ join_spec->mutable_nested_spec();
+ SearchSpecProto* nested_search_spec = nested_spec->mutable_search_spec();
+ nested_search_spec->set_term_match_type(TermMatchType::PREFIX);
+ nested_search_spec->set_query("");
+ *nested_spec->mutable_scoring_spec() = GetDefaultScoringSpec();
+ *nested_spec->mutable_result_spec() = ResultSpecProto::default_instance();
+
+ // Parent ScoringSpec
+ ScoringSpecProto scoring_spec = GetDefaultScoringSpec();
+
+ // Parent ResultSpec
+ ResultSpecProto result_spec;
+ result_spec.set_num_per_page(1);
+
+ SearchResultProto expected_result1;
+ expected_result1.mutable_status()->set_code(StatusProto::OK);
+ SearchResultProto::ResultProto* result_proto1 =
+ expected_result1.mutable_results()->Add();
+ *result_proto1->mutable_document() = person3;
+ *result_proto1->mutable_joined_results()->Add()->mutable_document() = email3;
+
+ SearchResultProto expected_result2;
+ expected_result2.mutable_status()->set_code(StatusProto::OK);
+ SearchResultProto::ResultProto* result_proto2 =
+ expected_result2.mutable_results()->Add();
+ *result_proto2->mutable_document() = person2;
+ *result_proto2->mutable_joined_results()->Add()->mutable_document() = email2;
+
+ SearchResultProto expected_result3;
+ expected_result3.mutable_status()->set_code(StatusProto::OK);
+ SearchResultProto::ResultProto* result_proto3 =
+ expected_result3.mutable_results()->Add();
+ *result_proto3->mutable_document() = person1;
+ *result_proto3->mutable_joined_results()->Add()->mutable_document() = email1;
+
+ SearchResultProto result1 =
+ icing.Search(search_spec, scoring_spec, result_spec);
+ uint64_t next_page_token = result1.next_page_token();
+ EXPECT_THAT(next_page_token, Ne(kInvalidNextPageToken));
+ expected_result1.set_next_page_token(next_page_token);
+ EXPECT_THAT(result1,
+ EqualsSearchResultIgnoreStatsAndScores(expected_result1));
+
+ SearchResultProto result2 = icing.GetNextPage(next_page_token);
+ next_page_token = result2.next_page_token();
+ EXPECT_THAT(next_page_token, Ne(kInvalidNextPageToken));
+ expected_result2.set_next_page_token(next_page_token);
+ EXPECT_THAT(result2,
+ EqualsSearchResultIgnoreStatsAndScores(expected_result2));
+
+ SearchResultProto result3 = icing.GetNextPage(next_page_token);
+ next_page_token = result3.next_page_token();
+ EXPECT_THAT(next_page_token, Eq(kInvalidNextPageToken));
+ EXPECT_THAT(result3,
+ EqualsSearchResultIgnoreStatsAndScores(expected_result3));
+}
+
+TEST_F(IcingSearchEngineTest, InvalidJoins) {
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("Person")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("firstName")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("lastName")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("emailAddress")
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .AddType(SchemaTypeConfigBuilder().SetType("Email").AddProperty(
+ PropertyConfigBuilder()
+ .SetName("subjectId")
+ .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ DocumentProto person1 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace", "person1")
+ .SetSchema("Person")
+ .AddStringProperty("firstName", "first1")
+ .AddStringProperty("lastName", "last1")
+ .AddStringProperty("emailAddress", "email1@gmail.com")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+ DocumentProto person2 =
+ DocumentBuilder()
+ .SetKey("pkg$db/namespace\\", "person2")
+ .SetSchema("Person")
+ .AddStringProperty("firstName", "first2")
+ .AddStringProperty("lastName", "last2")
+ .AddStringProperty("emailAddress", "email2@gmail.com")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+
+ // "invalid format" does not refer to any document, so it will not be joined
+ // to any document.
+ DocumentProto email1 =
+ DocumentBuilder()
+ .SetKey("namespace", "email1")
+ .SetSchema("Email")
+ .AddStringProperty("subjectId", "invalid format")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+ // This will not be joined because the # in the subjectId is escaped.
+ DocumentProto email2 =
+ DocumentBuilder()
+ .SetKey("namespace", "email2")
+ .SetSchema("Email")
+ .AddStringProperty("subjectId", "pkg$db/namespace\\#person2")
+ .SetCreationTimestampMs(kDefaultCreationTimestampMs)
+ .Build();
+
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+ ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(person1).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(person2).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(email1).status(), ProtoIsOk());
+ ASSERT_THAT(icing.Put(email2).status(), ProtoIsOk());
+
+ // Parent SearchSpec
+ SearchSpecProto search_spec;
+ search_spec.set_term_match_type(TermMatchType::PREFIX);
+ search_spec.set_query("first");
+
+ // JoinSpec
+ JoinSpecProto* join_spec = search_spec.mutable_join_spec();
+ // Set max_joined_child_count as 2, so only email 3, email2 will be included
+ // in the nested result and email1 will be truncated.
+ join_spec->set_max_joined_child_count(2);
+ join_spec->set_parent_property_expression("this.fullyQualifiedId()");
+ join_spec->set_child_property_expression("subjectId");
+ JoinSpecProto::NestedSpecProto* nested_spec =
+ join_spec->mutable_nested_spec();
+ SearchSpecProto* nested_search_spec = nested_spec->mutable_search_spec();
+ nested_search_spec->set_term_match_type(TermMatchType::PREFIX);
+ nested_search_spec->set_query("");
+ *nested_spec->mutable_scoring_spec() = GetDefaultScoringSpec();
+ *nested_spec->mutable_result_spec() = ResultSpecProto::default_instance();
+
+ // Parent ScoringSpec
+ ScoringSpecProto scoring_spec = GetDefaultScoringSpec();
+
+ // Parent ResultSpec
+ ResultSpecProto result_spec;
+ result_spec.set_num_per_page(1);
+
+ SearchResultProto expected_result1;
+ expected_result1.mutable_status()->set_code(StatusProto::OK);
+ SearchResultProto::ResultProto* result_proto1 =
+ expected_result1.mutable_results()->Add();
+ *result_proto1->mutable_document() = person2;
+
+ SearchResultProto expected_result2;
+ expected_result2.mutable_status()->set_code(StatusProto::OK);
+ SearchResultProto::ResultProto* result_proto2 =
+ expected_result2.mutable_results()->Add();
+ *result_proto2->mutable_document() = person1;
+
+ SearchResultProto result1 =
+ icing.Search(search_spec, scoring_spec, result_spec);
+ uint64_t next_page_token = result1.next_page_token();
+ EXPECT_THAT(next_page_token, Ne(kInvalidNextPageToken));
+ expected_result1.set_next_page_token(next_page_token);
+ EXPECT_THAT(result1,
+ EqualsSearchResultIgnoreStatsAndScores(expected_result1));
+
+ SearchResultProto result2 = icing.GetNextPage(next_page_token);
+ next_page_token = result2.next_page_token();
+ EXPECT_THAT(next_page_token, Eq(kInvalidNextPageToken));
+ EXPECT_THAT(result2,
+ EqualsSearchResultIgnoreStatsAndScores(expected_result2));
+}
+
+TEST_F(IcingSearchEngineTest, NumericFilterAdvancedQuerySucceeds) {
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+
+ // Create the schema and document store
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("transaction")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("price")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("cost")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+ ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk());
+
+ DocumentProto document_one = DocumentBuilder()
+ .SetKey("namespace", "1")
+ .SetSchema("transaction")
+ .SetCreationTimestampMs(1)
+ .AddInt64Property("price", 10)
+ .Build();
+ ASSERT_THAT(icing.Put(document_one).status(), ProtoIsOk());
+
+ DocumentProto document_two = DocumentBuilder()
+ .SetKey("namespace", "2")
+ .SetSchema("transaction")
+ .SetCreationTimestampMs(1)
+ .AddInt64Property("price", 25)
+ .Build();
+ ASSERT_THAT(icing.Put(document_two).status(), ProtoIsOk());
+
+ DocumentProto document_three = DocumentBuilder()
+ .SetKey("namespace", "3")
+ .SetSchema("transaction")
+ .SetCreationTimestampMs(1)
+ .AddInt64Property("cost", 2)
+ .Build();
+ ASSERT_THAT(icing.Put(document_three).status(), ProtoIsOk());
+
+ SearchSpecProto search_spec;
+ search_spec.set_query("price < 20");
+ search_spec.set_search_type(
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY);
+
+ SearchResultProto results =
+ icing.Search(search_spec, ScoringSpecProto::default_instance(),
+ ResultSpecProto::default_instance());
+ ASSERT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document_one));
+
+ search_spec.set_query("price == 25");
+ results = icing.Search(search_spec, ScoringSpecProto::default_instance(),
+ ResultSpecProto::default_instance());
+ ASSERT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document_two));
+
+ search_spec.set_query("cost > 2");
+ results = icing.Search(search_spec, ScoringSpecProto::default_instance(),
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.results(), IsEmpty());
+
+ search_spec.set_query("cost >= 2");
+ results = icing.Search(search_spec, ScoringSpecProto::default_instance(),
+ ResultSpecProto::default_instance());
+ ASSERT_THAT(results.results(), SizeIs(1));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document_three));
+
+ search_spec.set_query("price <= 25");
+ results = icing.Search(search_spec, ScoringSpecProto::default_instance(),
+ ResultSpecProto::default_instance());
+ ASSERT_THAT(results.results(), SizeIs(2));
+ EXPECT_THAT(results.results(0).document(), EqualsProto(document_two));
+ EXPECT_THAT(results.results(1).document(), EqualsProto(document_one));
+}
+
+TEST_F(IcingSearchEngineTest, NumericFilterOldQueryFails) {
+ IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache());
+ ASSERT_THAT(icing.Initialize().status(), ProtoIsOk());
+
+ // Create the schema and document store
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("transaction")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("price")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("cost")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+ ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk());
+
+ DocumentProto document_one = DocumentBuilder()
+ .SetKey("namespace", "1")
+ .SetSchema("transaction")
+ .SetCreationTimestampMs(1)
+ .AddInt64Property("price", 10)
+ .Build();
+ ASSERT_THAT(icing.Put(document_one).status(), ProtoIsOk());
+
+ DocumentProto document_two = DocumentBuilder()
+ .SetKey("namespace", "2")
+ .SetSchema("transaction")
+ .SetCreationTimestampMs(1)
+ .AddInt64Property("price", 25)
+ .Build();
+ ASSERT_THAT(icing.Put(document_two).status(), ProtoIsOk());
+
+ DocumentProto document_three = DocumentBuilder()
+ .SetKey("namespace", "3")
+ .SetSchema("transaction")
+ .SetCreationTimestampMs(1)
+ .AddInt64Property("cost", 2)
+ .Build();
+ ASSERT_THAT(icing.Put(document_three).status(), ProtoIsOk());
+
+ SearchSpecProto search_spec;
+ search_spec.set_query("price < 20");
+ search_spec.set_search_type(SearchSpecProto::SearchType::ICING_RAW_QUERY);
+
+ SearchResultProto results =
+ icing.Search(search_spec, ScoringSpecProto::default_instance(),
+ ResultSpecProto::default_instance());
+ EXPECT_THAT(results.status(), ProtoStatusIs(StatusProto::INVALID_ARGUMENT));
+}
+
} // namespace
} // namespace lib
} // namespace icing
diff --git a/icing/index/index-processor.cc b/icing/index/index-processor.cc
index cfeda31..9f21c9d 100644
--- a/icing/index/index-processor.cc
+++ b/icing/index/index-processor.cc
@@ -21,18 +21,12 @@
#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
-#include "icing/absl_ports/canonical_errors.h"
-#include "icing/absl_ports/str_cat.h"
#include "icing/index/index.h"
-#include "icing/legacy/core/icing-string-util.h"
+#include "icing/index/integer-section-indexing-handler.h"
+#include "icing/index/numeric/numeric-index.h"
+#include "icing/index/string-section-indexing-handler.h"
#include "icing/proto/logging.pb.h"
-#include "icing/proto/schema.pb.h"
-#include "icing/schema/section-manager.h"
-#include "icing/schema/section.h"
#include "icing/store/document-id.h"
-#include "icing/tokenization/token.h"
-#include "icing/tokenization/tokenizer-factory.h"
-#include "icing/tokenization/tokenizer.h"
#include "icing/transform/normalizer.h"
#include "icing/util/status-macros.h"
#include "icing/util/tokenized-document.h"
@@ -42,117 +36,33 @@ namespace lib {
libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>>
IndexProcessor::Create(const Normalizer* normalizer, Index* index,
+ NumericIndex<int64_t>* integer_index,
const Clock* clock) {
ICING_RETURN_ERROR_IF_NULL(normalizer);
ICING_RETURN_ERROR_IF_NULL(index);
+ ICING_RETURN_ERROR_IF_NULL(integer_index);
ICING_RETURN_ERROR_IF_NULL(clock);
+ std::vector<std::unique_ptr<SectionIndexingHandler>> handlers;
+ handlers.push_back(
+ std::make_unique<StringSectionIndexingHandler>(clock, normalizer, index));
+ handlers.push_back(
+ std::make_unique<IntegerSectionIndexingHandler>(clock, integer_index));
+
return std::unique_ptr<IndexProcessor>(
- new IndexProcessor(normalizer, index, clock));
+ new IndexProcessor(std::move(handlers), clock));
}
libtextclassifier3::Status IndexProcessor::IndexDocument(
const TokenizedDocument& tokenized_document, DocumentId document_id,
PutDocumentStatsProto* put_document_stats) {
- std::unique_ptr<Timer> index_timer = clock_.GetNewTimer();
-
- if (index_->last_added_document_id() != kInvalidDocumentId &&
- document_id <= index_->last_added_document_id()) {
- return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
- "DocumentId %d must be greater than last added document_id %d",
- document_id, index_->last_added_document_id()));
- }
- index_->set_last_added_document_id(document_id);
- uint32_t num_tokens = 0;
- libtextclassifier3::Status status;
- for (const TokenizedSection& section : tokenized_document.sections()) {
- if (section.metadata.tokenizer ==
- StringIndexingConfig::TokenizerType::NONE) {
- ICING_LOG(WARNING)
- << "Unexpected TokenizerType::NONE found when indexing document.";
- }
- // TODO(b/152934343): pass real namespace ids in
- Index::Editor editor =
- index_->Edit(document_id, section.metadata.id,
- section.metadata.term_match_type, /*namespace_id=*/0);
- for (std::string_view token : section.token_sequence) {
- ++num_tokens;
-
- switch (section.metadata.tokenizer) {
- case StringIndexingConfig::TokenizerType::VERBATIM:
- // data() is safe to use here because a token created from the
- // VERBATIM tokenizer is the entire string value. The character at
- // data() + token.length() is guaranteed to be a null char.
- status = editor.BufferTerm(token.data());
- break;
- case StringIndexingConfig::TokenizerType::NONE:
- [[fallthrough]];
- case StringIndexingConfig::TokenizerType::RFC822:
- [[fallthrough]];
- case StringIndexingConfig::TokenizerType::URL:
- [[fallthrough]];
- case StringIndexingConfig::TokenizerType::PLAIN:
- std::string normalized_term = normalizer_.NormalizeTerm(token);
- status = editor.BufferTerm(normalized_term.c_str());
- }
-
- if (!status.ok()) {
- // We've encountered a failure. Bail out. We'll mark this doc as deleted
- // and signal a failure to the client.
- ICING_LOG(WARNING) << "Failed to buffer term in lite lexicon due to: "
- << status.error_message();
- break;
- }
- }
- if (!status.ok()) {
- break;
- }
- // Add all the seen terms to the index with their term frequency.
- status = editor.IndexAllBufferedTerms();
- if (!status.ok()) {
- ICING_LOG(WARNING) << "Failed to add hits in lite index due to: "
- << status.error_message();
- break;
- }
- }
-
- if (put_document_stats != nullptr) {
- put_document_stats->set_index_latency_ms(
- index_timer->GetElapsedMilliseconds());
- put_document_stats->mutable_tokenization_stats()->set_num_tokens_indexed(
- num_tokens);
- }
-
- // If we're either successful or we've hit resource exhausted, then attempt a
- // merge.
- if ((status.ok() || absl_ports::IsResourceExhausted(status)) &&
- index_->WantsMerge()) {
- ICING_LOG(ERROR) << "Merging the index at docid " << document_id << ".";
-
- std::unique_ptr<Timer> merge_timer = clock_.GetNewTimer();
- libtextclassifier3::Status merge_status = index_->Merge();
-
- if (!merge_status.ok()) {
- ICING_LOG(ERROR) << "Index merging failed. Clearing index.";
- if (!index_->Reset().ok()) {
- return absl_ports::InternalError(IcingStringUtil::StringPrintf(
- "Unable to reset to clear index after merge failure. Merge "
- "failure=%d:%s",
- merge_status.error_code(), merge_status.error_message().c_str()));
- } else {
- return absl_ports::DataLossError(IcingStringUtil::StringPrintf(
- "Forced to reset index after merge failure. Merge failure=%d:%s",
- merge_status.error_code(), merge_status.error_message().c_str()));
- }
- }
-
- if (put_document_stats != nullptr) {
- put_document_stats->set_index_merge_latency_ms(
- merge_timer->GetElapsedMilliseconds());
- }
+ // TODO(b/259744228): set overall index latency.
+ for (auto& section_indexing_handler : section_indexing_handlers_) {
+ ICING_RETURN_IF_ERROR(section_indexing_handler->Handle(
+ tokenized_document, document_id, put_document_stats));
}
- return status;
+ return libtextclassifier3::Status::OK;
}
} // namespace lib
diff --git a/icing/index/index-processor.h b/icing/index/index-processor.h
index b7ffdb5..45954c4 100644
--- a/icing/index/index-processor.h
+++ b/icing/index/index-processor.h
@@ -16,14 +16,15 @@
#define ICING_INDEX_INDEX_PROCESSOR_H_
#include <cstdint>
-#include <string>
+#include <memory>
+#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/index/index.h"
+#include "icing/index/numeric/numeric-index.h"
+#include "icing/index/section-indexing-handler.h"
#include "icing/proto/logging.pb.h"
-#include "icing/schema/section-manager.h"
#include "icing/store/document-id.h"
-#include "icing/tokenization/token.h"
#include "icing/transform/normalizer.h"
#include "icing/util/tokenized-document.h"
@@ -40,7 +41,8 @@ class IndexProcessor {
// An IndexProcessor on success
// FAILED_PRECONDITION if any of the pointers is null.
static libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>> Create(
- const Normalizer* normalizer, Index* index, const Clock* clock);
+ const Normalizer* normalizer, Index* index,
+ NumericIndex<int64_t>* integer_index_, const Clock* clock);
// Add tokenized document to the index, associated with document_id. If the
// number of tokens in the document exceeds max_tokens_per_document, then only
@@ -54,23 +56,21 @@ class IndexProcessor {
// populated.
//
// Returns:
- // INVALID_ARGUMENT if document_id is less than the document_id of a
- // previously indexed document or tokenization fails.
- // RESOURCE_EXHAUSTED if the index is full and can't add anymore content.
- // DATA_LOSS if an attempt to merge the index fails and both indices are
- // cleared as a result.
- // NOT_FOUND if there is no definition for the document's schema type.
- // INTERNAL_ERROR if any other errors occur
+ // - OK on success.
+ // - Any SectionIndexingHandler errors.
libtextclassifier3::Status IndexDocument(
const TokenizedDocument& tokenized_document, DocumentId document_id,
PutDocumentStatsProto* put_document_stats = nullptr);
private:
- IndexProcessor(const Normalizer* normalizer, Index* index, const Clock* clock)
- : normalizer_(*normalizer), index_(index), clock_(*clock) {}
+ explicit IndexProcessor(std::vector<std::unique_ptr<SectionIndexingHandler>>&&
+ section_indexing_handlers,
+ const Clock* clock)
+ : section_indexing_handlers_(std::move(section_indexing_handlers)),
+ clock_(*clock) {}
- const Normalizer& normalizer_;
- Index* const index_;
+ std::vector<std::unique_ptr<SectionIndexingHandler>>
+ section_indexing_handlers_;
const Clock& clock_;
};
diff --git a/icing/index/index-processor_benchmark.cc b/icing/index/index-processor_benchmark.cc
index 68c592c..6123f47 100644
--- a/icing/index/index-processor_benchmark.cc
+++ b/icing/index/index-processor_benchmark.cc
@@ -18,6 +18,8 @@
#include "icing/file/filesystem.h"
#include "icing/index/index-processor.h"
#include "icing/index/index.h"
+#include "icing/index/numeric/dummy-numeric-index.h"
+#include "icing/index/numeric/numeric-index.h"
#include "icing/legacy/core/icing-string-util.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/schema-util.h"
@@ -55,7 +57,8 @@
// $ adb push blaze-bin/icing/index/index-processor_benchmark
// /data/local/tmp/
//
-// $ adb shell /data/local/tmp/index-processor_benchmark --benchmark_filter=all
+// $ adb shell /data/local/tmp/index-processor_benchmark
+// --benchmark_filter=all
// --adb
// Flag to tell the benchmark that it'll be run on an Android device via adb,
@@ -183,6 +186,8 @@ void BM_IndexDocumentWithOneProperty(benchmark::State& state) {
std::unique_ptr<Index> index =
CreateIndex(icing_filesystem, filesystem, index_dir);
+ std::unique_ptr<NumericIndex<int64_t>> integer_index =
+ std::make_unique<DummyNumericIndex<int64_t>>();
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
language_segmenter_factory::Create(std::move(options)).ValueOrDie();
@@ -191,7 +196,8 @@ void BM_IndexDocumentWithOneProperty(benchmark::State& state) {
std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock);
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<IndexProcessor> index_processor,
- IndexProcessor::Create(normalizer.get(), index.get(), &clock));
+ IndexProcessor::Create(normalizer.get(), index.get(), integer_index.get(),
+ &clock));
DocumentProto input_document = CreateDocumentWithOneProperty(state.range(0));
TokenizedDocument tokenized_document(std::move(
TokenizedDocument::Create(schema_store.get(), language_segmenter.get(),
@@ -237,6 +243,8 @@ void BM_IndexDocumentWithTenProperties(benchmark::State& state) {
std::unique_ptr<Index> index =
CreateIndex(icing_filesystem, filesystem, index_dir);
+ std::unique_ptr<NumericIndex<int64_t>> integer_index =
+ std::make_unique<DummyNumericIndex<int64_t>>();
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
language_segmenter_factory::Create(std::move(options)).ValueOrDie();
@@ -245,7 +253,8 @@ void BM_IndexDocumentWithTenProperties(benchmark::State& state) {
std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock);
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<IndexProcessor> index_processor,
- IndexProcessor::Create(normalizer.get(), index.get(), &clock));
+ IndexProcessor::Create(normalizer.get(), index.get(), integer_index.get(),
+ &clock));
DocumentProto input_document =
CreateDocumentWithTenProperties(state.range(0));
@@ -293,6 +302,8 @@ void BM_IndexDocumentWithDiacriticLetters(benchmark::State& state) {
std::unique_ptr<Index> index =
CreateIndex(icing_filesystem, filesystem, index_dir);
+ std::unique_ptr<NumericIndex<int64_t>> integer_index =
+ std::make_unique<DummyNumericIndex<int64_t>>();
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
language_segmenter_factory::Create(std::move(options)).ValueOrDie();
@@ -301,7 +312,8 @@ void BM_IndexDocumentWithDiacriticLetters(benchmark::State& state) {
std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock);
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<IndexProcessor> index_processor,
- IndexProcessor::Create(normalizer.get(), index.get(), &clock));
+ IndexProcessor::Create(normalizer.get(), index.get(), integer_index.get(),
+ &clock));
DocumentProto input_document =
CreateDocumentWithDiacriticLetters(state.range(0));
@@ -349,6 +361,8 @@ void BM_IndexDocumentWithHiragana(benchmark::State& state) {
std::unique_ptr<Index> index =
CreateIndex(icing_filesystem, filesystem, index_dir);
+ std::unique_ptr<NumericIndex<int64_t>> integer_index =
+ std::make_unique<DummyNumericIndex<int64_t>>();
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
language_segmenter_factory::Create(std::move(options)).ValueOrDie();
@@ -357,7 +371,8 @@ void BM_IndexDocumentWithHiragana(benchmark::State& state) {
std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock);
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<IndexProcessor> index_processor,
- IndexProcessor::Create(normalizer.get(), index.get(), &clock));
+ IndexProcessor::Create(normalizer.get(), index.get(), integer_index.get(),
+ &clock));
DocumentProto input_document = CreateDocumentWithHiragana(state.range(0));
TokenizedDocument tokenized_document(std::move(
diff --git a/icing/index/index-processor_test.cc b/icing/index/index-processor_test.cc
index 3c848d3..b83d33c 100644
--- a/icing/index/index-processor_test.cc
+++ b/icing/index/index-processor_test.cc
@@ -34,6 +34,8 @@
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/dummy-numeric-index.h"
+#include "icing/index/numeric/numeric-index.h"
#include "icing/index/term-property-id.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/legacy/index/icing-mock-filesystem.h"
@@ -44,7 +46,6 @@
#include "icing/schema-builder.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/schema-util.h"
-#include "icing/schema/section-manager.h"
#include "icing/schema/section.h"
#include "icing/store/document-id.h"
#include "icing/testing/common-matchers.h"
@@ -81,49 +82,57 @@ constexpr std::string_view kIpsumText =
"vehicula posuere vitae, convallis eu lorem. Donec semper augue eu nibh "
"placerat semper.";
-// type and property names of FakeType
+// schema types
constexpr std::string_view kFakeType = "FakeType";
+constexpr std::string_view kNestedType = "NestedType";
+
+// Indexable properties and section Id. Section Id is determined by the
+// lexicographical order of indexable property path.
constexpr std::string_view kExactProperty = "exact";
+constexpr std::string_view kIndexableIntegerProperty = "indexableInteger";
constexpr std::string_view kPrefixedProperty = "prefixed";
-constexpr std::string_view kUnindexedProperty1 = "unindexed1";
-constexpr std::string_view kUnindexedProperty2 = "unindexed2";
constexpr std::string_view kRepeatedProperty = "repeated";
-constexpr std::string_view kSubProperty = "submessage";
-constexpr std::string_view kNestedType = "NestedType";
-constexpr std::string_view kNestedProperty = "nested";
-constexpr std::string_view kExactVerbatimProperty = "verbatimExact";
-constexpr std::string_view kPrefixedVerbatimProperty = "verbatimPrefixed";
constexpr std::string_view kRfc822Property = "rfc822";
+constexpr std::string_view kSubProperty = "submessage"; // submessage.nested
+constexpr std::string_view kNestedProperty = "nested"; // submessage.nested
// TODO (b/246964044): remove ifdef guard when url-tokenizer is ready for export
// to Android.
#ifdef ENABLE_URL_TOKENIZER
-constexpr std::string_view kExactUrlProperty = "urlExact";
-constexpr std::string_view kPrefixedUrlProperty = "urlPrefixed";
+constexpr std::string_view kUrlExactProperty = "urlExact";
+constexpr std::string_view kUrlPrefixedProperty = "urlPrefixed";
#endif // ENABLE_URL_TOKENIZER
-
-constexpr DocumentId kDocumentId0 = 0;
-constexpr DocumentId kDocumentId1 = 1;
+constexpr std::string_view kVerbatimExactProperty = "verbatimExact";
+constexpr std::string_view kVerbatimPrefixedProperty = "verbatimPrefixed";
constexpr SectionId kExactSectionId = 0;
-constexpr SectionId kPrefixedSectionId = 1;
-constexpr SectionId kRepeatedSectionId = 2;
-constexpr SectionId kRfc822SectionId = 3;
-constexpr SectionId kNestedSectionId = 4;
+constexpr SectionId kIndexableIntegerSectionId = 1;
+constexpr SectionId kPrefixedSectionId = 2;
+constexpr SectionId kRepeatedSectionId = 3;
+constexpr SectionId kRfc822SectionId = 4;
+constexpr SectionId kNestedSectionId = 5; // submessage.nested
#ifdef ENABLE_URL_TOKENIZER
-constexpr SectionId kUrlExactSectionId = 5;
-constexpr SectionId kUrlPrefixedSectionId = 6;
-constexpr SectionId kExactVerbatimSectionId = 7;
-constexpr SectionId kPrefixedVerbatimSectionId = 8;
-#else // !ENABLE_URL_TOKENIZER
-constexpr SectionId kExactVerbatimSectionId = 5;
-constexpr SectionId kPrefixedVerbatimSectionId = 6;
+constexpr SectionId kUrlExactSectionId = 6;
+constexpr SectionId kUrlPrefixedSectionId = 7;
+constexpr SectionId kVerbatimExactSectionId = 8;
+constexpr SectionId kVerbatimPrefixedSectionId = 9;
+#else // !ENABLE_URL_TOKENIZER
+constexpr SectionId kVerbatimExactSectionId = 6;
+constexpr SectionId kVerbatimPrefixedSectionId = 7;
#endif // ENABLE_URL_TOKENIZER
+// Other non-indexable properties.
+constexpr std::string_view kUnindexedProperty1 = "unindexed1";
+constexpr std::string_view kUnindexedProperty2 = "unindexed2";
+
+constexpr DocumentId kDocumentId0 = 0;
+constexpr DocumentId kDocumentId1 = 1;
+
using Cardinality = PropertyConfigProto::Cardinality;
using DataType = PropertyConfigProto::DataType;
using ::testing::ElementsAre;
using ::testing::Eq;
using ::testing::IsEmpty;
+using ::testing::SizeIs;
using ::testing::Test;
#ifdef ENABLE_URL_TOKENIZER
@@ -146,6 +155,8 @@ class IndexProcessorTest : public Test {
ICING_ASSERT_OK_AND_ASSIGN(
index_, Index::Create(options, &filesystem_, &icing_filesystem_));
+ integer_index_ = std::make_unique<DummyNumericIndex<int64_t>>();
+
language_segmenter_factory::SegmenterOptions segmenter_options(ULOC_US);
ICING_ASSERT_OK_AND_ASSIGN(
lang_segmenter_,
@@ -191,12 +202,12 @@ class IndexProcessorTest : public Test {
TOKENIZER_PLAIN)
.SetCardinality(CARDINALITY_REPEATED))
.AddProperty(PropertyConfigBuilder()
- .SetName(kExactVerbatimProperty)
+ .SetName(kVerbatimExactProperty)
.SetDataTypeString(TERM_MATCH_EXACT,
TOKENIZER_VERBATIM)
.SetCardinality(CARDINALITY_REPEATED))
.AddProperty(PropertyConfigBuilder()
- .SetName(kPrefixedVerbatimProperty)
+ .SetName(kVerbatimPrefixedProperty)
.SetDataTypeString(TERM_MATCH_PREFIX,
TOKENIZER_VERBATIM)
.SetCardinality(CARDINALITY_REPEATED))
@@ -208,15 +219,19 @@ class IndexProcessorTest : public Test {
#ifdef ENABLE_URL_TOKENIZER
.AddProperty(
PropertyConfigBuilder()
- .SetName(kExactUrlProperty)
- .SetDataTypeString(MATCH_EXACT, TOKENIZER_URL)
+ .SetName(kUrlExactProperty)
+ .SetDataTypeString(TERM_MATCH_EXACT, TOKENIZER_URL)
.SetCardinality(CARDINALITY_REPEATED))
.AddProperty(
PropertyConfigBuilder()
- .SetName(kPrefixedUrlProperty)
- .SetDataTypeString(MATCH_PREFIX, TOKENIZER_URL)
+ .SetName(kUrlPrefixedProperty)
+ .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_URL)
.SetCardinality(CARDINALITY_REPEATED))
#endif // ENABLE_URL_TOKENIZER
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kIndexableIntegerProperty)
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_REPEATED))
.AddProperty(
PropertyConfigBuilder()
.SetName(kSubProperty)
@@ -236,7 +251,8 @@ class IndexProcessorTest : public Test {
ICING_ASSERT_OK_AND_ASSIGN(
index_processor_,
- IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_));
+ IndexProcessor::Create(normalizer_.get(), index_.get(),
+ integer_index_.get(), &fake_clock_));
mock_icing_filesystem_ = std::make_unique<IcingMockFilesystem>();
}
@@ -254,6 +270,7 @@ class IndexProcessorTest : public Test {
std::unique_ptr<LanguageSegmenter> lang_segmenter_;
std::unique_ptr<Normalizer> normalizer_;
std::unique_ptr<Index> index_;
+ std::unique_ptr<NumericIndex<int64_t>> integer_index_;
std::unique_ptr<SchemaStore> schema_store_;
std::unique_ptr<IndexProcessor> index_processor_;
};
@@ -282,11 +299,11 @@ std::vector<DocHitInfoTermFrequencyPair> GetHitsWithTermFrequency(
TEST_F(IndexProcessorTest, CreationWithNullPointerShouldFail) {
EXPECT_THAT(IndexProcessor::Create(/*normalizer=*/nullptr, index_.get(),
- &fake_clock_),
+ integer_index_.get(), &fake_clock_),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(IndexProcessor::Create(normalizer_.get(), /*index=*/nullptr,
- &fake_clock_),
+ integer_index_.get(), &fake_clock_),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
@@ -540,7 +557,8 @@ TEST_F(IndexProcessorTest, TooLongTokens) {
ICING_ASSERT_OK_AND_ASSIGN(
index_processor_,
- IndexProcessor::Create(normalizer.get(), index_.get(), &fake_clock_));
+ IndexProcessor::Create(normalizer.get(), index_.get(),
+ integer_index_.get(), &fake_clock_));
DocumentProto document =
DocumentBuilder()
@@ -773,7 +791,8 @@ TEST_F(IndexProcessorTest, IndexingDocAutomaticMerge) {
ICING_ASSERT_OK_AND_ASSIGN(
index_processor_,
- IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_));
+ IndexProcessor::Create(normalizer_.get(), index_.get(),
+ integer_index_.get(), &fake_clock_));
DocumentId doc_id = 0;
// Have determined experimentally that indexing 3373 documents with this text
// will cause the LiteIndex to fill up. Further indexing will fail unless the
@@ -829,7 +848,8 @@ TEST_F(IndexProcessorTest, IndexingDocMergeFailureResets) {
ICING_ASSERT_OK_AND_ASSIGN(
index_processor_,
- IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_));
+ IndexProcessor::Create(normalizer_.get(), index_.get(),
+ integer_index_.get(), &fake_clock_));
// 3. Index one document. This should fit in the LiteIndex without requiring a
// merge.
@@ -856,14 +876,14 @@ TEST_F(IndexProcessorTest, ExactVerbatimProperty) {
DocumentBuilder()
.SetKey("icing", "fake_type/1")
.SetSchema(std::string(kFakeType))
- .AddStringProperty(std::string(kExactVerbatimProperty),
+ .AddStringProperty(std::string(kVerbatimExactProperty),
"Hello, world!")
.Build();
ICING_ASSERT_OK_AND_ASSIGN(
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 1);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(1));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -876,7 +896,7 @@ TEST_F(IndexProcessorTest, ExactVerbatimProperty) {
std::vector<DocHitInfoTermFrequencyPair> hits =
GetHitsWithTermFrequency(std::move(itr));
std::unordered_map<SectionId, Hit::TermFrequency> expectedMap{
- {kExactVerbatimSectionId, 1}};
+ {kVerbatimExactSectionId, 1}};
EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfoWithTermFrequency(
kDocumentId0, expectedMap)));
@@ -887,14 +907,14 @@ TEST_F(IndexProcessorTest, PrefixVerbatimProperty) {
DocumentBuilder()
.SetKey("icing", "fake_type/1")
.SetSchema(std::string(kFakeType))
- .AddStringProperty(std::string(kPrefixedVerbatimProperty),
+ .AddStringProperty(std::string(kVerbatimPrefixedProperty),
"Hello, world!")
.Build();
ICING_ASSERT_OK_AND_ASSIGN(
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 1);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(1));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -908,7 +928,7 @@ TEST_F(IndexProcessorTest, PrefixVerbatimProperty) {
std::vector<DocHitInfoTermFrequencyPair> hits =
GetHitsWithTermFrequency(std::move(itr));
std::unordered_map<SectionId, Hit::TermFrequency> expectedMap{
- {kPrefixedVerbatimSectionId, 1}};
+ {kVerbatimPrefixedSectionId, 1}};
EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfoWithTermFrequency(
kDocumentId0, expectedMap)));
@@ -919,14 +939,14 @@ TEST_F(IndexProcessorTest, VerbatimPropertyDoesntMatchSubToken) {
DocumentBuilder()
.SetKey("icing", "fake_type/1")
.SetSchema(std::string(kFakeType))
- .AddStringProperty(std::string(kPrefixedVerbatimProperty),
+ .AddStringProperty(std::string(kVerbatimPrefixedProperty),
"Hello, world!")
.Build();
ICING_ASSERT_OK_AND_ASSIGN(
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 1);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(1));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -955,7 +975,7 @@ TEST_F(IndexProcessorTest, Rfc822PropertyExact) {
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 7);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -1000,7 +1020,7 @@ TEST_F(IndexProcessorTest, Rfc822PropertyExactShouldNotReturnPrefix) {
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 7);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -1028,7 +1048,7 @@ TEST_F(IndexProcessorTest, Rfc822PropertyPrefix) {
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 7);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -1069,7 +1089,7 @@ TEST_F(IndexProcessorTest, Rfc822PropertyNoMatch) {
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 7);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -1091,14 +1111,14 @@ TEST_F(IndexProcessorTest, ExactUrlProperty) {
DocumentBuilder()
.SetKey("icing", "fake_type/1")
.SetSchema(std::string(kFakeType))
- .AddStringProperty(std::string(kExactUrlProperty),
+ .AddStringProperty(std::string(kUrlExactProperty),
"http://www.google.com")
.Build();
ICING_ASSERT_OK_AND_ASSIGN(
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 7);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -1144,14 +1164,14 @@ TEST_F(IndexProcessorTest, ExactUrlPropertyDoesNotMatchPrefix) {
DocumentBuilder()
.SetKey("icing", "fake_type/1")
.SetSchema(std::string(kFakeType))
- .AddStringProperty(std::string(kExactUrlProperty),
+ .AddStringProperty(std::string(kUrlExactProperty),
"https://mail.google.com/calendar/render")
.Build();
ICING_ASSERT_OK_AND_ASSIGN(
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 8);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(8));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -1182,14 +1202,14 @@ TEST_F(IndexProcessorTest, PrefixUrlProperty) {
DocumentBuilder()
.SetKey("icing", "fake_type/1")
.SetSchema(std::string(kFakeType))
- .AddStringProperty(std::string(kPrefixedUrlProperty),
+ .AddStringProperty(std::string(kUrlPrefixedProperty),
"http://www.google.com")
.Build();
ICING_ASSERT_OK_AND_ASSIGN(
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 7);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -1229,14 +1249,14 @@ TEST_F(IndexProcessorTest, PrefixUrlPropertyNoMatch) {
DocumentBuilder()
.SetKey("icing", "fake_type/1")
.SetSchema(std::string(kFakeType))
- .AddStringProperty(std::string(kPrefixedUrlProperty),
+ .AddStringProperty(std::string(kUrlPrefixedProperty),
"https://mail.google.com/calendar/render")
.Build();
ICING_ASSERT_OK_AND_ASSIGN(
TokenizedDocument tokenized_document,
TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
document));
- EXPECT_THAT(tokenized_document.num_tokens(), 8);
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(8));
EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
IsOk());
@@ -1270,6 +1290,61 @@ TEST_F(IndexProcessorTest, PrefixUrlPropertyNoMatch) {
}
#endif // ENABLE_URL_TOKENIZER
+TEST_F(IndexProcessorTest, IndexableIntegerProperty) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddInt64Property(std::string(kIndexableIntegerProperty), 1, 2, 3, 4,
+ 5)
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+ // Expected to have 1 integer section.
+ EXPECT_THAT(tokenized_document.integer_sections(), SizeIs(1));
+
+ EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
+ IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<DocHitInfoIterator> itr,
+ integer_index_->GetIterator(kIndexableIntegerProperty, /*key_lower=*/1,
+ /*key_upper=*/5));
+
+ EXPECT_THAT(
+ GetHits(std::move(itr)),
+ ElementsAre(EqualsDocHitInfo(
+ kDocumentId0, std::vector<SectionId>{kIndexableIntegerSectionId})));
+}
+
+TEST_F(IndexProcessorTest, IndexableIntegerPropertyNoMatch) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddInt64Property(std::string(kIndexableIntegerProperty), 1, 2, 3, 4,
+ 5)
+ .Build();
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+ // Expected to have 1 integer section.
+ EXPECT_THAT(tokenized_document.integer_sections(), SizeIs(1));
+
+ EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0),
+ IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<DocHitInfoIterator> itr,
+ integer_index_->GetIterator(kIndexableIntegerProperty, /*key_lower=*/-1,
+ /*key_upper=*/0));
+
+ EXPECT_THAT(GetHits(std::move(itr)), IsEmpty());
+}
+
} // namespace
} // namespace lib
diff --git a/icing/index/integer-section-indexing-handler.cc b/icing/index/integer-section-indexing-handler.cc
new file mode 100644
index 0000000..a49b9f3
--- /dev/null
+++ b/icing/index/integer-section-indexing-handler.cc
@@ -0,0 +1,70 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/integer-section-indexing-handler.h"
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/schema/section-manager.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/util/logging.h"
+#include "icing/util/tokenized-document.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::Status IntegerSectionIndexingHandler::Handle(
+ const TokenizedDocument& tokenized_document, DocumentId document_id,
+ PutDocumentStatsProto* put_document_stats) {
+ // TODO(b/259744228):
+ // 1. Resolve last_added_document_id for index rebuilding before rollout
+ // 2. Set integer indexing latency and other stats
+
+ libtextclassifier3::Status status;
+ // We have to add integer sections into integer index in reverse order because
+ // sections are sorted by SectionId in ascending order, but BasicHit should be
+ // added in descending order of SectionId (posting list requirement).
+ for (auto riter = tokenized_document.integer_sections().rbegin();
+ riter != tokenized_document.integer_sections().rend(); ++riter) {
+ const Section<int64_t>& section = *riter;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor = integer_index_.Edit(
+ section.metadata.path, document_id, section.metadata.id);
+
+ for (int64_t key : section.content) {
+ status = editor->BufferKey(key);
+ if (!status.ok()) {
+ ICING_LOG(WARNING)
+ << "Failed to buffer keys into integer index due to: "
+ << status.error_message();
+ break;
+ }
+ }
+ if (!status.ok()) {
+ break;
+ }
+
+ // Add all the seen keys to the integer index.
+ status = editor->IndexAllBufferedKeys();
+ if (!status.ok()) {
+ ICING_LOG(WARNING) << "Failed to add keys into integer index due to: "
+ << status.error_message();
+ break;
+ }
+ }
+
+ return status;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/integer-section-indexing-handler.h b/icing/index/integer-section-indexing-handler.h
new file mode 100644
index 0000000..dd0e46c
--- /dev/null
+++ b/icing/index/integer-section-indexing-handler.h
@@ -0,0 +1,55 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_INTEGER_SECTION_INDEXING_HANDLER_H_
+#define ICING_INDEX_INTEGER_SECTION_INDEXING_HANDLER_H_
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/index/numeric/numeric-index.h"
+#include "icing/index/section-indexing-handler.h"
+#include "icing/store/document-id.h"
+#include "icing/util/clock.h"
+#include "icing/util/tokenized-document.h"
+
+namespace icing {
+namespace lib {
+
+class IntegerSectionIndexingHandler : public SectionIndexingHandler {
+ public:
+ explicit IntegerSectionIndexingHandler(const Clock* clock,
+ NumericIndex<int64_t>* integer_index)
+ : SectionIndexingHandler(clock), integer_index_(*integer_index) {}
+
+ ~IntegerSectionIndexingHandler() override = default;
+
+ // TODO(b/259744228): update this documentation after resolving
+ // last_added_document_id problem.
+ // Handles the integer indexing process: add hits into the integer index for
+ // all contents in tokenized_document.integer_sections.
+ //
+ /// Returns:
+ // - OK on success
+ // - Any NumericIndex<int64_t>::Editor errors.
+ libtextclassifier3::Status Handle(
+ const TokenizedDocument& tokenized_document, DocumentId document_id,
+ PutDocumentStatsProto* put_document_stats) override;
+
+ private:
+ NumericIndex<int64_t>& integer_index_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_INTEGER_SECTION_INDEXING_HANDLER_H_
diff --git a/icing/index/iterator/doc-hit-info-iterator-test-util.h b/icing/index/iterator/doc-hit-info-iterator-test-util.h
index ed6db23..fe3a4b9 100644
--- a/icing/index/iterator/doc-hit-info-iterator-test-util.h
+++ b/icing/index/iterator/doc-hit-info-iterator-test-util.h
@@ -15,6 +15,7 @@
#ifndef ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
#define ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
+#include <cinttypes>
#include <string>
#include <utility>
#include <vector>
diff --git a/icing/index/main/doc-hit-info-iterator-term-main.cc b/icing/index/main/doc-hit-info-iterator-term-main.cc
index 098a450..f06124a 100644
--- a/icing/index/main/doc-hit-info-iterator-term-main.cc
+++ b/icing/index/main/doc-hit-info-iterator-term-main.cc
@@ -22,7 +22,7 @@
#include "icing/absl_ports/str_cat.h"
#include "icing/file/posting_list/posting-list-identifier.h"
#include "icing/index/hit/doc-hit-info.h"
-#include "icing/index/main/posting-list-accessor.h"
+#include "icing/index/main/posting-list-hit-accessor.h"
#include "icing/legacy/core/icing-string-util.h"
#include "icing/schema/section.h"
#include "icing/store/document-id.h"
diff --git a/icing/index/main/doc-hit-info-iterator-term-main.h b/icing/index/main/doc-hit-info-iterator-term-main.h
index c1b289f..6a21dc3 100644
--- a/icing/index/main/doc-hit-info-iterator-term-main.h
+++ b/icing/index/main/doc-hit-info-iterator-term-main.h
@@ -23,7 +23,7 @@
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/index/main/main-index.h"
-#include "icing/index/main/posting-list-accessor.h"
+#include "icing/index/main/posting-list-hit-accessor.h"
#include "icing/schema/section.h"
namespace icing {
@@ -91,7 +91,7 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator {
const std::string term_;
// The accessor of the posting list chain for the requested term.
- std::unique_ptr<PostingListAccessor> posting_list_accessor_;
+ std::unique_ptr<PostingListHitAccessor> posting_list_accessor_;
MainIndex* main_index_;
// Stores hits retrieved from the index. This may only be a subset of the hits
diff --git a/icing/index/main/main-index.cc b/icing/index/main/main-index.cc
index 1c61bfa..fd1630a 100644
--- a/icing/index/main/main-index.cc
+++ b/icing/index/main/main-index.cc
@@ -160,19 +160,16 @@ IndexStorageInfoProto MainIndex::GetStorageInfo(
return storage_info;
}
-libtextclassifier3::StatusOr<std::unique_ptr<PostingListAccessor>>
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>>
MainIndex::GetAccessorForExactTerm(const std::string& term) {
PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid;
if (!main_lexicon_->Find(term.c_str(), &posting_list_id)) {
return absl_ports::NotFoundError(IcingStringUtil::StringPrintf(
"Term %s is not present in main lexicon.", term.c_str()));
}
- ICING_ASSIGN_OR_RETURN(
- PostingListAccessor accessor,
- PostingListAccessor::CreateFromExisting(
- flash_index_storage_.get(), posting_list_used_hit_serializer_.get(),
- posting_list_id));
- return std::make_unique<PostingListAccessor>(std::move(accessor));
+ return PostingListHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), posting_list_used_hit_serializer_.get(),
+ posting_list_id);
}
libtextclassifier3::StatusOr<MainIndex::GetPrefixAccessorResult>
@@ -202,13 +199,11 @@ MainIndex::GetAccessorForPrefixTerm(const std::string& prefix) {
PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid;
memcpy(&posting_list_id, main_itr.GetValue(), sizeof(posting_list_id));
ICING_ASSIGN_OR_RETURN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::CreateFromExisting(
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::CreateFromExisting(
flash_index_storage_.get(), posting_list_used_hit_serializer_.get(),
posting_list_id));
- GetPrefixAccessorResult result = {
- std::make_unique<PostingListAccessor>(std::move(pl_accessor)), exact};
- return result;
+ return GetPrefixAccessorResult(std::move(pl_accessor), exact);
}
// TODO(tjbarron): Implement a method PropertyReadersAll.HasAnyProperty().
@@ -245,12 +240,12 @@ MainIndex::FindTermsByPrefix(
PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid;
memcpy(&posting_list_id, term_iterator.GetValue(), sizeof(posting_list_id));
ICING_ASSIGN_OR_RETURN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::CreateFromExisting(
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::CreateFromExisting(
flash_index_storage_.get(), posting_list_used_hit_serializer_.get(),
posting_list_id));
ICING_ASSIGN_OR_RETURN(std::vector<Hit> hits,
- pl_accessor.GetNextHitsBatch());
+ pl_accessor->GetNextHitsBatch());
while (!hits.empty()) {
for (const Hit& hit : hits) {
// Check whether this Hit is desired.
@@ -297,7 +292,7 @@ MainIndex::FindTermsByPrefix(
// The term is desired and no need to be scored.
break;
}
- ICING_ASSIGN_OR_RETURN(hits, pl_accessor.GetNextHitsBatch());
+ ICING_ASSIGN_OR_RETURN(hits, pl_accessor->GetNextHitsBatch());
}
if (score > 0) {
term_metadata_list.push_back(TermMetadata(term_iterator.GetKey(), score));
@@ -559,14 +554,14 @@ libtextclassifier3::Status MainIndex::AddHits(
memcpy(&backfill_posting_list_id,
main_lexicon_->GetValueAtIndex(other_tvi_main_tvi_pair.second),
sizeof(backfill_posting_list_id));
- ICING_ASSIGN_OR_RETURN(
- PostingListAccessor hit_accum,
- PostingListAccessor::Create(flash_index_storage_.get(),
- posting_list_used_hit_serializer_.get()));
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<PostingListHitAccessor> hit_accum,
+ PostingListHitAccessor::Create(
+ flash_index_storage_.get(),
+ posting_list_used_hit_serializer_.get()));
ICING_RETURN_IF_ERROR(
- AddPrefixBackfillHits(backfill_posting_list_id, &hit_accum));
+ AddPrefixBackfillHits(backfill_posting_list_id, hit_accum.get()));
PostingListAccessor::FinalizeResult result =
- PostingListAccessor::Finalize(std::move(hit_accum));
+ std::move(*hit_accum).Finalize();
if (result.id.is_valid()) {
main_lexicon_->SetValueAtIndex(other_tvi_main_tvi_pair.first, &result.id);
}
@@ -578,12 +573,12 @@ libtextclassifier3::Status MainIndex::AddHits(
libtextclassifier3::Status MainIndex::AddHitsForTerm(
uint32_t tvi, PostingListIdentifier backfill_posting_list_id,
const TermIdHitPair* hit_elements, size_t len) {
- // 1. Create a PostingListAccessor - either from the pre-existing block, if
+ // 1. Create a PostingListHitAccessor - either from the pre-existing block, if
// one exists, or from scratch.
PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid;
memcpy(&posting_list_id, main_lexicon_->GetValueAtIndex(tvi),
sizeof(posting_list_id));
- std::unique_ptr<PostingListAccessor> pl_accessor;
+ std::unique_ptr<PostingListHitAccessor> pl_accessor;
if (posting_list_id.is_valid()) {
if (posting_list_id.block_index() >= flash_index_storage_->num_blocks()) {
ICING_LOG(ERROR) << "Index dropped hits. Invalid block index "
@@ -597,18 +592,16 @@ libtextclassifier3::Status MainIndex::AddHitsForTerm(
"Valid posting list has an invalid block index!");
}
ICING_ASSIGN_OR_RETURN(
- PostingListAccessor tmp,
- PostingListAccessor::CreateFromExisting(
+ pl_accessor,
+ PostingListHitAccessor::CreateFromExisting(
flash_index_storage_.get(), posting_list_used_hit_serializer_.get(),
posting_list_id));
- pl_accessor = std::make_unique<PostingListAccessor>(std::move(tmp));
} else {
// New posting list.
- ICING_ASSIGN_OR_RETURN(
- PostingListAccessor tmp,
- PostingListAccessor::Create(flash_index_storage_.get(),
- posting_list_used_hit_serializer_.get()));
- pl_accessor = std::make_unique<PostingListAccessor>(std::move(tmp));
+ ICING_ASSIGN_OR_RETURN(pl_accessor,
+ PostingListHitAccessor::Create(
+ flash_index_storage_.get(),
+ posting_list_used_hit_serializer_.get()));
}
// 2. Backfill any hits if necessary.
@@ -625,7 +618,7 @@ libtextclassifier3::Status MainIndex::AddHitsForTerm(
// 4. Finalize this posting list and put its identifier in the lexicon.
PostingListAccessor::FinalizeResult result =
- PostingListAccessor::Finalize(std::move(*pl_accessor));
+ std::move(*pl_accessor).Finalize();
if (result.id.is_valid()) {
main_lexicon_->SetValueAtIndex(tvi, &result.id);
}
@@ -634,18 +627,18 @@ libtextclassifier3::Status MainIndex::AddHitsForTerm(
libtextclassifier3::Status MainIndex::AddPrefixBackfillHits(
PostingListIdentifier backfill_posting_list_id,
- PostingListAccessor* hit_accum) {
+ PostingListHitAccessor* hit_accum) {
ICING_ASSIGN_OR_RETURN(
- PostingListAccessor backfill_accessor,
- PostingListAccessor::CreateFromExisting(
+ std::unique_ptr<PostingListHitAccessor> backfill_accessor,
+ PostingListHitAccessor::CreateFromExisting(
flash_index_storage_.get(), posting_list_used_hit_serializer_.get(),
backfill_posting_list_id));
std::vector<Hit> backfill_hits;
ICING_ASSIGN_OR_RETURN(std::vector<Hit> tmp,
- backfill_accessor.GetNextHitsBatch());
+ backfill_accessor->GetNextHitsBatch());
while (!tmp.empty()) {
std::copy(tmp.begin(), tmp.end(), std::back_inserter(backfill_hits));
- ICING_ASSIGN_OR_RETURN(tmp, backfill_accessor.GetNextHitsBatch());
+ ICING_ASSIGN_OR_RETURN(tmp, backfill_accessor->GetNextHitsBatch());
}
Hit last_added_hit;
@@ -738,7 +731,7 @@ libtextclassifier3::Status MainIndex::Optimize(
libtextclassifier3::StatusOr<DocumentId> MainIndex::TransferAndAddHits(
const std::vector<DocumentId>& document_id_old_to_new, const char* term,
- PostingListAccessor& old_pl_accessor, MainIndex* new_index) {
+ PostingListHitAccessor& old_pl_accessor, MainIndex* new_index) {
std::vector<Hit> new_hits;
bool has_no_exact_hits = true;
bool has_hits_in_prefix_section = false;
@@ -776,15 +769,14 @@ libtextclassifier3::StatusOr<DocumentId> MainIndex::TransferAndAddHits(
}
ICING_ASSIGN_OR_RETURN(
- PostingListAccessor hit_accum,
- PostingListAccessor::Create(
+ std::unique_ptr<PostingListHitAccessor> hit_accum,
+ PostingListHitAccessor::Create(
new_index->flash_index_storage_.get(),
new_index->posting_list_used_hit_serializer_.get()));
for (auto itr = new_hits.rbegin(); itr != new_hits.rend(); ++itr) {
- ICING_RETURN_IF_ERROR(hit_accum.PrependHit(*itr));
+ ICING_RETURN_IF_ERROR(hit_accum->PrependHit(*itr));
}
- PostingListAccessor::FinalizeResult result =
- PostingListAccessor::Finalize(std::move(hit_accum));
+ PostingListAccessor::FinalizeResult result = std::move(*hit_accum).Finalize();
if (!result.id.is_valid()) {
return absl_ports::InternalError(
absl_ports::StrCat("Failed to add translated hits for term: ", term));
@@ -826,14 +818,14 @@ libtextclassifier3::Status MainIndex::TransferIndex(
continue;
}
ICING_ASSIGN_OR_RETURN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::CreateFromExisting(
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::CreateFromExisting(
flash_index_storage_.get(), posting_list_used_hit_serializer_.get(),
posting_list_id));
ICING_ASSIGN_OR_RETURN(
DocumentId curr_largest_document_id,
TransferAndAddHits(document_id_old_to_new, term_itr.GetKey(),
- pl_accessor, new_index));
+ *pl_accessor, new_index));
if (curr_largest_document_id == kInvalidDocumentId) {
continue;
}
diff --git a/icing/index/main/main-index.h b/icing/index/main/main-index.h
index e257a77..70ae6f6 100644
--- a/icing/index/main/main-index.h
+++ b/icing/index/main/main-index.h
@@ -22,7 +22,7 @@
#include "icing/file/filesystem.h"
#include "icing/file/posting_list/flash-index-storage.h"
#include "icing/index/lite/term-id-hit-pair.h"
-#include "icing/index/main/posting-list-accessor.h"
+#include "icing/index/main/posting-list-hit-accessor.h"
#include "icing/index/main/posting-list-used-hit-serializer.h"
#include "icing/index/term-id-codec.h"
#include "icing/index/term-metadata.h"
@@ -48,27 +48,31 @@ class MainIndex {
const std::string& index_directory, const Filesystem* filesystem,
const IcingFilesystem* icing_filesystem);
- // Get a PostingListAccessor that holds the posting list chain for 'term'.
+ // Get a PostingListHitAccessor that holds the posting list chain for 'term'.
//
// RETURNS:
- // - On success, a valid PostingListAccessor
+ // - On success, a valid PostingListHitAccessor
// - NOT_FOUND if term is not present in the main index.
- libtextclassifier3::StatusOr<std::unique_ptr<PostingListAccessor>>
+ libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>>
GetAccessorForExactTerm(const std::string& term);
- // Get a PostingListAccessor for 'prefix'.
+ // Get a PostingListHitAccessor for 'prefix'.
//
// RETURNS:
- // - On success, a result containing a valid PostingListAccessor.
+ // - On success, a result containing a valid PostingListHitAccessor.
// - NOT_FOUND if neither 'prefix' nor any terms for which 'prefix' is a
// prefix are present in the main index.
struct GetPrefixAccessorResult {
- // A PostingListAccessor that holds the posting list chain for the term
+ // A PostingListHitAccessor that holds the posting list chain for the term
// that best represents 'prefix' in the main index.
- std::unique_ptr<PostingListAccessor> accessor;
+ std::unique_ptr<PostingListHitAccessor> accessor;
// True if the returned posting list chain is for 'prefix' or false if the
// returned posting list chain is for a term for which 'prefix' is a prefix.
bool exact;
+
+ explicit GetPrefixAccessorResult(
+ std::unique_ptr<PostingListHitAccessor> accessor_in, bool exact_in)
+ : accessor(std::move(accessor_in)), exact(exact_in) {}
};
libtextclassifier3::StatusOr<GetPrefixAccessorResult>
GetAccessorForPrefixTerm(const std::string& prefix);
@@ -302,7 +306,7 @@ class MainIndex {
// posting list.
libtextclassifier3::Status AddPrefixBackfillHits(
PostingListIdentifier backfill_posting_list_id,
- PostingListAccessor* hit_accum);
+ PostingListHitAccessor* hit_accum);
// Transfer hits from old_pl_accessor to new_index for term.
//
@@ -311,7 +315,7 @@ class MainIndex {
// INTERNAL_ERROR on IO error
static libtextclassifier3::StatusOr<DocumentId> TransferAndAddHits(
const std::vector<DocumentId>& document_id_old_to_new, const char* term,
- PostingListAccessor& old_pl_accessor, MainIndex* new_index);
+ PostingListHitAccessor& old_pl_accessor, MainIndex* new_index);
// Transfer hits from the current main index to new_index.
//
diff --git a/icing/index/main/posting-list-accessor.cc b/icing/index/main/posting-list-accessor.cc
deleted file mode 100644
index 06ab0a1..0000000
--- a/icing/index/main/posting-list-accessor.cc
+++ /dev/null
@@ -1,215 +0,0 @@
-// Copyright (C) 2019 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "icing/index/main/posting-list-accessor.h"
-
-#include <cstdint>
-#include <memory>
-#include <vector>
-
-#include "icing/absl_ports/canonical_errors.h"
-#include "icing/file/posting_list/flash-index-storage.h"
-#include "icing/file/posting_list/index-block.h"
-#include "icing/file/posting_list/posting-list-identifier.h"
-#include "icing/file/posting_list/posting-list-used.h"
-#include "icing/index/main/posting-list-used-hit-serializer.h"
-#include "icing/util/status-macros.h"
-
-namespace icing {
-namespace lib {
-
-libtextclassifier3::StatusOr<PostingListAccessor> PostingListAccessor::Create(
- FlashIndexStorage *storage, PostingListUsedHitSerializer *serializer) {
- uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes(
- storage->block_size(), serializer->GetDataTypeBytes());
- std::unique_ptr<uint8_t[]> posting_list_buffer_array =
- std::make_unique<uint8_t[]>(max_posting_list_bytes);
- ICING_ASSIGN_OR_RETURN(
- PostingListUsed posting_list_buffer,
- PostingListUsed::CreateFromUnitializedRegion(
- serializer, posting_list_buffer_array.get(), max_posting_list_bytes));
- return PostingListAccessor(storage, serializer,
- std::move(posting_list_buffer_array),
- std::move(posting_list_buffer));
-}
-
-libtextclassifier3::StatusOr<PostingListAccessor>
-PostingListAccessor::CreateFromExisting(
- FlashIndexStorage *storage, PostingListUsedHitSerializer *serializer,
- PostingListIdentifier existing_posting_list_id) {
- // Our posting_list_buffer_ will start as empty.
- ICING_ASSIGN_OR_RETURN(PostingListAccessor pl_accessor,
- Create(storage, serializer));
- ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
- storage->GetPostingList(existing_posting_list_id));
- pl_accessor.preexisting_posting_list_ =
- std::make_unique<PostingListHolder>(std::move(holder));
- return pl_accessor;
-}
-
-// Returns the next batch of hits for the provided posting list.
-libtextclassifier3::StatusOr<std::vector<Hit>>
-PostingListAccessor::GetNextHitsBatch() {
- if (preexisting_posting_list_ == nullptr) {
- if (has_reached_posting_list_chain_end_) {
- return std::vector<Hit>();
- }
- return absl_ports::FailedPreconditionError(
- "Cannot retrieve hits from a PostingListAccessor that was not created "
- "from a preexisting posting list.");
- }
- ICING_ASSIGN_OR_RETURN(
- std::vector<Hit> batch,
- serializer_->GetHits(&preexisting_posting_list_->posting_list));
- uint32_t next_block_index;
- // Posting lists will only be chained when they are max-sized, in which case
- // block.next_block_index() will point to the next block for the next posting
- // list. Otherwise, block.next_block_index() can be kInvalidBlockIndex or be
- // used to point to the next free list block, which is not relevant here.
- if (preexisting_posting_list_->block.max_num_posting_lists() == 1) {
- next_block_index = preexisting_posting_list_->block.next_block_index();
- } else {
- next_block_index = kInvalidBlockIndex;
- }
- if (next_block_index != kInvalidBlockIndex) {
- PostingListIdentifier next_posting_list_id(
- next_block_index, /*posting_list_index=*/0,
- preexisting_posting_list_->block.posting_list_index_bits());
- ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
- storage_->GetPostingList(next_posting_list_id));
- preexisting_posting_list_ =
- std::make_unique<PostingListHolder>(std::move(holder));
- } else {
- has_reached_posting_list_chain_end_ = true;
- preexisting_posting_list_.reset();
- }
- return batch;
-}
-
-libtextclassifier3::Status PostingListAccessor::PrependHit(const Hit &hit) {
- PostingListUsed &active_pl = (preexisting_posting_list_ != nullptr)
- ? preexisting_posting_list_->posting_list
- : posting_list_buffer_;
- libtextclassifier3::Status status = serializer_->PrependHit(&active_pl, hit);
- if (!absl_ports::IsResourceExhausted(status)) {
- return status;
- }
- // There is no more room to add hits to this current posting list! Therefore,
- // we need to either move those hits to a larger posting list or flush this
- // posting list and create another max-sized posting list in the chain.
- if (preexisting_posting_list_ != nullptr) {
- FlushPreexistingPostingList();
- } else {
- ICING_RETURN_IF_ERROR(FlushInMemoryPostingList());
- }
-
- // Re-add hit. Should always fit since we just cleared posting_list_buffer_.
- // It's fine to explicitly reference posting_list_buffer_ here because there's
- // no way of reaching this line while preexisting_posting_list_ is still in
- // use.
- return serializer_->PrependHit(&posting_list_buffer_, hit);
-}
-
-void PostingListAccessor::FlushPreexistingPostingList() {
- if (preexisting_posting_list_->block.max_num_posting_lists() == 1) {
- // If this is a max-sized posting list, then just keep track of the id for
- // chaining. It'll be flushed to disk when preexisting_posting_list_ is
- // destructed.
- prev_block_identifier_ = preexisting_posting_list_->id;
- } else {
- // If this is NOT a max-sized posting list, then our hits have outgrown this
- // particular posting list. Move the hits into the in-memory posting list
- // and free this posting list.
- //
- // Move will always succeed since posting_list_buffer_ is max_pl_bytes.
- serializer_->MoveFrom(/*dst=*/&posting_list_buffer_,
- /*src=*/&preexisting_posting_list_->posting_list);
-
- // Now that all the contents of this posting list have been copied, there's
- // no more use for it. Make it available to be used for another posting
- // list.
- storage_->FreePostingList(std::move(*preexisting_posting_list_));
- }
- preexisting_posting_list_.reset();
-}
-
-libtextclassifier3::Status PostingListAccessor::FlushInMemoryPostingList() {
- // We exceeded max_pl_bytes(). Need to flush posting_list_buffer_ and update
- // the chain.
- uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes(
- storage_->block_size(), serializer_->GetDataTypeBytes());
- ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
- storage_->AllocatePostingList(max_posting_list_bytes));
- holder.block.set_next_block_index(prev_block_identifier_.block_index());
- prev_block_identifier_ = holder.id;
- return serializer_->MoveFrom(/*dst=*/&holder.posting_list,
- /*src=*/&posting_list_buffer_);
-}
-
-PostingListAccessor::FinalizeResult PostingListAccessor::Finalize(
- PostingListAccessor accessor) {
- if (accessor.preexisting_posting_list_ != nullptr) {
- // Our hits are already in an existing posting list. Nothing else to do, but
- // return its id.
- FinalizeResult result = {libtextclassifier3::Status::OK,
- accessor.preexisting_posting_list_->id};
- return result;
- }
- if (accessor.serializer_->GetBytesUsed(&accessor.posting_list_buffer_) <= 0) {
- FinalizeResult result = {absl_ports::InvalidArgumentError(
- "Can't finalize an empty PostingListAccessor. "
- "There's nothing to Finalize!"),
- PostingListIdentifier::kInvalid};
- return result;
- }
- uint32_t posting_list_bytes =
- accessor.serializer_->GetMinPostingListSizeToFit(
- &accessor.posting_list_buffer_);
- if (accessor.prev_block_identifier_.is_valid()) {
- posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes(
- accessor.storage_->block_size(),
- accessor.serializer_->GetDataTypeBytes());
- }
- auto holder_or = accessor.storage_->AllocatePostingList(posting_list_bytes);
- if (!holder_or.ok()) {
- FinalizeResult result = {holder_or.status(),
- accessor.prev_block_identifier_};
- return result;
- }
- PostingListHolder holder = std::move(holder_or).ValueOrDie();
- if (accessor.prev_block_identifier_.is_valid()) {
- holder.block.set_next_block_index(
- accessor.prev_block_identifier_.block_index());
- }
-
- // Move to allocated area. This should never actually return an error. We know
- // that editor.posting_list() is valid because it wouldn't have successfully
- // returned by AllocatePostingList if it wasn't. We know posting_list_buffer_
- // is valid because we created it in-memory. And finally, we know that the
- // hits from posting_list_buffer_ will fit in editor.posting_list() because we
- // requested it be at at least posting_list_bytes large.
- auto status =
- accessor.serializer_->MoveFrom(/*dst=*/&holder.posting_list,
- /*src=*/&accessor.posting_list_buffer_);
- if (!status.ok()) {
- FinalizeResult result = {std::move(status),
- accessor.prev_block_identifier_};
- return result;
- }
- FinalizeResult result = {libtextclassifier3::Status::OK, holder.id};
- return result;
-}
-
-} // namespace lib
-} // namespace icing
diff --git a/icing/index/main/posting-list-hit-accessor.cc b/icing/index/main/posting-list-hit-accessor.cc
new file mode 100644
index 0000000..30b2410
--- /dev/null
+++ b/icing/index/main/posting-list-hit-accessor.cc
@@ -0,0 +1,126 @@
+// Copyright (C) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/main/posting-list-hit-accessor.h"
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/index-block.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/main/posting-list-used-hit-serializer.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>>
+PostingListHitAccessor::Create(FlashIndexStorage *storage,
+ PostingListUsedHitSerializer *serializer) {
+ uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes(
+ storage->block_size(), serializer->GetDataTypeBytes());
+ std::unique_ptr<uint8_t[]> posting_list_buffer_array =
+ std::make_unique<uint8_t[]>(max_posting_list_bytes);
+ ICING_ASSIGN_OR_RETURN(
+ PostingListUsed posting_list_buffer,
+ PostingListUsed::CreateFromUnitializedRegion(
+ serializer, posting_list_buffer_array.get(), max_posting_list_bytes));
+ return std::unique_ptr<PostingListHitAccessor>(new PostingListHitAccessor(
+ storage, serializer, std::move(posting_list_buffer_array),
+ std::move(posting_list_buffer)));
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>>
+PostingListHitAccessor::CreateFromExisting(
+ FlashIndexStorage *storage, PostingListUsedHitSerializer *serializer,
+ PostingListIdentifier existing_posting_list_id) {
+ // Our posting_list_buffer_ will start as empty.
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ Create(storage, serializer));
+ ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
+ storage->GetPostingList(existing_posting_list_id));
+ pl_accessor->preexisting_posting_list_ =
+ std::make_unique<PostingListHolder>(std::move(holder));
+ return pl_accessor;
+}
+
+// Returns the next batch of hits for the provided posting list.
+libtextclassifier3::StatusOr<std::vector<Hit>>
+PostingListHitAccessor::GetNextHitsBatch() {
+ if (preexisting_posting_list_ == nullptr) {
+ if (has_reached_posting_list_chain_end_) {
+ return std::vector<Hit>();
+ }
+ return absl_ports::FailedPreconditionError(
+ "Cannot retrieve hits from a PostingListHitAccessor that was not "
+ "created from a preexisting posting list.");
+ }
+ ICING_ASSIGN_OR_RETURN(
+ std::vector<Hit> batch,
+ serializer_->GetHits(&preexisting_posting_list_->posting_list));
+ uint32_t next_block_index;
+ // Posting lists will only be chained when they are max-sized, in which case
+ // block.next_block_index() will point to the next block for the next posting
+ // list. Otherwise, block.next_block_index() can be kInvalidBlockIndex or be
+ // used to point to the next free list block, which is not relevant here.
+ if (preexisting_posting_list_->block.max_num_posting_lists() == 1) {
+ next_block_index = preexisting_posting_list_->block.next_block_index();
+ } else {
+ next_block_index = kInvalidBlockIndex;
+ }
+ if (next_block_index != kInvalidBlockIndex) {
+ PostingListIdentifier next_posting_list_id(
+ next_block_index, /*posting_list_index=*/0,
+ preexisting_posting_list_->block.posting_list_index_bits());
+ ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
+ storage_->GetPostingList(next_posting_list_id));
+ preexisting_posting_list_ =
+ std::make_unique<PostingListHolder>(std::move(holder));
+ } else {
+ has_reached_posting_list_chain_end_ = true;
+ preexisting_posting_list_.reset();
+ }
+ return batch;
+}
+
+libtextclassifier3::Status PostingListHitAccessor::PrependHit(const Hit &hit) {
+ PostingListUsed &active_pl = (preexisting_posting_list_ != nullptr)
+ ? preexisting_posting_list_->posting_list
+ : posting_list_buffer_;
+ libtextclassifier3::Status status = serializer_->PrependHit(&active_pl, hit);
+ if (!absl_ports::IsResourceExhausted(status)) {
+ return status;
+ }
+ // There is no more room to add hits to this current posting list! Therefore,
+ // we need to either move those hits to a larger posting list or flush this
+ // posting list and create another max-sized posting list in the chain.
+ if (preexisting_posting_list_ != nullptr) {
+ FlushPreexistingPostingList();
+ } else {
+ ICING_RETURN_IF_ERROR(FlushInMemoryPostingList());
+ }
+
+ // Re-add hit. Should always fit since we just cleared posting_list_buffer_.
+ // It's fine to explicitly reference posting_list_buffer_ here because there's
+ // no way of reaching this line while preexisting_posting_list_ is still in
+ // use.
+ return serializer_->PrependHit(&posting_list_buffer_, hit);
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/main/posting-list-hit-accessor.h b/icing/index/main/posting-list-hit-accessor.h
new file mode 100644
index 0000000..953f2bd
--- /dev/null
+++ b/icing/index/main/posting-list-hit-accessor.h
@@ -0,0 +1,103 @@
+// Copyright (C) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_POSTING_LIST_HIT_ACCESSOR_H_
+#define ICING_INDEX_POSTING_LIST_HIT_ACCESSOR_H_
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-accessor.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/hit/hit.h"
+#include "icing/index/main/posting-list-used-hit-serializer.h"
+
+namespace icing {
+namespace lib {
+
+// This class is used to provide a simple abstraction for adding hits to posting
+// lists. PostingListHitAccessor handles 1) selection of properly-sized posting
+// lists for the accumulated hits during Finalize() and 2) chaining of max-sized
+// posting lists.
+class PostingListHitAccessor : public PostingListAccessor {
+ public:
+ // Creates an empty PostingListHitAccessor.
+ //
+ // RETURNS:
+ // - On success, a valid unique_ptr instance of PostingListHitAccessor
+ // - INVALID_ARGUMENT error if storage has an invalid block_size.
+ static libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>>
+ Create(FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer);
+
+ // Create a PostingListHitAccessor with an existing posting list identified by
+ // existing_posting_list_id.
+ //
+ // The PostingListHitAccessor will add hits to this posting list until it is
+ // necessary either to 1) chain the posting list (if it is max-sized) or 2)
+ // move its hits to a larger posting list.
+ //
+ // RETURNS:
+ // - On success, a valid unique_ptr instance of PostingListHitAccessor
+ // - INVALID_ARGUMENT if storage has an invalid block_size.
+ static libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>>
+ CreateFromExisting(FlashIndexStorage* storage,
+ PostingListUsedHitSerializer* serializer,
+ PostingListIdentifier existing_posting_list_id);
+
+ PostingListUsedSerializer* GetSerializer() override { return serializer_; }
+
+ // Retrieve the next batch of hits for the posting list chain
+ //
+ // RETURNS:
+ // - On success, a vector of hits in the posting list chain
+ // - INTERNAL if called on an instance of PostingListHitAccessor that was
+ // created via PostingListHitAccessor::Create, if unable to read the next
+ // posting list in the chain or if the posting list has been corrupted
+ // somehow.
+ libtextclassifier3::StatusOr<std::vector<Hit>> GetNextHitsBatch();
+
+ // Prepend one hit. This may result in flushing the posting list to disk (if
+ // the PostingListHitAccessor holds a max-sized posting list that is full) or
+ // freeing a pre-existing posting list if it is too small to fit all hits
+ // necessary.
+ //
+ // RETURNS:
+ // - OK, on success
+ // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the
+ // previously added hit.
+ // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new
+ // posting list.
+ libtextclassifier3::Status PrependHit(const Hit& hit);
+
+ private:
+ explicit PostingListHitAccessor(
+ FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer,
+ std::unique_ptr<uint8_t[]> posting_list_buffer_array,
+ PostingListUsed posting_list_buffer)
+ : PostingListAccessor(storage, std::move(posting_list_buffer_array),
+ std::move(posting_list_buffer)),
+ serializer_(serializer) {}
+
+ PostingListUsedHitSerializer* serializer_; // Does not own.
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_POSTING_LIST_HIT_ACCESSOR_H_
diff --git a/icing/index/main/posting-list-accessor_test.cc b/icing/index/main/posting-list-hit-accessor_test.cc
index 3145420..fcdd580 100644
--- a/icing/index/main/posting-list-accessor_test.cc
+++ b/icing/index/main/posting-list-hit-accessor_test.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "icing/index/main/posting-list-accessor.h"
+#include "icing/index/main/posting-list-hit-accessor.h"
#include <cstdint>
@@ -40,7 +40,7 @@ using ::testing::Eq;
using ::testing::Lt;
using ::testing::SizeIs;
-class PostingListAccessorTest : public ::testing::Test {
+class PostingListHitAccessorTest : public ::testing::Test {
protected:
void SetUp() override {
test_dir_ = GetTestTempDir() + "/test_dir";
@@ -71,19 +71,19 @@ class PostingListAccessorTest : public ::testing::Test {
std::unique_ptr<FlashIndexStorage> flash_index_storage_;
};
-TEST_F(PostingListAccessorTest, HitsAddAndRetrieveProperly) {
+TEST_F(PostingListHitAccessorTest, HitsAddAndRetrieveProperly) {
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::Create(flash_index_storage_.get(),
- serializer_.get()));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
// Add some hits! Any hits!
std::vector<Hit> hits1 =
CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1);
for (const Hit& hit : hits1) {
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
}
PostingListAccessor::FinalizeResult result =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
ICING_EXPECT_OK(result.status);
EXPECT_THAT(result.id.block_index(), Eq(1));
EXPECT_THAT(result.id.posting_list_index(), Eq(0));
@@ -96,16 +96,16 @@ TEST_F(PostingListAccessorTest, HitsAddAndRetrieveProperly) {
EXPECT_THAT(pl_holder.block.next_block_index(), Eq(kInvalidBlockIndex));
}
-TEST_F(PostingListAccessorTest, PreexistingPLKeepOnSameBlock) {
+TEST_F(PostingListHitAccessorTest, PreexistingPLKeepOnSameBlock) {
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::Create(flash_index_storage_.get(),
- serializer_.get()));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
// Add a single hit. This will fit in a min-sized posting list.
Hit hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency);
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit1));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
PostingListAccessor::FinalizeResult result1 =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
ICING_EXPECT_OK(result1.status);
// Should have been allocated to the first block.
EXPECT_THAT(result1.id.block_index(), Eq(1));
@@ -116,12 +116,12 @@ TEST_F(PostingListAccessorTest, PreexistingPLKeepOnSameBlock) {
// reallocated.
ICING_ASSERT_OK_AND_ASSIGN(
pl_accessor,
- PostingListAccessor::CreateFromExisting(flash_index_storage_.get(),
- serializer_.get(), result1.id));
+ PostingListHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/1);
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit2));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit2));
PostingListAccessor::FinalizeResult result2 =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
ICING_EXPECT_OK(result2.status);
// Should have been allocated to the same posting list as the first hit.
EXPECT_THAT(result2.id, Eq(result1.id));
@@ -134,11 +134,11 @@ TEST_F(PostingListAccessorTest, PreexistingPLKeepOnSameBlock) {
IsOkAndHolds(ElementsAre(hit2, hit1)));
}
-TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) {
+TEST_F(PostingListHitAccessorTest, PreexistingPLReallocateToLargerPL) {
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::Create(flash_index_storage_.get(),
- serializer_.get()));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
// The smallest posting list size is 15 bytes. The first four hits will be
// compressed to one byte each and will be able to fit in the 5 byte padded
// region. The last hit will fit in one of the special hits. The posting list
@@ -146,10 +146,10 @@ TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) {
std::vector<Hit> hits1 =
CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1);
for (const Hit& hit : hits1) {
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
}
PostingListAccessor::FinalizeResult result1 =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
ICING_EXPECT_OK(result1.status);
// Should have been allocated to the first block.
EXPECT_THAT(result1.id.block_index(), Eq(1));
@@ -158,8 +158,8 @@ TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) {
// Now let's add some more hits!
ICING_ASSERT_OK_AND_ASSIGN(
pl_accessor,
- PostingListAccessor::CreateFromExisting(flash_index_storage_.get(),
- serializer_.get(), result1.id));
+ PostingListHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
// The current posting list can fit at most 2 more hits. Adding 12 more hits
// should result in these hits being moved to a larger posting list.
std::vector<Hit> hits2 = CreateHits(
@@ -167,10 +167,10 @@ TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) {
/*desired_byte_length=*/1);
for (const Hit& hit : hits2) {
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
}
PostingListAccessor::FinalizeResult result2 =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
ICING_EXPECT_OK(result2.status);
// Should have been allocated to the second (new) block because the posting
// list should have grown beyond the size that the first block maintains.
@@ -188,19 +188,19 @@ TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) {
IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
}
-TEST_F(PostingListAccessorTest, MultiBlockChainsBlocksProperly) {
+TEST_F(PostingListHitAccessorTest, MultiBlockChainsBlocksProperly) {
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::Create(flash_index_storage_.get(),
- serializer_.get()));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
// Add some hits! Any hits!
std::vector<Hit> hits1 =
CreateHits(/*num_hits=*/5000, /*desired_byte_length=*/1);
for (const Hit& hit : hits1) {
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
}
PostingListAccessor::FinalizeResult result1 =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
ICING_EXPECT_OK(result1.status);
PostingListIdentifier second_block_id = result1.id;
// Should have been allocated to the second block, which holds a max-sized
@@ -235,19 +235,19 @@ TEST_F(PostingListAccessorTest, MultiBlockChainsBlocksProperly) {
IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits1.rend())));
}
-TEST_F(PostingListAccessorTest, PreexistingMultiBlockReusesBlocksProperly) {
+TEST_F(PostingListHitAccessorTest, PreexistingMultiBlockReusesBlocksProperly) {
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::Create(flash_index_storage_.get(),
- serializer_.get()));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
// Add some hits! Any hits!
std::vector<Hit> hits1 =
CreateHits(/*num_hits=*/5000, /*desired_byte_length=*/1);
for (const Hit& hit : hits1) {
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
}
PostingListAccessor::FinalizeResult result1 =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
ICING_EXPECT_OK(result1.status);
PostingListIdentifier first_add_id = result1.id;
EXPECT_THAT(first_add_id, Eq(PostingListIdentifier(
@@ -258,17 +258,17 @@ TEST_F(PostingListAccessorTest, PreexistingMultiBlockReusesBlocksProperly) {
// second block.
ICING_ASSERT_OK_AND_ASSIGN(
pl_accessor,
- PostingListAccessor::CreateFromExisting(flash_index_storage_.get(),
- serializer_.get(), first_add_id));
+ PostingListHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), first_add_id));
std::vector<Hit> hits2 = CreateHits(
/*start_docid=*/hits1.back().document_id() + 1, /*num_hits=*/50,
/*desired_byte_length=*/1);
for (const Hit& hit : hits2) {
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit));
}
PostingListAccessor::FinalizeResult result2 =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
ICING_EXPECT_OK(result2.status);
PostingListIdentifier second_add_id = result2.id;
EXPECT_THAT(second_add_id, Eq(first_add_id));
@@ -302,61 +302,61 @@ TEST_F(PostingListAccessorTest, PreexistingMultiBlockReusesBlocksProperly) {
IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits1.rend())));
}
-TEST_F(PostingListAccessorTest, InvalidHitReturnsInvalidArgument) {
+TEST_F(PostingListHitAccessorTest, InvalidHitReturnsInvalidArgument) {
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::Create(flash_index_storage_.get(),
- serializer_.get()));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
Hit invalid_hit;
- EXPECT_THAT(pl_accessor.PrependHit(invalid_hit),
+ EXPECT_THAT(pl_accessor->PrependHit(invalid_hit),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
-TEST_F(PostingListAccessorTest, HitsNotDecreasingReturnsInvalidArgument) {
+TEST_F(PostingListHitAccessorTest, HitsNotDecreasingReturnsInvalidArgument) {
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::Create(flash_index_storage_.get(),
- serializer_.get()));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency);
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit1));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
Hit hit2(/*section_id=*/6, /*document_id=*/1, Hit::kDefaultTermFrequency);
- EXPECT_THAT(pl_accessor.PrependHit(hit2),
+ EXPECT_THAT(pl_accessor->PrependHit(hit2),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
Hit hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency);
- EXPECT_THAT(pl_accessor.PrependHit(hit3),
+ EXPECT_THAT(pl_accessor->PrependHit(hit3),
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
-TEST_F(PostingListAccessorTest, NewPostingListNoHitsAdded) {
+TEST_F(PostingListHitAccessorTest, NewPostingListNoHitsAdded) {
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::Create(flash_index_storage_.get(),
- serializer_.get()));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
PostingListAccessor::FinalizeResult result1 =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
EXPECT_THAT(result1.status,
StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
}
-TEST_F(PostingListAccessorTest, PreexistingPostingListNoHitsAdded) {
+TEST_F(PostingListHitAccessorTest, PreexistingPostingListNoHitsAdded) {
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor,
- PostingListAccessor::Create(flash_index_storage_.get(),
- serializer_.get()));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor,
+ PostingListHitAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency);
- ICING_ASSERT_OK(pl_accessor.PrependHit(hit1));
+ ICING_ASSERT_OK(pl_accessor->PrependHit(hit1));
PostingListAccessor::FinalizeResult result1 =
- PostingListAccessor::Finalize(std::move(pl_accessor));
+ std::move(*pl_accessor).Finalize();
ICING_ASSERT_OK(result1.status);
ICING_ASSERT_OK_AND_ASSIGN(
- PostingListAccessor pl_accessor2,
- PostingListAccessor::CreateFromExisting(flash_index_storage_.get(),
- serializer_.get(), result1.id));
+ std::unique_ptr<PostingListHitAccessor> pl_accessor2,
+ PostingListHitAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
PostingListAccessor::FinalizeResult result2 =
- PostingListAccessor::Finalize(std::move(pl_accessor2));
+ std::move(*pl_accessor2).Finalize();
ICING_ASSERT_OK(result2.status);
}
diff --git a/icing/index/main/posting-list-used-hit-serializer.cc b/icing/index/main/posting-list-used-hit-serializer.cc
index d45a428..a163188 100644
--- a/icing/index/main/posting-list-used-hit-serializer.cc
+++ b/icing/index/main/posting-list-used-hit-serializer.cc
@@ -20,7 +20,6 @@
#include <vector>
#include "icing/absl_ports/canonical_errors.h"
-#include "icing/file/posting_list/posting-list-common.h"
#include "icing/file/posting_list/posting-list-used.h"
#include "icing/legacy/core/icing-string-util.h"
#include "icing/legacy/index/icing-bit-util.h"
diff --git a/icing/index/main/posting-list-used-hit-serializer.h b/icing/index/main/posting-list-used-hit-serializer.h
index 70e3e6c..1a3cbc2 100644
--- a/icing/index/main/posting-list-used-hit-serializer.h
+++ b/icing/index/main/posting-list-used-hit-serializer.h
@@ -31,7 +31,7 @@ namespace lib {
// comments in posting-list-used-hit-serializer.cc.
class PostingListUsedHitSerializer : public PostingListUsedSerializer {
public:
- static constexpr uint32_t kSpecialHitsSize = sizeof(Hit) * kNumSpecialData;
+ static constexpr uint32_t kSpecialHitsSize = kNumSpecialData * sizeof(Hit);
uint32_t GetDataTypeBytes() const override { return sizeof(Hit); }
@@ -44,23 +44,14 @@ class PostingListUsedHitSerializer : public PostingListUsedSerializer {
return kMinPostingListSize;
}
- // Min size of posting list that can fit these used bytes (see MoveFrom).
uint32_t GetMinPostingListSizeToFit(
const PostingListUsed* posting_list_used) const override;
- // Returns bytes used by actual hits.
uint32_t GetBytesUsed(
const PostingListUsed* posting_list_used) const override;
void Clear(PostingListUsed* posting_list_used) const override;
- // Moves contents from posting list 'src' to 'dst'. Clears 'src'.
- //
- // RETURNS:
- // - OK on success
- // - INVALID_ARGUMENT if 'src' is not valid or 'src' is too large to fit in
- // 'dst'.
- // - FAILED_PRECONDITION if 'dst' posting list is in a corrupted state.
libtextclassifier3::Status MoveFrom(PostingListUsed* dst,
PostingListUsed* src) const override;
diff --git a/icing/index/main/posting-list-used-hit-serializer_test.cc b/icing/index/main/posting-list-used-hit-serializer_test.cc
index b87adc9..9ecb7ec 100644
--- a/icing/index/main/posting-list-used-hit-serializer_test.cc
+++ b/icing/index/main/posting-list-used-hit-serializer_test.cc
@@ -579,7 +579,7 @@ TEST(PostingListUsedHitSerializerTest,
ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit));
}
- EXPECT_THAT(serializer.MoveFrom(&pl_used1, /*other=*/nullptr),
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used1, /*src=*/nullptr),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(serializer.GetHits(&pl_used1),
IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend())));
@@ -625,7 +625,7 @@ TEST(PostingListUsedHitSerializerTest,
}
TEST(PostingListUsedHitSerializerTest,
- MoveToInvalidPostingListReturnsInvalidArgument) {
+ MoveToInvalidPostingListReturnsFailedPrecondition) {
PostingListUsedHitSerializer serializer;
int size = 3 * serializer.GetMinPostingListSize();
@@ -657,7 +657,7 @@ TEST(PostingListUsedHitSerializerTest,
*first_hit = invalid_hit;
++first_hit;
*first_hit = invalid_hit;
- EXPECT_THAT(serializer.MoveFrom(&pl_used2, &pl_used1),
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(serializer.GetHits(&pl_used1),
IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend())));
diff --git a/icing/index/numeric/doc-hit-info-iterator-numeric.h b/icing/index/numeric/doc-hit-info-iterator-numeric.h
new file mode 100644
index 0000000..1bfd193
--- /dev/null
+++ b/icing/index/numeric/doc-hit-info-iterator-numeric.h
@@ -0,0 +1,63 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_NUMERIC_DOC_HIT_INFO_ITERATOR_NUMERIC_H_
+#define ICING_INDEX_NUMERIC_DOC_HIT_INFO_ITERATOR_NUMERIC_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/numeric-index.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+template <typename T>
+class DocHitInfoIteratorNumeric : public DocHitInfoIterator {
+ public:
+ explicit DocHitInfoIteratorNumeric(
+ std::unique_ptr<typename NumericIndex<T>::Iterator> numeric_index_iter)
+ : numeric_index_iter_(std::move(numeric_index_iter)) {}
+
+ libtextclassifier3::Status Advance() override {
+ ICING_RETURN_IF_ERROR(numeric_index_iter_->Advance());
+
+ doc_hit_info_ = numeric_index_iter_->GetDocHitInfo();
+ return libtextclassifier3::Status::OK;
+ }
+
+ int32_t GetNumBlocksInspected() const override { return 0; }
+
+ int32_t GetNumLeafAdvanceCalls() const override { return 0; }
+
+ std::string ToString() const override { return "test"; }
+
+ void PopulateMatchedTermsStats(
+ std::vector<TermMatchInfo>* matched_terms_stats,
+ SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override {
+ // For numeric hit iterator, this should do nothing since there is no term.
+ }
+
+ private:
+ std::unique_ptr<typename NumericIndex<T>::Iterator> numeric_index_iter_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_NUMERIC_DOC_HIT_INFO_ITERATOR_NUMERIC_H_
diff --git a/icing/index/numeric/dummy-numeric-index.h b/icing/index/numeric/dummy-numeric-index.h
new file mode 100644
index 0000000..a1d20f8
--- /dev/null
+++ b/icing/index/numeric/dummy-numeric-index.h
@@ -0,0 +1,239 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_NUMERIC_DUMMY_NUMERIC_INDEX_H_
+#define ICING_INDEX_NUMERIC_DUMMY_NUMERIC_INDEX_H_
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <queue>
+#include <string>
+#include <string_view>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/hit/hit.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/doc-hit-info-iterator-numeric.h"
+#include "icing/index/numeric/numeric-index.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+
+namespace icing {
+namespace lib {
+
+template <typename T>
+class DummyNumericIndex : public NumericIndex<T> {
+ public:
+ ~DummyNumericIndex() override = default;
+
+ std::unique_ptr<typename NumericIndex<T>::Editor> Edit(
+ std::string_view property_name, DocumentId document_id,
+ SectionId section_id) override {
+ return std::make_unique<Editor>(property_name, document_id, section_id,
+ storage_);
+ }
+
+ libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> GetIterator(
+ std::string_view property_name, T key_lower, T key_upper) const override;
+
+ libtextclassifier3::Status Reset() override {
+ storage_.clear();
+ return libtextclassifier3::Status::OK;
+ }
+
+ libtextclassifier3::Status PersistToDisk() override {
+ return libtextclassifier3::Status::OK;
+ }
+
+ private:
+ class Editor : public NumericIndex<T>::Editor {
+ public:
+ explicit Editor(
+ std::string_view property_name, DocumentId document_id,
+ SectionId section_id,
+ std::unordered_map<std::string, std::map<T, std::vector<BasicHit>>>&
+ storage)
+ : NumericIndex<T>::Editor(property_name, document_id, section_id),
+ storage_(storage) {}
+
+ ~Editor() override = default;
+
+ libtextclassifier3::Status BufferKey(T key) override {
+ seen_keys_.insert(key);
+ return libtextclassifier3::Status::OK;
+ }
+
+ libtextclassifier3::Status IndexAllBufferedKeys() override;
+
+ private:
+ std::unordered_set<T> seen_keys_;
+ std::unordered_map<std::string, std::map<T, std::vector<BasicHit>>>&
+ storage_;
+ };
+
+ class Iterator : public NumericIndex<T>::Iterator {
+ public:
+ // We group BasicHits (sorted by document_id) of a key into a Bucket (stored
+ // as std::vector) and store key -> vector in an std::map. When doing range
+ // query, we may access vectors from multiple keys and want to return
+ // BasicHits to callers sorted by document_id. Therefore, this problem is
+ // actually "merge K sorted vectors".
+ // To implement this algorithm via priority_queue, we create this wrapper
+ // class to store iterators of map and vector.
+ class BucketInfo {
+ public:
+ explicit BucketInfo(
+ typename std::map<T, std::vector<BasicHit>>::const_iterator
+ bucket_iter)
+ : bucket_iter_(bucket_iter),
+ vec_iter_(bucket_iter_->second.rbegin()) {}
+
+ bool Advance() { return ++vec_iter_ != bucket_iter_->second.rend(); }
+
+ const BasicHit& GetCurrentBasicHit() const { return *vec_iter_; }
+
+ bool operator<(const BucketInfo& other) const {
+ // std::priority_queue is a max heap and we should return BasicHits in
+ // DocumentId descending order.
+ // - BucketInfo::operator< should have the same order as DocumentId.
+ // - BasicHit encodes inverted document id and its operator< compares
+ // the encoded raw value directly.
+ // - Therefore, BucketInfo::operator< should compare BasicHit reversely.
+ // - This will make priority_queue return buckets in DocumentId
+ // descending and SectionId ascending order.
+ // - Whatever direction we sort SectionId by (or pop by priority_queue)
+ // doesn't matter because all hits for the same DocumentId will be
+ // merged into a single DocHitInfo.
+ return other.GetCurrentBasicHit() < GetCurrentBasicHit();
+ }
+
+ private:
+ typename std::map<T, std::vector<BasicHit>>::const_iterator bucket_iter_;
+ std::vector<BasicHit>::const_reverse_iterator vec_iter_;
+ };
+
+ explicit Iterator(T key_lower, T key_upper,
+ std::vector<BucketInfo>&& bucket_info_vec)
+ : NumericIndex<T>::Iterator(key_lower, key_upper),
+ pq_(std::less<BucketInfo>(), std::move(bucket_info_vec)) {}
+
+ ~Iterator() override = default;
+
+ libtextclassifier3::Status Advance() override;
+
+ DocHitInfo GetDocHitInfo() const override { return doc_hit_info_; }
+
+ private:
+ std::priority_queue<BucketInfo> pq_;
+ DocHitInfo doc_hit_info_;
+ };
+
+ std::unordered_map<std::string, std::map<T, std::vector<BasicHit>>> storage_;
+};
+
+template <typename T>
+libtextclassifier3::Status
+DummyNumericIndex<T>::Editor::IndexAllBufferedKeys() {
+ auto property_map_iter = storage_.find(this->property_name_);
+ if (property_map_iter == storage_.end()) {
+ const auto& [inserted_iter, insert_result] =
+ storage_.insert({this->property_name_, {}});
+ if (!insert_result) {
+ return absl_ports::InternalError(
+ absl_ports::StrCat("Failed to create a new map for property \"",
+ this->property_name_, "\""));
+ }
+ property_map_iter = inserted_iter;
+ }
+
+ for (const T& key : seen_keys_) {
+ auto key_map_iter = property_map_iter->second.find(key);
+ if (key_map_iter == property_map_iter->second.end()) {
+ const auto& [inserted_iter, insert_result] =
+ property_map_iter->second.insert({key, {}});
+ if (!insert_result) {
+ return absl_ports::InternalError("Failed to create a new map for key");
+ }
+ key_map_iter = inserted_iter;
+ }
+ key_map_iter->second.push_back(
+ BasicHit(this->section_id_, this->document_id_));
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+template <typename T>
+libtextclassifier3::Status DummyNumericIndex<T>::Iterator::Advance() {
+ if (pq_.empty()) {
+ return absl_ports::OutOfRangeError("End of iterator");
+ }
+
+ DocumentId document_id = pq_.top().GetCurrentBasicHit().document_id();
+ doc_hit_info_ = DocHitInfo(document_id);
+ // Merge sections with same document_id into a single DocHitInfo
+ while (!pq_.empty() &&
+ pq_.top().GetCurrentBasicHit().document_id() == document_id) {
+ doc_hit_info_.UpdateSection(pq_.top().GetCurrentBasicHit().section_id());
+
+ BucketInfo info = pq_.top();
+ pq_.pop();
+
+ if (info.Advance()) {
+ pq_.push(std::move(info));
+ }
+ }
+
+ return libtextclassifier3::Status::OK;
+}
+
+template <typename T>
+libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>>
+DummyNumericIndex<T>::GetIterator(std::string_view property_name, T key_lower,
+ T key_upper) const {
+ if (key_lower > key_upper) {
+ return absl_ports::InvalidArgumentError(
+ "key_lower should not be greater than key_upper");
+ }
+
+ auto property_map_iter = storage_.find(std::string(property_name));
+ if (property_map_iter == storage_.end()) {
+ return absl_ports::NotFoundError(
+ absl_ports::StrCat("Property \"", property_name, "\" not found"));
+ }
+
+ std::vector<typename Iterator::BucketInfo> bucket_info_vec;
+ for (auto key_map_iter = property_map_iter->second.lower_bound(key_lower);
+ key_map_iter != property_map_iter->second.cend() &&
+ key_map_iter->first <= key_upper;
+ ++key_map_iter) {
+ bucket_info_vec.push_back(typename Iterator::BucketInfo(key_map_iter));
+ }
+
+ return std::make_unique<DocHitInfoIteratorNumeric<T>>(
+ std::make_unique<Iterator>(key_lower, key_upper,
+ std::move(bucket_info_vec)));
+}
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_NUMERIC_DUMMY_NUMERIC_INDEX_H_
diff --git a/icing/index/numeric/integer-index-data.h b/icing/index/numeric/integer-index-data.h
new file mode 100644
index 0000000..92653fa
--- /dev/null
+++ b/icing/index/numeric/integer-index-data.h
@@ -0,0 +1,59 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_NUMERIC_INTEGER_INDEX_DATA_H_
+#define ICING_INDEX_NUMERIC_INTEGER_INDEX_DATA_H_
+
+#include <cstdint>
+
+#include "icing/index/hit/hit.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+
+namespace icing {
+namespace lib {
+
+// Data wrapper to store BasicHit and key for integer index.
+class IntegerIndexData {
+ public:
+ explicit IntegerIndexData(SectionId section_id, DocumentId document_id,
+ int64_t key)
+ : basic_hit_(section_id, document_id), key_(key) {}
+
+ explicit IntegerIndexData() : basic_hit_(), key_(0) {}
+
+ const BasicHit& basic_hit() const { return basic_hit_; }
+
+ int64_t key() const { return key_; }
+
+ bool is_valid() const { return basic_hit_.is_valid(); }
+
+ bool operator<(const IntegerIndexData& other) const {
+ return basic_hit_ < other.basic_hit_;
+ }
+
+ bool operator==(const IntegerIndexData& other) const {
+ return basic_hit_ == other.basic_hit_ && key_ == other.key_;
+ }
+
+ private:
+ BasicHit basic_hit_;
+ int64_t key_;
+} __attribute__((packed));
+static_assert(sizeof(IntegerIndexData) == 12, "");
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_NUMERIC_INTEGER_INDEX_DATA_H_
diff --git a/icing/index/numeric/numeric-index.h b/icing/index/numeric/numeric-index.h
new file mode 100644
index 0000000..6798f8d
--- /dev/null
+++ b/icing/index/numeric/numeric-index.h
@@ -0,0 +1,146 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_NUMERIC_NUMERIC_INDEX_H_
+#define ICING_INDEX_NUMERIC_NUMERIC_INDEX_H_
+
+#include <memory>
+#include <string>
+#include <string_view>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+
+namespace icing {
+namespace lib {
+
+template <typename T>
+class NumericIndex {
+ public:
+ using value_type = T;
+
+ // Editor class for batch adding new records into numeric index for a given
+ // property, DocumentId and SectionId. The caller should use BufferKey to
+ // buffer a key (calls several times for multiple keys) and finally call
+ // IndexAllBufferedKeys to batch add all buffered keys (with DocumentId +
+ // SectionId info, i.e. BasicHit) into numeric index.
+ //
+ // For example, there are values = [5, 1, 10, -100] in DocumentId = 5,
+ // SectionId = 1 (property "timestamp").
+ // Then the client should call BufferKey(5), BufferKey(1), BufferKey(10),
+ // BufferKey(-100) first, and finally call IndexAllBufferedKeys once to batch
+ // add these records into numeric index.
+ class Editor {
+ public:
+ explicit Editor(std::string_view property_name, DocumentId document_id,
+ SectionId section_id)
+ : property_name_(property_name),
+ document_id_(document_id),
+ section_id_(section_id) {}
+
+ virtual ~Editor() = default;
+
+ // Buffers a new key.
+ //
+ // Returns:
+ // - OK on success
+ // - Any other errors, depending on the actual implementation
+ virtual libtextclassifier3::Status BufferKey(T key) = 0;
+
+ // Adds all buffered keys into numeric index.
+ //
+ // Returns:
+ // - OK on success
+ // - Any other errors, depending on the actual implementation
+ virtual libtextclassifier3::Status IndexAllBufferedKeys() = 0;
+
+ protected:
+ std::string property_name_;
+ DocumentId document_id_;
+ SectionId section_id_;
+ };
+
+ // Iterator class for numeric index range query [key_lower, key_upper]
+ // (inclusive for both side) on a given property (see GetIterator). There are
+ // some basic requirements for implementation:
+ // - Iterates through all relevant doc hits.
+ // - Merges multiple SectionIds of doc hits with same DocumentId into a single
+ // SectionIdMask and constructs DocHitInfo.
+ // - Returns DocHitInfo in descending DocumentId order.
+ //
+ // For example, relevant doc hits (DocumentId, SectionId) are [(2, 0), (4, 3),
+ // (2, 1), (6, 2), (4, 2)]. Advance() and GetDocHitInfo() should return
+ // DocHitInfo(6, SectionIdMask(2)), DocHitInfo(4, SectionIdMask(2, 3)) and
+ // DocHitInfo(2, SectionIdMask(0, 1)).
+ class Iterator {
+ public:
+ explicit Iterator(T key_lower, T key_upper)
+ : key_lower_(key_lower), key_upper_(key_upper) {}
+
+ virtual ~Iterator() = default;
+
+ virtual libtextclassifier3::Status Advance() = 0;
+
+ virtual DocHitInfo GetDocHitInfo() const = 0;
+
+ protected:
+ T key_lower_;
+ T key_upper_;
+ };
+
+ virtual ~NumericIndex() = default;
+
+ // Returns an Editor instance for adding new records into numeric index for a
+ // given property, DocumentId and SectionId. See Editor for more details.
+ virtual std::unique_ptr<Editor> Edit(std::string_view property_name,
+ DocumentId document_id,
+ SectionId section_id) = 0;
+
+ // Returns a DocHitInfoIteratorNumeric (in DocHitInfoIterator interface type
+ // format) for iterating through all docs which have the specified (numeric)
+ // property contents in range [key_lower, key_upper].
+ //
+ // In general, different numeric index implementations require different data
+ // iterator implementations, so class Iterator is an abstraction of the data
+ // iterator and DocHitInfoIteratorNumeric can work with any implementation of
+ // it. See Iterator and DocHitInfoIteratorNumeric for more details.
+ //
+ // Returns:
+ // - std::unique_ptr<DocHitInfoIterator> on success
+ // - NOT_FOUND_ERROR if there is no numeric index for property_name
+ // - INVALID_ARGUMENT_ERROR if key_lower > key_upper
+ // - Any other errors, depending on the actual implementation
+ virtual libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>>
+ GetIterator(std::string_view property_name, T key_lower,
+ T key_upper) const = 0;
+
+ // Clears all files created by the index. Returns OK if all files were
+ // cleared.
+ virtual libtextclassifier3::Status Reset() = 0;
+
+ // Syncs all the data and metadata changes to disk.
+ //
+ // Returns:
+ // OK on success
+ // INTERNAL_ERROR on I/O errors
+ virtual libtextclassifier3::Status PersistToDisk() = 0;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_NUMERIC_NUMERIC_INDEX_H_
diff --git a/icing/index/numeric/numeric-index_test.cc b/icing/index/numeric/numeric-index_test.cc
new file mode 100644
index 0000000..38769f6
--- /dev/null
+++ b/icing/index/numeric/numeric-index_test.cc
@@ -0,0 +1,361 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/numeric/numeric-index.h"
+
+#include <limits>
+#include <string>
+#include <string_view>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/dummy-numeric-index.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/testing/common-matchers.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+using ::testing::NotNull;
+
+constexpr static std::string_view kDefaultTestPropertyName = "test";
+
+constexpr SectionId kDefaultSectionId = 0;
+
+template <typename T>
+class NumericIndexTest : public ::testing::Test {
+ protected:
+ using INDEX_IMPL_TYPE = T;
+
+ void SetUp() override {
+ if (std::is_same_v<
+ INDEX_IMPL_TYPE,
+ DummyNumericIndex<typename INDEX_IMPL_TYPE::value_type>>) {
+ numeric_index_ = std::make_unique<
+ DummyNumericIndex<typename INDEX_IMPL_TYPE::value_type>>();
+ }
+
+ ASSERT_THAT(numeric_index_, NotNull());
+ }
+
+ void Index(std::string_view property_name, DocumentId document_id,
+ SectionId section_id,
+ std::vector<typename INDEX_IMPL_TYPE::value_type> keys) {
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ this->numeric_index_->Edit(property_name, document_id, section_id);
+
+ for (const auto& key : keys) {
+ ICING_EXPECT_OK(editor->BufferKey(key));
+ }
+ ICING_EXPECT_OK(editor->IndexAllBufferedKeys());
+ }
+
+ libtextclassifier3::StatusOr<std::vector<DocHitInfo>> Query(
+ std::string_view property_name,
+ typename INDEX_IMPL_TYPE::value_type key_lower,
+ typename INDEX_IMPL_TYPE::value_type key_upper) {
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<DocHitInfoIterator> iter,
+ this->numeric_index_->GetIterator(property_name, key_lower, key_upper));
+
+ std::vector<DocHitInfo> result;
+ while (iter->Advance().ok()) {
+ result.push_back(iter->doc_hit_info());
+ }
+ return result;
+ }
+
+ std::unique_ptr<NumericIndex<typename INDEX_IMPL_TYPE::value_type>>
+ numeric_index_;
+};
+
+using TestTypes = ::testing::Types<DummyNumericIndex<int64_t>>;
+TYPED_TEST_SUITE(NumericIndexTest, TestTypes);
+
+TYPED_TEST(NumericIndexTest, SingleKeyExactQuery) {
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId,
+ /*keys=*/{1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId,
+ /*keys=*/{3});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId,
+ /*keys=*/{2});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId,
+ /*keys=*/{0});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId,
+ /*keys=*/{4});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId,
+ /*keys=*/{2});
+
+ int64_t query_key = 2;
+ std::vector<SectionId> expected_sections{kDefaultSectionId};
+ EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/query_key,
+ /*key_upper=*/query_key),
+ IsOkAndHolds(ElementsAre(
+ EqualsDocHitInfo(/*document_id=*/5, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/2, expected_sections))));
+}
+
+TYPED_TEST(NumericIndexTest, SingleKeyRangeQuery) {
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId,
+ /*keys=*/{1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId,
+ /*keys=*/{3});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId,
+ /*keys=*/{2});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId,
+ /*keys=*/{0});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId,
+ /*keys=*/{4});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId,
+ /*keys=*/{2});
+
+ std::vector<SectionId> expected_sections{kDefaultSectionId};
+ EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/1,
+ /*key_upper=*/3),
+ IsOkAndHolds(ElementsAre(
+ EqualsDocHitInfo(/*document_id=*/5, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/2, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/1, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/0, expected_sections))));
+}
+
+TYPED_TEST(NumericIndexTest, EmptyResult) {
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId,
+ /*keys=*/{1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId,
+ /*keys=*/{3});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId,
+ /*keys=*/{2});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId,
+ /*keys=*/{0});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId,
+ /*keys=*/{4});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId,
+ /*keys=*/{2});
+
+ EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/100,
+ /*key_upper=*/200),
+ IsOkAndHolds(IsEmpty()));
+}
+
+TYPED_TEST(NumericIndexTest, MultipleKeysShouldMergeAndDedupeDocHitInfo) {
+ // Construct several documents with mutiple keys under the same section.
+ // Range query [1, 3] will find hits with same (DocumentId, SectionId) for
+ // mutiple times. For example, (2, kDefaultSectionId) will be found twice
+ // (once for key = 1 and once for key = 3).
+ // Test if the iterator dedupes correctly.
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId,
+ /*keys=*/{-1000, 0});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId,
+ /*keys=*/{-100, 0, 1, 2, 3, 4, 5});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId,
+ /*keys=*/{3, 1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId,
+ /*keys=*/{4, 1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId,
+ /*keys=*/{1, 6});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId,
+ /*keys=*/{2, 100});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/6, kDefaultSectionId,
+ /*keys=*/{1000, 2});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/7, kDefaultSectionId,
+ /*keys=*/{4, -1000});
+
+ std::vector<SectionId> expected_sections{kDefaultSectionId};
+ EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/1,
+ /*key_upper=*/3),
+ IsOkAndHolds(ElementsAre(
+ EqualsDocHitInfo(/*document_id=*/6, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/5, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/4, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/3, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/2, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/1, expected_sections))));
+}
+
+TYPED_TEST(NumericIndexTest, EdgeNumericValues) {
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId,
+ /*keys=*/{0});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId,
+ /*keys=*/{-100});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId,
+ /*keys=*/{-80});
+ this->Index(
+ kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId,
+ /*keys=*/{std::numeric_limits<typename TypeParam::value_type>::max()});
+ this->Index(
+ kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId,
+ /*keys=*/{std::numeric_limits<typename TypeParam::value_type>::min()});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId,
+ /*keys=*/{200});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/6, kDefaultSectionId,
+ /*keys=*/{100});
+ this->Index(
+ kDefaultTestPropertyName, /*document_id=*/7, kDefaultSectionId,
+ /*keys=*/{std::numeric_limits<typename TypeParam::value_type>::max()});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/8, kDefaultSectionId,
+ /*keys=*/{0});
+ this->Index(
+ kDefaultTestPropertyName, /*document_id=*/9, kDefaultSectionId,
+ /*keys=*/{std::numeric_limits<typename TypeParam::value_type>::min()});
+
+ std::vector<SectionId> expected_sections{kDefaultSectionId};
+
+ // Negative key
+ EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/-100,
+ /*key_upper=*/-70),
+ IsOkAndHolds(ElementsAre(
+ EqualsDocHitInfo(/*document_id=*/2, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/1, expected_sections))));
+
+ // value_type max key
+ EXPECT_THAT(
+ this->Query(kDefaultTestPropertyName, /*key_lower=*/
+ std::numeric_limits<typename TypeParam::value_type>::max(),
+ /*key_upper=*/
+ std::numeric_limits<typename TypeParam::value_type>::max()),
+ IsOkAndHolds(
+ ElementsAre(EqualsDocHitInfo(/*document_id=*/7, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/3, expected_sections))));
+
+ // value_type min key
+ EXPECT_THAT(
+ this->Query(kDefaultTestPropertyName, /*key_lower=*/
+ std::numeric_limits<typename TypeParam::value_type>::min(),
+ /*key_upper=*/
+ std::numeric_limits<typename TypeParam::value_type>::min()),
+ IsOkAndHolds(
+ ElementsAre(EqualsDocHitInfo(/*document_id=*/9, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/4, expected_sections))));
+
+ // Key = 0
+ EXPECT_THAT(
+ this->Query(kDefaultTestPropertyName, /*key_lower=*/0, /*key_upper=*/0),
+ IsOkAndHolds(
+ ElementsAre(EqualsDocHitInfo(/*document_id=*/8, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/0, expected_sections))));
+
+ // All keys from value_type min to value_type max
+ EXPECT_THAT(
+ this->Query(kDefaultTestPropertyName, /*key_lower=*/
+ std::numeric_limits<typename TypeParam::value_type>::min(),
+ /*key_upper=*/
+ std::numeric_limits<typename TypeParam::value_type>::max()),
+ IsOkAndHolds(
+ ElementsAre(EqualsDocHitInfo(/*document_id=*/9, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/8, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/7, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/6, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/5, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/4, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/3, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/2, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/1, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/0, expected_sections))));
+}
+
+TYPED_TEST(NumericIndexTest,
+ MultipleSectionsShouldMergeSectionsAndDedupeDocHitInfo) {
+ // Construct several documents with mutiple numeric sections.
+ // Range query [1, 3] will find hits with same DocumentIds but multiple
+ // different SectionIds. For example, there will be 2 hits (1, 0), (1, 1) for
+ // DocumentId=1.
+ // Test if the iterator merges multiple sections into a single SectionIdMask
+ // correctly.
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, /*section_id=*/0,
+ /*keys=*/{0});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, /*section_id=*/1,
+ /*keys=*/{1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, /*section_id=*/2,
+ /*keys=*/{-1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, /*section_id=*/0,
+ /*keys=*/{2});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, /*section_id=*/1,
+ /*keys=*/{1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, /*section_id=*/2,
+ /*keys=*/{4});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/2, /*section_id=*/3,
+ /*keys=*/{3});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/2, /*section_id=*/4,
+ /*keys=*/{2});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/2, /*section_id=*/5,
+ /*keys=*/{5});
+
+ EXPECT_THAT(
+ this->Query(kDefaultTestPropertyName, /*key_lower=*/1,
+ /*key_upper=*/3),
+ IsOkAndHolds(ElementsAre(
+ EqualsDocHitInfo(/*document_id=*/2, std::vector<SectionId>{3, 4}),
+ EqualsDocHitInfo(/*document_id=*/1, std::vector<SectionId>{0, 1}),
+ EqualsDocHitInfo(/*document_id=*/0, std::vector<SectionId>{1}))));
+}
+
+TYPED_TEST(NumericIndexTest, NonRelevantPropertyShouldNotBeIncluded) {
+ constexpr std::string_view kNonRelevantProperty = "non_relevant_property";
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId,
+ /*keys=*/{1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId,
+ /*keys=*/{3});
+ this->Index(kNonRelevantProperty, /*document_id=*/2, kDefaultSectionId,
+ /*keys=*/{2});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId,
+ /*keys=*/{0});
+ this->Index(kNonRelevantProperty, /*document_id=*/4, kDefaultSectionId,
+ /*keys=*/{4});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId,
+ /*keys=*/{2});
+
+ std::vector<SectionId> expected_sections{kDefaultSectionId};
+ EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/1,
+ /*key_upper=*/3),
+ IsOkAndHolds(ElementsAre(
+ EqualsDocHitInfo(/*document_id=*/5, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/1, expected_sections),
+ EqualsDocHitInfo(/*document_id=*/0, expected_sections))));
+}
+
+TYPED_TEST(NumericIndexTest,
+ RangeQueryKeyLowerGreaterThanKeyUpperShouldReturnError) {
+ this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId,
+ /*keys=*/{1});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId,
+ /*keys=*/{3});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId,
+ /*keys=*/{2});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId,
+ /*keys=*/{0});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId,
+ /*keys=*/{4});
+ this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId,
+ /*keys=*/{2});
+
+ EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/3,
+ /*key_upper=*/1),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/numeric/posting-list-integer-index-data-accessor.cc b/icing/index/numeric/posting-list-integer-index-data-accessor.cc
new file mode 100644
index 0000000..73b48e2
--- /dev/null
+++ b/icing/index/numeric/posting-list-integer-index-data-accessor.cc
@@ -0,0 +1,136 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/numeric/posting-list-integer-index-data-accessor.h"
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/index-block.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/numeric/integer-index-data.h"
+#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+/* static */ libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListIntegerIndexDataAccessor>>
+PostingListIntegerIndexDataAccessor::Create(
+ FlashIndexStorage* storage,
+ PostingListUsedIntegerIndexDataSerializer* serializer) {
+ uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes(
+ storage->block_size(), serializer->GetDataTypeBytes());
+ std::unique_ptr<uint8_t[]> posting_list_buffer_array =
+ std::make_unique<uint8_t[]>(max_posting_list_bytes);
+ ICING_ASSIGN_OR_RETURN(
+ PostingListUsed posting_list_buffer,
+ PostingListUsed::CreateFromUnitializedRegion(
+ serializer, posting_list_buffer_array.get(), max_posting_list_bytes));
+ return std::unique_ptr<PostingListIntegerIndexDataAccessor>(
+ new PostingListIntegerIndexDataAccessor(
+ storage, std::move(posting_list_buffer_array),
+ std::move(posting_list_buffer), serializer));
+}
+
+/* static */ libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListIntegerIndexDataAccessor>>
+PostingListIntegerIndexDataAccessor::CreateFromExisting(
+ FlashIndexStorage* storage,
+ PostingListUsedIntegerIndexDataSerializer* serializer,
+ PostingListIdentifier existing_posting_list_id) {
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor,
+ Create(storage, serializer));
+ ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
+ storage->GetPostingList(existing_posting_list_id));
+ pl_accessor->preexisting_posting_list_ =
+ std::make_unique<PostingListHolder>(std::move(holder));
+ return pl_accessor;
+}
+
+// Returns the next batch of integer index data for the provided posting list.
+libtextclassifier3::StatusOr<std::vector<IntegerIndexData>>
+PostingListIntegerIndexDataAccessor::GetNextDataBatch() {
+ if (preexisting_posting_list_ == nullptr) {
+ if (has_reached_posting_list_chain_end_) {
+ return std::vector<IntegerIndexData>();
+ }
+ return absl_ports::FailedPreconditionError(
+ "Cannot retrieve data from a PostingListIntegerIndexDataAccessor that "
+ "was not created from a preexisting posting list.");
+ }
+ ICING_ASSIGN_OR_RETURN(
+ std::vector<IntegerIndexData> batch,
+ serializer_->GetData(&preexisting_posting_list_->posting_list));
+ uint32_t next_block_index;
+ // Posting lists will only be chained when they are max-sized, in which case
+ // block.next_block_index() will point to the next block for the next posting
+ // list. Otherwise, block.next_block_index() can be kInvalidBlockIndex or be
+ // used to point to the next free list block, which is not relevant here.
+ if (preexisting_posting_list_->block.max_num_posting_lists() == 1) {
+ next_block_index = preexisting_posting_list_->block.next_block_index();
+ } else {
+ next_block_index = kInvalidBlockIndex;
+ }
+ if (next_block_index != kInvalidBlockIndex) {
+ PostingListIdentifier next_posting_list_id(
+ next_block_index, /*posting_list_index=*/0,
+ preexisting_posting_list_->block.posting_list_index_bits());
+ ICING_ASSIGN_OR_RETURN(PostingListHolder holder,
+ storage_->GetPostingList(next_posting_list_id));
+ preexisting_posting_list_ =
+ std::make_unique<PostingListHolder>(std::move(holder));
+ } else {
+ has_reached_posting_list_chain_end_ = true;
+ preexisting_posting_list_.reset();
+ }
+ return batch;
+}
+
+libtextclassifier3::Status PostingListIntegerIndexDataAccessor::PrependData(
+ const IntegerIndexData& data) {
+ PostingListUsed& active_pl = (preexisting_posting_list_ != nullptr)
+ ? preexisting_posting_list_->posting_list
+ : posting_list_buffer_;
+ libtextclassifier3::Status status =
+ serializer_->PrependData(&active_pl, data);
+ if (!absl_ports::IsResourceExhausted(status)) {
+ return status;
+ }
+ // There is no more room to add data to this current posting list! Therefore,
+ // we need to either move those data to a larger posting list or flush this
+ // posting list and create another max-sized posting list in the chain.
+ if (preexisting_posting_list_ != nullptr) {
+ FlushPreexistingPostingList();
+ } else {
+ ICING_RETURN_IF_ERROR(FlushInMemoryPostingList());
+ }
+
+ // Re-add data. Should always fit since we just cleared posting_list_buffer_.
+ // It's fine to explicitly reference posting_list_buffer_ here because there's
+ // no way of reaching this line while preexisting_posting_list_ is still in
+ // use.
+ return serializer_->PrependData(&posting_list_buffer_, data);
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/numeric/posting-list-integer-index-data-accessor.h b/icing/index/numeric/posting-list-integer-index-data-accessor.h
new file mode 100644
index 0000000..7835bf9
--- /dev/null
+++ b/icing/index/numeric/posting-list-integer-index-data-accessor.h
@@ -0,0 +1,108 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_NUMERIC_POSTING_LIST_INTEGER_INDEX_DATA_ACCESSOR_H_
+#define ICING_INDEX_NUMERIC_POSTING_LIST_INTEGER_INDEX_DATA_ACCESSOR_H_
+
+#include <cstdint>
+#include <memory>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-accessor.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/numeric/integer-index-data.h"
+#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h"
+
+namespace icing {
+namespace lib {
+
+// TODO(b/259743562): Refactor PostingListAccessor derived classes
+
+// This class is used to provide a simple abstraction for adding integer index
+// data to posting lists. PostingListIntegerIndexDataAccessor handles:
+// 1) selection of properly-sized posting lists for the accumulated integer
+// index data during Finalize()
+// 2) chaining of max-sized posting lists.
+class PostingListIntegerIndexDataAccessor : public PostingListAccessor {
+ public:
+ // Creates an empty PostingListIntegerIndexDataAccessor.
+ //
+ // RETURNS:
+ // - On success, a valid instance of PostingListIntegerIndexDataAccessor
+ // - INVALID_ARGUMENT error if storage has an invalid block_size.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListIntegerIndexDataAccessor>>
+ Create(FlashIndexStorage* storage,
+ PostingListUsedIntegerIndexDataSerializer* serializer);
+
+ // Create a PostingListIntegerIndexDataAccessor with an existing posting list
+ // identified by existing_posting_list_id.
+ //
+ // RETURNS:
+ // - On success, a valid instance of PostingListIntegerIndexDataAccessor
+ // - INVALID_ARGUMENT if storage has an invalid block_size.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<PostingListIntegerIndexDataAccessor>>
+ CreateFromExisting(FlashIndexStorage* storage,
+ PostingListUsedIntegerIndexDataSerializer* serializer,
+ PostingListIdentifier existing_posting_list_id);
+
+ PostingListUsedSerializer* GetSerializer() override { return serializer_; }
+
+ // Retrieve the next batch of data in the posting list chain
+ //
+ // RETURNS:
+ // - On success, a vector of integer index data in the posting list chain
+ // - INTERNAL if called on an instance that was created via Create, if
+ // unable to read the next posting list in the chain or if the posting
+ // list has been corrupted somehow.
+ libtextclassifier3::StatusOr<std::vector<IntegerIndexData>>
+ GetNextDataBatch();
+
+ // Prepend one data. This may result in flushing the posting list to disk (if
+ // the PostingListIntegerIndexDataAccessor holds a max-sized posting list that
+ // is full) or freeing a pre-existing posting list if it is too small to fit
+ // all data necessary.
+ //
+ // RETURNS:
+ // - OK, on success
+ // - INVALID_ARGUMENT if !data.is_valid() or if data is greater than the
+ // previously added data.
+ // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new
+ // posting list.
+ libtextclassifier3::Status PrependData(const IntegerIndexData& data);
+
+ // TODO(b/259743562): [Optimization 1] add GetAndClear, IsFull for split
+
+ private:
+ explicit PostingListIntegerIndexDataAccessor(
+ FlashIndexStorage* storage,
+ std::unique_ptr<uint8_t[]> posting_list_buffer_array,
+ PostingListUsed posting_list_buffer,
+ PostingListUsedIntegerIndexDataSerializer* serializer)
+ : PostingListAccessor(storage, std::move(posting_list_buffer_array),
+ std::move(posting_list_buffer)),
+ serializer_(serializer) {}
+
+ PostingListUsedIntegerIndexDataSerializer* serializer_; // Does not own.
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_NUMERIC_POSTING_LIST_INTEGER_INDEX_DATA_ACCESSOR_H_
diff --git a/icing/index/numeric/posting-list-integer-index-data-accessor_test.cc b/icing/index/numeric/posting-list-integer-index-data-accessor_test.cc
new file mode 100644
index 0000000..ca0804e
--- /dev/null
+++ b/icing/index/numeric/posting-list-integer-index-data-accessor_test.cc
@@ -0,0 +1,410 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/numeric/posting-list-integer-index-data-accessor.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/posting_list/flash-index-storage.h"
+#include "icing/file/posting_list/posting-list-identifier.h"
+#include "icing/index/numeric/integer-index-data.h"
+#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/tmp-directory.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::Eq;
+using ::testing::Lt;
+using ::testing::SizeIs;
+
+class PostingListIntegerIndexDataAccessorTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ test_dir_ = GetTestTempDir() + "/test_dir";
+ file_name_ = test_dir_ + "/test_file.idx.index";
+
+ ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()));
+ ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(test_dir_.c_str()));
+
+ serializer_ = std::make_unique<PostingListUsedIntegerIndexDataSerializer>();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ FlashIndexStorage flash_index_storage,
+ FlashIndexStorage::Create(file_name_, &filesystem_, serializer_.get()));
+ flash_index_storage_ =
+ std::make_unique<FlashIndexStorage>(std::move(flash_index_storage));
+ }
+
+ void TearDown() override {
+ flash_index_storage_.reset();
+ serializer_.reset();
+ ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()));
+ }
+
+ Filesystem filesystem_;
+ std::string test_dir_;
+ std::string file_name_;
+ std::unique_ptr<PostingListUsedIntegerIndexDataSerializer> serializer_;
+ std::unique_ptr<FlashIndexStorage> flash_index_storage_;
+};
+
+std::vector<IntegerIndexData> CreateData(int num_data,
+ DocumentId start_document_id,
+ int64_t start_key) {
+ SectionId section_id = kMaxSectionId;
+
+ std::vector<IntegerIndexData> data;
+ data.reserve(num_data);
+ for (int i = 0; i < num_data; ++i) {
+ data.push_back(IntegerIndexData(section_id, start_document_id, start_key));
+
+ if (section_id == kMinSectionId) {
+ section_id = kMaxSectionId;
+ } else {
+ --section_id;
+ }
+ ++start_document_id;
+ ++start_key;
+ }
+ return data;
+}
+
+TEST_F(PostingListIntegerIndexDataAccessorTest, DataAddAndRetrieveProperly) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor,
+ PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add some integer index data
+ std::vector<IntegerIndexData> data_vec =
+ CreateData(/*num_data=*/5, /*start_document_id=*/0, /*start_key=*/819);
+ for (const IntegerIndexData& data : data_vec) {
+ EXPECT_THAT(pl_accessor->PrependData(data), IsOk());
+ }
+ PostingListAccessor::FinalizeResult result =
+ std::move(*pl_accessor).Finalize();
+ EXPECT_THAT(result.status, IsOk());
+ EXPECT_THAT(result.id.block_index(), Eq(1));
+ EXPECT_THAT(result.id.posting_list_index(), Eq(0));
+
+ // Retrieve some data.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(result.id));
+ EXPECT_THAT(
+ serializer_->GetData(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(data_vec.rbegin(), data_vec.rend())));
+ EXPECT_THAT(pl_holder.block.next_block_index(), Eq(kInvalidBlockIndex));
+}
+
+TEST_F(PostingListIntegerIndexDataAccessorTest, PreexistingPLKeepOnSameBlock) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor,
+ PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Add a single data. This will fit in a min-sized posting list.
+ IntegerIndexData data1(/*section_id=*/1, /*document_id=*/0, /*key=*/12345);
+ ICING_ASSERT_OK(pl_accessor->PrependData(data1));
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_ASSERT_OK(result1.status);
+ // Should be allocated to the first block.
+ ASSERT_THAT(result1.id.block_index(), Eq(1));
+ ASSERT_THAT(result1.id.posting_list_index(), Eq(0));
+
+ // Add one more data. The minimum size for a posting list must be able to fit
+ // two data, so this should NOT cause the previous pl to be reallocated.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListIntegerIndexDataAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
+ IntegerIndexData data2(/*section_id=*/1, /*document_id=*/1, /*key=*/23456);
+ ICING_ASSERT_OK(pl_accessor->PrependData(data2));
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor).Finalize();
+ ICING_ASSERT_OK(result2.status);
+ // Should be in the same posting list.
+ EXPECT_THAT(result2.id, Eq(result1.id));
+
+ // The posting list at result2.id should hold all of the data that have been
+ // added.
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(result2.id));
+ EXPECT_THAT(serializer_->GetData(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAre(data2, data1)));
+}
+
+TEST_F(PostingListIntegerIndexDataAccessorTest,
+ PreexistingPLReallocateToLargerPL) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor,
+ PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Adding 3 data should cause Finalize allocating a 48-byte posting list,
+ // which can store at most 4 data.
+ std::vector<IntegerIndexData> data_vec1 =
+ CreateData(/*num_data=*/3, /*start_document_id=*/0, /*start_key=*/819);
+ for (const IntegerIndexData& data : data_vec1) {
+ ICING_ASSERT_OK(pl_accessor->PrependData(data));
+ }
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_ASSERT_OK(result1.status);
+ // Should be allocated to the first block.
+ ASSERT_THAT(result1.id.block_index(), Eq(1));
+ ASSERT_THAT(result1.id.posting_list_index(), Eq(0));
+
+ // Now add more data.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListIntegerIndexDataAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
+ // The current posting list can fit 1 more data. Adding 12 more data should
+ // result in these data being moved to a larger posting list. Also the total
+ // size of these data won't exceed max size posting list, so there will be
+ // only one single posting list and no chain.
+ std::vector<IntegerIndexData> data_vec2 = CreateData(
+ /*num_data=*/12,
+ /*start_document_id=*/data_vec1.back().basic_hit().document_id() + 1,
+ /*start_key=*/819);
+
+ for (const IntegerIndexData& data : data_vec2) {
+ ICING_ASSERT_OK(pl_accessor->PrependData(data));
+ }
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor).Finalize();
+ ICING_ASSERT_OK(result2.status);
+ // Should be allocated to the second (new) block because the posting list
+ // should grow beyond the size that the first block maintains.
+ EXPECT_THAT(result2.id.block_index(), Eq(2));
+ EXPECT_THAT(result2.id.posting_list_index(), Eq(0));
+
+ // The posting list at result2.id should hold all of the data that have been
+ // added.
+ std::vector<IntegerIndexData> all_data_vec;
+ all_data_vec.reserve(data_vec1.size() + data_vec2.size());
+ all_data_vec.insert(all_data_vec.end(), data_vec1.begin(), data_vec1.end());
+ all_data_vec.insert(all_data_vec.end(), data_vec2.begin(), data_vec2.end());
+ ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(result2.id));
+ EXPECT_THAT(serializer_->GetData(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(all_data_vec.rbegin(),
+ all_data_vec.rend())));
+}
+
+TEST_F(PostingListIntegerIndexDataAccessorTest,
+ MultiBlockChainsBlocksProperly) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor,
+ PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Block size is 4096, sizeof(BlockHeader) is 12 and sizeof(IntegerIndexData)
+ // is 12, so the max size posting list can store (4096 - 12) / 12 = 340 data.
+ // Adding 341 data should cause:
+ // - 2 max size posting lists being allocated to block 1 and block 2.
+ // - Chaining: block 2 -> block 1
+ std::vector<IntegerIndexData> data_vec =
+ CreateData(/*num_data=*/341, /*start_document_id=*/0, /*start_key=*/819);
+ for (const IntegerIndexData& data : data_vec) {
+ ICING_ASSERT_OK(pl_accessor->PrependData(data));
+ }
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_ASSERT_OK(result1.status);
+ PostingListIdentifier second_block_id = result1.id;
+ // Should be allocated to the second block.
+ EXPECT_THAT(second_block_id, Eq(PostingListIdentifier(
+ /*block_index=*/2, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0)));
+
+ // We should be able to retrieve all data.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(second_block_id));
+ // This pl_holder will only hold a posting list with the data that didn't fit
+ // on the first block.
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<IntegerIndexData> second_block_data,
+ serializer_->GetData(&pl_holder.posting_list));
+ ASSERT_THAT(second_block_data, SizeIs(Lt(data_vec.size())));
+ auto first_block_data_start = data_vec.rbegin() + second_block_data.size();
+ EXPECT_THAT(second_block_data,
+ ElementsAreArray(data_vec.rbegin(), first_block_data_start));
+
+ // Now retrieve all of the data that were on the first block.
+ uint32_t first_block_id = pl_holder.block.next_block_index();
+ EXPECT_THAT(first_block_id, Eq(1));
+
+ PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0);
+ ICING_ASSERT_OK_AND_ASSIGN(pl_holder,
+ flash_index_storage_->GetPostingList(pl_id));
+ EXPECT_THAT(
+ serializer_->GetData(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(first_block_data_start, data_vec.rend())));
+}
+
+TEST_F(PostingListIntegerIndexDataAccessorTest,
+ PreexistingMultiBlockReusesBlocksProperly) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor,
+ PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ // Block size is 4096, sizeof(BlockHeader) is 12 and sizeof(IntegerIndexData)
+ // is 12, so the max size posting list can store (4096 - 12) / 12 = 340 data.
+ // Adding 341 data will cause:
+ // - 2 max size posting lists being allocated to block 1 and block 2.
+ // - Chaining: block 2 -> block 1
+ std::vector<IntegerIndexData> data_vec1 =
+ CreateData(/*num_data=*/341, /*start_document_id=*/0, /*start_key=*/819);
+ for (const IntegerIndexData& data : data_vec1) {
+ ICING_ASSERT_OK(pl_accessor->PrependData(data));
+ }
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor).Finalize();
+ ICING_ASSERT_OK(result1.status);
+ PostingListIdentifier first_add_id = result1.id;
+ EXPECT_THAT(first_add_id, Eq(PostingListIdentifier(
+ /*block_index=*/2, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0)));
+
+ // Now add more data. These should fit on the existing second block and not
+ // fill it up.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ pl_accessor,
+ PostingListIntegerIndexDataAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), first_add_id));
+ std::vector<IntegerIndexData> data_vec2 = CreateData(
+ /*num_data=*/10,
+ /*start_document_id=*/data_vec1.back().basic_hit().document_id() + 1,
+ /*start_key=*/819);
+ for (const IntegerIndexData& data : data_vec2) {
+ ICING_ASSERT_OK(pl_accessor->PrependData(data));
+ }
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor).Finalize();
+ ICING_ASSERT_OK(result2.status);
+ PostingListIdentifier second_add_id = result2.id;
+ EXPECT_THAT(second_add_id, Eq(first_add_id));
+
+ // We should be able to retrieve all data.
+ std::vector<IntegerIndexData> all_data_vec;
+ all_data_vec.reserve(data_vec1.size() + data_vec2.size());
+ all_data_vec.insert(all_data_vec.end(), data_vec1.begin(), data_vec1.end());
+ all_data_vec.insert(all_data_vec.end(), data_vec2.begin(), data_vec2.end());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListHolder pl_holder,
+ flash_index_storage_->GetPostingList(second_add_id));
+ // This pl_holder will only hold a posting list with the data that didn't fit
+ // on the first block.
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<IntegerIndexData> second_block_data,
+ serializer_->GetData(&pl_holder.posting_list));
+ ASSERT_THAT(second_block_data, SizeIs(Lt(all_data_vec.size())));
+ auto first_block_data_start =
+ all_data_vec.rbegin() + second_block_data.size();
+ EXPECT_THAT(second_block_data,
+ ElementsAreArray(all_data_vec.rbegin(), first_block_data_start));
+
+ // Now retrieve all of the data that were on the first block.
+ uint32_t first_block_id = pl_holder.block.next_block_index();
+ EXPECT_THAT(first_block_id, Eq(1));
+
+ PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0,
+ /*posting_list_index_bits=*/0);
+ ICING_ASSERT_OK_AND_ASSIGN(pl_holder,
+ flash_index_storage_->GetPostingList(pl_id));
+ EXPECT_THAT(serializer_->GetData(&pl_holder.posting_list),
+ IsOkAndHolds(ElementsAreArray(first_block_data_start,
+ all_data_vec.rend())));
+}
+
+TEST_F(PostingListIntegerIndexDataAccessorTest,
+ InvalidDataShouldReturnInvalidArgument) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor,
+ PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ IntegerIndexData invalid_data;
+ EXPECT_THAT(pl_accessor->PrependData(invalid_data),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(PostingListIntegerIndexDataAccessorTest,
+ BasicHitIncreasingShouldReturnInvalidArgument) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor,
+ PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ IntegerIndexData data1(/*section_id=*/3, /*document_id=*/1, /*key=*/12345);
+ ICING_ASSERT_OK(pl_accessor->PrependData(data1));
+
+ IntegerIndexData data2(/*section_id=*/6, /*document_id=*/1, /*key=*/12345);
+ EXPECT_THAT(pl_accessor->PrependData(data2),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ IntegerIndexData data3(/*section_id=*/2, /*document_id=*/0, /*key=*/12345);
+ EXPECT_THAT(pl_accessor->PrependData(data3),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(PostingListIntegerIndexDataAccessorTest,
+ NewPostingListNoDataAddedShouldReturnInvalidArgument) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor,
+ PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ PostingListAccessor::FinalizeResult result =
+ std::move(*pl_accessor).Finalize();
+ EXPECT_THAT(result.status,
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(PostingListIntegerIndexDataAccessorTest,
+ PreexistingPostingListNoDataAddedShouldSucceed) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor1,
+ PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(),
+ serializer_.get()));
+ IntegerIndexData data1(/*section_id=*/3, /*document_id=*/1, /*key=*/12345);
+ ICING_ASSERT_OK(pl_accessor1->PrependData(data1));
+ PostingListAccessor::FinalizeResult result1 =
+ std::move(*pl_accessor1).Finalize();
+ ICING_ASSERT_OK(result1.status);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor2,
+ PostingListIntegerIndexDataAccessor::CreateFromExisting(
+ flash_index_storage_.get(), serializer_.get(), result1.id));
+ PostingListAccessor::FinalizeResult result2 =
+ std::move(*pl_accessor2).Finalize();
+ EXPECT_THAT(result2.status, IsOk());
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/numeric/posting-list-used-integer-index-data-serializer.cc b/icing/index/numeric/posting-list-used-integer-index-data-serializer.cc
new file mode 100644
index 0000000..800fd6b
--- /dev/null
+++ b/icing/index/numeric/posting-list-used-integer-index-data-serializer.cc
@@ -0,0 +1,514 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h"
+
+#include <cstdint>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/numeric/integer-index-data.h"
+#include "icing/legacy/core/icing-string-util.h"
+#include "icing/util/logging.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+uint32_t PostingListUsedIntegerIndexDataSerializer::GetBytesUsed(
+ const PostingListUsed* posting_list_used) const {
+ // The special data will be included if they represent actual data. If they
+ // represent the data start offset or the invalid data sentinel, they are not
+ // included.
+ return posting_list_used->size_in_bytes() -
+ GetStartByteOffset(posting_list_used);
+}
+
+uint32_t PostingListUsedIntegerIndexDataSerializer::GetMinPostingListSizeToFit(
+ const PostingListUsed* posting_list_used) const {
+ if (IsFull(posting_list_used) || IsAlmostFull(posting_list_used)) {
+ // If in either the FULL state or ALMOST_FULL state, this posting list *is*
+ // the minimum size posting list that can fit these data. So just return the
+ // size of the posting list.
+ return posting_list_used->size_in_bytes();
+ }
+
+ // In NOT_FULL state, BytesUsed contains no special data. The minimum sized
+ // posting list that would be guaranteed to fit these data would be
+ // ALMOST_FULL, with kInvalidData in special data 0, the uncompressed data in
+ // special data 1 and the n compressed data in the compressed region.
+ // BytesUsed contains one uncompressed data and n compressed data. Therefore,
+ // fitting these data into a posting list would require BytesUsed plus one
+ // extra data.
+ return GetBytesUsed(posting_list_used) + GetDataTypeBytes();
+}
+
+void PostingListUsedIntegerIndexDataSerializer::Clear(
+ PostingListUsed* posting_list_used) const {
+ // Safe to ignore return value because posting_list_used->size_in_bytes() is
+ // a valid argument.
+ SetStartByteOffset(posting_list_used,
+ /*offset=*/posting_list_used->size_in_bytes());
+}
+
+libtextclassifier3::Status PostingListUsedIntegerIndexDataSerializer::MoveFrom(
+ PostingListUsed* dst, PostingListUsed* src) const {
+ ICING_RETURN_ERROR_IF_NULL(dst);
+ ICING_RETURN_ERROR_IF_NULL(src);
+ if (GetMinPostingListSizeToFit(src) > dst->size_in_bytes()) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "src MinPostingListSizeToFit %d must be larger than size %d.",
+ GetMinPostingListSizeToFit(src), dst->size_in_bytes()));
+ }
+
+ if (!IsPostingListValid(dst)) {
+ return absl_ports::FailedPreconditionError(
+ "Dst posting list is in an invalid state and can't be used!");
+ }
+ if (!IsPostingListValid(src)) {
+ return absl_ports::InvalidArgumentError(
+ "Cannot MoveFrom an invalid src posting list!");
+ }
+
+ // Pop just enough data that all of src's compressed data fit in
+ // dst posting_list's compressed area. Then we can memcpy that area.
+ std::vector<IntegerIndexData> data_arr;
+ while (IsFull(src) || IsAlmostFull(src) ||
+ (dst->size_in_bytes() - kSpecialDataSize < GetBytesUsed(src))) {
+ if (!GetDataInternal(src, /*limit=*/1, /*pop=*/true, &data_arr).ok()) {
+ return absl_ports::AbortedError(
+ "Unable to retrieve data from src posting list.");
+ }
+ }
+
+ // memcpy the area and set up start byte offset.
+ Clear(dst);
+ memcpy(dst->posting_list_buffer() + dst->size_in_bytes() - GetBytesUsed(src),
+ src->posting_list_buffer() + GetStartByteOffset(src),
+ GetBytesUsed(src));
+ // Because we popped all data from src outside of the compressed area and we
+ // guaranteed that GetBytesUsed(src) is less than dst->size_in_bytes() -
+ // kSpecialDataSize. This is guaranteed to be a valid byte offset for the
+ // NOT_FULL state, so ignoring the value is safe.
+ SetStartByteOffset(dst, dst->size_in_bytes() - GetBytesUsed(src));
+
+ // Put back remaining data.
+ for (auto riter = data_arr.rbegin(); riter != data_arr.rend(); ++riter) {
+ // PrependData may return:
+ // - INVALID_ARGUMENT: if data is invalid or not less than the previous data
+ // - RESOURCE_EXHAUSTED
+ // RESOURCE_EXHAUSTED should be impossible because we've already assured
+ // that there is enough room above.
+ ICING_RETURN_IF_ERROR(PrependData(dst, *riter));
+ }
+
+ Clear(src);
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status
+PostingListUsedIntegerIndexDataSerializer::PrependDataToAlmostFull(
+ PostingListUsed* posting_list_used, const IntegerIndexData& data) const {
+ SpecialDataType special_data = GetSpecialData(posting_list_used, /*index=*/1);
+ if (special_data.data().basic_hit() < data.basic_hit()) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "BasicHit %d being prepended must not be greater than the most recent"
+ "BasicHit %d",
+ data.basic_hit().value(), special_data.data().basic_hit().value()));
+ }
+
+ // TODO(b/259743562): [Optimization 2] compression
+ // Without compression, prepend a new data into ALMOST_FULL posting list will
+ // change the posting list to FULL state. Therefore, set special data 0
+ // directly.
+ SetSpecialData(posting_list_used, /*index=*/0, SpecialDataType(data));
+ return libtextclassifier3::Status::OK;
+}
+
+void PostingListUsedIntegerIndexDataSerializer::PrependDataToEmpty(
+ PostingListUsed* posting_list_used, const IntegerIndexData& data) const {
+ // First data to be added. Just add verbatim, no compression.
+ if (posting_list_used->size_in_bytes() == kSpecialDataSize) {
+ // First data will be stored at special data 1.
+ // Safe to ignore the return value because 1 < kNumSpecialData
+ SetSpecialData(posting_list_used, /*index=*/1, SpecialDataType(data));
+ // Safe to ignore the return value because sizeof(IntegerIndexData) is a
+ // valid argument.
+ SetStartByteOffset(posting_list_used,
+ /*offset=*/sizeof(IntegerIndexData));
+ } else {
+ // Since this is the first data, size != kSpecialDataSize and
+ // size % sizeof(IntegerIndexData) == 0, we know that there is room to fit
+ // 'data' into the compressed region, so ValueOrDie is safe.
+ uint32_t offset =
+ PrependDataUncompressed(posting_list_used, data,
+ /*offset=*/posting_list_used->size_in_bytes())
+ .ValueOrDie();
+ // Safe to ignore the return value because PrependDataUncompressed is
+ // guaranteed to return a valid offset.
+ SetStartByteOffset(posting_list_used, offset);
+ }
+}
+
+libtextclassifier3::Status
+PostingListUsedIntegerIndexDataSerializer::PrependDataToNotFull(
+ PostingListUsed* posting_list_used, const IntegerIndexData& data,
+ uint32_t offset) const {
+ IntegerIndexData cur;
+ memcpy(&cur, posting_list_used->posting_list_buffer() + offset,
+ sizeof(IntegerIndexData));
+ if (cur.basic_hit() < data.basic_hit()) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "BasicHit %d being prepended must not be greater than the most recent"
+ "BasicHit %d",
+ data.basic_hit().value(), cur.basic_hit().value()));
+ }
+
+ // TODO(b/259743562): [Optimization 2] compression
+ if (offset >= kSpecialDataSize + sizeof(IntegerIndexData)) {
+ offset =
+ PrependDataUncompressed(posting_list_used, data, offset).ValueOrDie();
+ SetStartByteOffset(posting_list_used, offset);
+ } else {
+ // The new data must be put in special data 1.
+ SetSpecialData(posting_list_used, /*index=*/1, SpecialDataType(data));
+ // State ALMOST_FULL. Safe to ignore the return value because
+ // sizeof(IntegerIndexData) is a valid argument.
+ SetStartByteOffset(posting_list_used, /*offset=*/sizeof(IntegerIndexData));
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status
+PostingListUsedIntegerIndexDataSerializer::PrependData(
+ PostingListUsed* posting_list_used, const IntegerIndexData& data) const {
+ static_assert(
+ sizeof(BasicHit::Value) <= sizeof(uint64_t),
+ "BasicHit::Value cannot be larger than 8 bytes because the delta "
+ "must be able to fit in 8 bytes.");
+
+ if (!data.is_valid()) {
+ return absl_ports::InvalidArgumentError("Cannot prepend an invalid data!");
+ }
+ if (!IsPostingListValid(posting_list_used)) {
+ return absl_ports::FailedPreconditionError(
+ "This PostingListUsed is in an invalid state and can't add any data!");
+ }
+
+ if (IsFull(posting_list_used)) {
+ // State FULL: no space left.
+ return absl_ports::ResourceExhaustedError("No more room for data");
+ } else if (IsAlmostFull(posting_list_used)) {
+ return PrependDataToAlmostFull(posting_list_used, data);
+ } else if (IsEmpty(posting_list_used)) {
+ PrependDataToEmpty(posting_list_used, data);
+ return libtextclassifier3::Status::OK;
+ } else {
+ uint32_t offset = GetStartByteOffset(posting_list_used);
+ return PrependDataToNotFull(posting_list_used, data, offset);
+ }
+}
+
+uint32_t PostingListUsedIntegerIndexDataSerializer::PrependDataArray(
+ PostingListUsed* posting_list_used, const IntegerIndexData* array,
+ uint32_t num_data, bool keep_prepended) const {
+ if (!IsPostingListValid(posting_list_used)) {
+ return 0;
+ }
+
+ uint32_t i;
+ for (i = 0; i < num_data; ++i) {
+ if (!PrependData(posting_list_used, array[i]).ok()) {
+ break;
+ }
+ }
+ if (i != num_data && !keep_prepended) {
+ // Didn't fit. Undo everything and check that we have the same offset as
+ // before. PopFrontData guarantees that it will remove all 'i' data so long
+ // as there are at least 'i' data in the posting list, which we know there
+ // are.
+ PopFrontData(posting_list_used, /*num_data=*/i);
+ return 0;
+ }
+ return i;
+}
+
+libtextclassifier3::StatusOr<std::vector<IntegerIndexData>>
+PostingListUsedIntegerIndexDataSerializer::GetData(
+ const PostingListUsed* posting_list_used) const {
+ std::vector<IntegerIndexData> data_arr_out;
+ ICING_RETURN_IF_ERROR(GetData(posting_list_used, &data_arr_out));
+ return data_arr_out;
+}
+
+libtextclassifier3::Status PostingListUsedIntegerIndexDataSerializer::GetData(
+ const PostingListUsed* posting_list_used,
+ std::vector<IntegerIndexData>* data_arr_out) const {
+ return GetDataInternal(posting_list_used,
+ /*limit=*/std::numeric_limits<uint32_t>::max(),
+ /*pop=*/false, data_arr_out);
+}
+
+libtextclassifier3::Status
+PostingListUsedIntegerIndexDataSerializer::PopFrontData(
+ PostingListUsed* posting_list_used, uint32_t num_data) const {
+ if (num_data == 1 && IsFull(posting_list_used)) {
+ // The PL is in FULL state which means that we save 2 uncompressed data in
+ // the 2 special postions. But FULL state may be reached by 2 different
+ // states.
+ // (1) In ALMOST_FULL state
+ // +------------------+-----------------+-----+---------------------------+
+ // |Data::Invalid |1st data |(pad)|(compressed) data |
+ // | | | | |
+ // +------------------+-----------------+-----+---------------------------+
+ // When we prepend another data, we can only put it at special data 0, and
+ // thus get a FULL PL
+ // +------------------+-----------------+-----+---------------------------+
+ // |new 1st data |original 1st data|(pad)|(compressed) data |
+ // | | | | |
+ // +------------------+-----------------+-----+---------------------------+
+ //
+ // (2) In NOT_FULL state
+ // +------------------+-----------------+-------+---------+---------------+
+ // |data-start-offset |Data::Invalid |(pad) |1st data |(compressed) |
+ // | | | | |data |
+ // +------------------+-----------------+-------+---------+---------------+
+ // When we prepend another data, we can reach any of the 3 following
+ // scenarios:
+ // (2.1) NOT_FULL
+ // if the space of pad and original 1st data can accommodate the new 1st
+ // data and the encoded delta value.
+ // +------------------+-----------------+-----+--------+------------------+
+ // |data-start-offset |Data::Invalid |(pad)|new |(compressed) data |
+ // | | | |1st data| |
+ // +------------------+-----------------+-----+--------+------------------+
+ // (2.2) ALMOST_FULL
+ // If the space of pad and original 1st data cannot accommodate the new 1st
+ // data and the encoded delta value but can accommodate the encoded delta
+ // value only. We can put the new 1st data at special position 1.
+ // +------------------+-----------------+---------+-----------------------+
+ // |Data::Invalid |new 1st data |(pad) |(compressed) data |
+ // | | | | |
+ // +------------------+-----------------+---------+-----------------------+
+ // (2.3) FULL
+ // In very rare case, it cannot even accommodate only the encoded delta
+ // value. we can move the original 1st data into special position 1 and the
+ // new 1st data into special position 0. This may happen because we use
+ // VarInt encoding method which may make the encoded value longer (about
+ // 4/3 times of original)
+ // +------------------+-----------------+--------------+------------------+
+ // |new 1st data |original 1st data|(pad) |(compressed) data |
+ // | | | | |
+ // +------------------+-----------------+--------------+------------------+
+ //
+ // Suppose now the PL is in FULL state. But we don't know whether it arrived
+ // this state from NOT_FULL (like (2.3)) or from ALMOST_FULL (like (1)).
+ // We'll return to ALMOST_FULL state like (1) if we simply pop the new 1st
+ // data, but we want to make the prepending operation "reversible". So
+ // there should be some way to return to NOT_FULL if possible. A simple way
+ // to do is:
+ // - Pop 2 data out of the PL to state ALMOST_FULL or NOT_FULL.
+ // - Add the second data ("original 1st data") back.
+ //
+ // Then we can return to the correct original states of (2.1) or (1). This
+ // makes our prepending operation reversible.
+ std::vector<IntegerIndexData> out;
+
+ // Popping 2 data should never fail because we've just ensured that the
+ // posting list is in the FULL state.
+ ICING_RETURN_IF_ERROR(
+ GetDataInternal(posting_list_used, /*limit=*/2, /*pop=*/true, &out));
+
+ // PrependData should never fail because:
+ // - out[1] is a valid data less than all previous data in the posting list.
+ // - There's no way that the posting list could run out of room because it
+ // previously stored these 2 data.
+ PrependData(posting_list_used, out[1]);
+ } else if (num_data > 0) {
+ return GetDataInternal(posting_list_used, /*limit=*/num_data, /*pop=*/true,
+ /*out=*/nullptr);
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::Status
+PostingListUsedIntegerIndexDataSerializer::GetDataInternal(
+ const PostingListUsed* posting_list_used, uint32_t limit, bool pop,
+ std::vector<IntegerIndexData>* out) const {
+ // TODO(b/259743562): [Optimization 2] handle compressed data
+
+ uint32_t offset = GetStartByteOffset(posting_list_used);
+ uint32_t count = 0;
+
+ // First traverse the first two special positions.
+ while (count < limit && offset < kSpecialDataSize) {
+ // offset / sizeof(IntegerIndexData) < kNumSpecialData because of the check
+ // above.
+ SpecialDataType special_data =
+ GetSpecialData(posting_list_used,
+ /*index=*/offset / sizeof(IntegerIndexData));
+ if (out != nullptr) {
+ out->push_back(special_data.data());
+ }
+ offset += sizeof(IntegerIndexData);
+ ++count;
+ }
+
+ // - We don't compress the data now.
+ // - The posting list size is a multiple of data type bytes.
+ // So offset of the first non-special data is guaranteed to be at
+ // kSpecialDataSize if in ALMOST_FULL or FULL state. In fact, we must not
+ // apply padding skipping logic here when still storing uncompressed data,
+ // because in this case 0 bytes are meanful (e.g. inverted doc id byte = 0).
+ // TODO(b/259743562): [Optimization 2] deal with padding skipping logic when
+ // apply data compression.
+
+ while (count < limit && offset < posting_list_used->size_in_bytes()) {
+ IntegerIndexData data;
+ memcpy(&data, posting_list_used->posting_list_buffer() + offset,
+ sizeof(IntegerIndexData));
+ offset += sizeof(IntegerIndexData);
+ if (out != nullptr) {
+ out->push_back(data);
+ }
+ ++count;
+ }
+
+ if (pop) {
+ PostingListUsed* mutable_posting_list_used =
+ const_cast<PostingListUsed*>(posting_list_used);
+ // Modify the posting list so that we pop all data actually traversed.
+ if (offset >= kSpecialDataSize &&
+ offset < posting_list_used->size_in_bytes()) {
+ memset(
+ mutable_posting_list_used->posting_list_buffer() + kSpecialDataSize,
+ 0, offset - kSpecialDataSize);
+ }
+ SetStartByteOffset(mutable_posting_list_used, offset);
+ }
+
+ return libtextclassifier3::Status::OK;
+}
+
+PostingListUsedIntegerIndexDataSerializer::SpecialDataType
+PostingListUsedIntegerIndexDataSerializer::GetSpecialData(
+ const PostingListUsed* posting_list_used, uint32_t index) const {
+ // It is ok to temporarily construct a SpecialData with offset = 0 since we're
+ // going to overwrite it by memcpy.
+ SpecialDataType special_data(0);
+ memcpy(&special_data,
+ posting_list_used->posting_list_buffer() +
+ index * sizeof(SpecialDataType),
+ sizeof(SpecialDataType));
+ return special_data;
+}
+
+void PostingListUsedIntegerIndexDataSerializer::SetSpecialData(
+ PostingListUsed* posting_list_used, uint32_t index,
+ const SpecialDataType& special_data) const {
+ memcpy(posting_list_used->posting_list_buffer() +
+ index * sizeof(SpecialDataType),
+ &special_data, sizeof(SpecialDataType));
+}
+
+bool PostingListUsedIntegerIndexDataSerializer::IsPostingListValid(
+ const PostingListUsed* posting_list_used) const {
+ if (IsAlmostFull(posting_list_used)) {
+ // Special data 1 should hold a valid data.
+ if (!GetSpecialData(posting_list_used, /*index=*/1).data().is_valid()) {
+ ICING_LOG(ERROR)
+ << "Both special data cannot be invalid at the same time.";
+ return false;
+ }
+ } else if (!IsFull(posting_list_used)) {
+ // NOT_FULL. Special data 0 should hold a valid offset.
+ SpecialDataType special_data =
+ GetSpecialData(posting_list_used, /*index=*/0);
+ if (special_data.data_start_offset() > posting_list_used->size_in_bytes() ||
+ special_data.data_start_offset() < kSpecialDataSize) {
+ ICING_LOG(ERROR) << "Offset: " << special_data.data_start_offset()
+ << " size: " << posting_list_used->size_in_bytes()
+ << " sp size: " << kSpecialDataSize;
+ return false;
+ }
+ }
+ return true;
+}
+
+uint32_t PostingListUsedIntegerIndexDataSerializer::GetStartByteOffset(
+ const PostingListUsed* posting_list_used) const {
+ if (IsFull(posting_list_used)) {
+ return 0;
+ } else if (IsAlmostFull(posting_list_used)) {
+ return sizeof(IntegerIndexData);
+ } else {
+ return GetSpecialData(posting_list_used, /*index=*/0).data_start_offset();
+ }
+}
+
+bool PostingListUsedIntegerIndexDataSerializer::SetStartByteOffset(
+ PostingListUsed* posting_list_used, uint32_t offset) const {
+ if (offset > posting_list_used->size_in_bytes()) {
+ ICING_LOG(ERROR) << "offset cannot be a value greater than size "
+ << posting_list_used->size_in_bytes() << ". offset is "
+ << offset << ".";
+ return false;
+ }
+ if (offset < kSpecialDataSize && offset > sizeof(IntegerIndexData)) {
+ ICING_LOG(ERROR) << "offset cannot be a value between ("
+ << sizeof(IntegerIndexData) << ", " << kSpecialDataSize
+ << "). offset is " << offset << ".";
+ return false;
+ }
+ if (offset < sizeof(IntegerIndexData) && offset != 0) {
+ ICING_LOG(ERROR) << "offset cannot be a value between (0, "
+ << sizeof(IntegerIndexData) << "). offset is " << offset
+ << ".";
+ return false;
+ }
+
+ if (offset >= kSpecialDataSize) {
+ // NOT_FULL state.
+ SetSpecialData(posting_list_used, /*index=*/0, SpecialDataType(offset));
+ SetSpecialData(posting_list_used, /*index=*/1,
+ SpecialDataType(IntegerIndexData()));
+ } else if (offset == sizeof(IntegerIndexData)) {
+ // ALMOST_FULL state.
+ SetSpecialData(posting_list_used, /*index=*/0,
+ SpecialDataType(IntegerIndexData()));
+ }
+ // Nothing to do for the FULL state - the offset isn't actually stored
+ // anywhere and both 2 special data hold valid data.
+ return true;
+}
+
+libtextclassifier3::StatusOr<uint32_t>
+PostingListUsedIntegerIndexDataSerializer::PrependDataUncompressed(
+ PostingListUsed* posting_list_used, const IntegerIndexData& data,
+ uint32_t offset) const {
+ if (offset < kSpecialDataSize + sizeof(IntegerIndexData)) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "Not enough room to prepend IntegerIndexData at offset %d.", offset));
+ }
+ offset -= sizeof(IntegerIndexData);
+ memcpy(posting_list_used->posting_list_buffer() + offset, &data,
+ sizeof(IntegerIndexData));
+ return offset;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/numeric/posting-list-used-integer-index-data-serializer.h b/icing/index/numeric/posting-list-used-integer-index-data-serializer.h
new file mode 100644
index 0000000..49007e3
--- /dev/null
+++ b/icing/index/numeric/posting-list-used-integer-index-data-serializer.h
@@ -0,0 +1,338 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_NUMERIC_POSTING_LIST_USED_INTEGER_INDEX_DATA_SERIALIZER_H_
+#define ICING_INDEX_NUMERIC_POSTING_LIST_USED_INTEGER_INDEX_DATA_SERIALIZER_H_
+
+#include <cstdint>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/file/posting_list/posting-list-common.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/numeric/integer-index-data.h"
+
+namespace icing {
+namespace lib {
+
+// A serializer class to serialize IntegerIndexData to PostingListUsed.
+class PostingListUsedIntegerIndexDataSerializer
+ : public PostingListUsedSerializer {
+ public:
+ using SpecialDataType = SpecialData<IntegerIndexData>;
+ static_assert(sizeof(SpecialDataType) == sizeof(IntegerIndexData), "");
+
+ static constexpr uint32_t kSpecialDataSize =
+ kNumSpecialData * sizeof(SpecialDataType);
+
+ uint32_t GetDataTypeBytes() const override {
+ return sizeof(IntegerIndexData);
+ }
+
+ uint32_t GetMinPostingListSize() const override {
+ static constexpr uint32_t kMinPostingListSize = kSpecialDataSize;
+ static_assert(sizeof(PostingListIndex) <= kMinPostingListSize,
+ "PostingListIndex must be small enough to fit in a "
+ "minimum-sized Posting List.");
+
+ return kMinPostingListSize;
+ }
+
+ uint32_t GetMinPostingListSizeToFit(
+ const PostingListUsed* posting_list_used) const override;
+
+ uint32_t GetBytesUsed(
+ const PostingListUsed* posting_list_used) const override;
+
+ void Clear(PostingListUsed* posting_list_used) const override;
+
+ libtextclassifier3::Status MoveFrom(PostingListUsed* dst,
+ PostingListUsed* src) const override;
+
+ // Prepend an IntegerIndexData to the posting list.
+ //
+ // RETURNS:
+ // - INVALID_ARGUMENT if !data.is_valid() or if data is not less than the
+ // previously added data.
+ // - RESOURCE_EXHAUSTED if there is no more room to add data to the posting
+ // list.
+ libtextclassifier3::Status PrependData(PostingListUsed* posting_list_used,
+ const IntegerIndexData& data) const;
+
+ // Prepend multiple IntegerIndexData to the posting list. Data should be
+ // sorted in ascending order (as defined by the less than operator for
+ // IntegerIndexData)
+ // If keep_prepended is true, whatever could be prepended is kept, otherwise
+ // the posting list is reverted and left in its original state.
+ //
+ // RETURNS:
+ // The number of data that have been prepended to the posting list. If
+ // keep_prepended is false and reverted, then it returns 0.
+ uint32_t PrependDataArray(PostingListUsed* posting_list_used,
+ const IntegerIndexData* array, uint32_t num_data,
+ bool keep_prepended) const;
+
+ // Retrieves all data stored in the posting list.
+ //
+ // RETURNS:
+ // - On success, a vector of IntegerIndexData sorted by the reverse order of
+ // prepending.
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::StatusOr<std::vector<IntegerIndexData>> GetData(
+ const PostingListUsed* posting_list_used) const;
+
+ // Same as GetData but appends data to data_arr_out.
+ //
+ // RETURNS:
+ // - OK on success, and data_arr_out will be appended IntegerIndexData
+ // sorted by the reverse order of prepending.
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::Status GetData(
+ const PostingListUsed* posting_list_used,
+ std::vector<IntegerIndexData>* data_arr_out) const;
+
+ // Undo the last num_data data prepended. If num_data > number of data, then
+ // we clear all data.
+ //
+ // RETURNS:
+ // - OK on success
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::Status PopFrontData(PostingListUsed* posting_list_used,
+ uint32_t num_data) const;
+
+ private:
+ // Posting list layout formats:
+ //
+ // NOT_FULL
+ // +-special-data-0--+-special-data-1--+------------+-----------------------+
+ // | | | | |
+ // |data-start-offset| Data::Invalid | 0x00000000 | (compressed) data |
+ // | | | | |
+ // +-----------------+-----------------+------------+-----------------------+
+ //
+ // ALMOST_FULL
+ // +-special-data-0--+-special-data-1--+-----+------------------------------+
+ // | | | | |
+ // | Data::Invalid | 1st data |(pad)| (compressed) data |
+ // | | | | |
+ // +-----------------+-----------------+-----+------------------------------+
+ //
+ // FULL
+ // +-special-data-0--+-special-data-1--+-----+------------------------------+
+ // | | | | |
+ // | 1st data | 2nd data |(pad)| (compressed) data |
+ // | | | | |
+ // +-----------------+-----------------+-----+------------------------------+
+ //
+ // The first two uncompressed (special) data also implicitly encode
+ // information about the size of the compressed data region.
+ //
+ // 1. If the posting list is NOT_FULL, then special_data_0 contains the byte
+ // offset of the start of the compressed data. Thus, the size of the
+ // compressed data is
+ // posting_list_used->size_in_bytes() - special_data_0.data_start_offset().
+ //
+ // 2. If posting list is ALMOST_FULL or FULL, then the compressed data region
+ // starts somewhere between
+ // [kSpecialDataSize, kSpecialDataSize + sizeof(IntegerIndexData) - 1] and
+ // ends at posting_list_used->size_in_bytes() - 1.
+ //
+ // EXAMPLE
+ // Posting list storage. Posting list size: 36 bytes
+ //
+ // EMPTY!
+ // +--- byte 0-11 ---+----- 12-23 -----+-------------- 24-35 ---------------+
+ // | | | |
+ // | 36 | Data::Invalid | 0x00000000 |
+ // | | | |
+ // +-----------------+-----------------+------------------------------------+
+ //
+ // Add IntegerIndexData(0x0FFFFCC3, 5)
+ // (DocumentId = 12, SectionId = 3; Key = 5)
+ // (VarInt64(5) is encoded as 10 (b'1010), requires 1 byte)
+ // NOT FULL!
+ // +--- byte 0-11 ---+----- 12-23 -----+------- 24-30 -------+--- 31-35 ----+
+ // | | | | 0x0FFFFCC3 |
+ // | 31 | Data::Invalid | 0x00000000 | VI64(5) |
+ // | | | | |
+ // +-----------------+-----------------+---------------------+--------------+
+ //
+ // Add IntegerIndexData(0x0FFFFB40, -2)
+ // (DocumentId = 18, SectionId = 0; Key = -2)
+ // (VarInt64(-2) is encoded as 3 (b'11), requires 1 byte)
+ // Previous IntegerIndexData BasicHit delta varint encoding:
+ // 0x0FFFFCC3 - 0x0FFFFB40 = 387, VarUnsignedInt(387) requires 2 bytes
+ // +--- byte 0-11 ---+----- 12-23 -----+-- 24-27 ---+--- 28-32 ----+ 33-35 -+
+ // | | | | 0x0FFFFB40 |VUI(387)|
+ // | 28 | Data::Invalid | 0x00 | VI64(-2) |VI64(5) |
+ // | | | | | |
+ // +-----------------+-----------------+------------+--------------+--------+
+ //
+ // Add IntegerIndexData(0x0FFFFA4A, 3)
+ // (DocumentId = 22, SectionId = 10; Key = 3)
+ // (VarInt64(3) is encoded as 6 (b'110), requires 1 byte)
+ // Previous IntegerIndexData BasicHit delta varint encoding:
+ // 0x0FFFFB40 - 0x0FFFFA4A = 246, VarUnsignedInt(246) requires 2 bytes
+ // +--- byte 0-11 ---+----- 12-23 -----+---+--- 25-29 ----+ 30-32 -+ 33-35 -+
+ // | | | | 0x0FFFFA4A |VUI(246)|VUI(387)|
+ // | 25 | Data::Invalid | | VI64(3) |VI64(-2)|VI64(5) |
+ // | | | | | | |
+ // +-----------------+-----------------+---+--------------+--------+--------+
+ //
+ // Add IntegerIndexData(0x0FFFFA01, -4)
+ // (DocumentId = 23, SectionId = 1; Key = -4)
+ // (No VarInt64 for key, since it is stored in special data section)
+ // Previous IntegerIndexData BasicHit delta varint encoding:
+ // 0x0FFFFA4A - 0x0FFFFA01 = 73, VarUnsignedInt(73) requires 1 byte)
+ // ALMOST_FULL!
+ // +--- byte 0-11 ---+----- 12-23 -----+-- 24-27 ---+28-29+ 30-32 -+ 33-35 -+
+ // | | 0x0FFFFA01 | |(73) |VUI(246)|VUI(387)|
+ // | Data::Invalid | 0xFFFFFFFF | (pad) |(3) |VI64(-2)|VI64(5) |
+ // | | 0xFFFFFFFC | | | | |
+ // +-----------------+-----------------+------------+-----+--------+--------+
+ //
+ // Add IntegerIndexData(0x0FFFF904, 0)
+ // (DocumentId = 27, SectionId = 4; Key = 0)
+ // (No VarInt64 for key, since it is stored in special data section)
+ // Previous IntegerIndexData:
+ // Since 0x0FFFFA01 - 0x0FFFF904 = 253 and VarInt64(-4) is encoded as 7
+ // (b'111), it requires only 3 bytes after compression. It's able to fit
+ // into the padding section.
+ // Still ALMOST_FULL!
+ // +--- byte 0-11 ---+----- 12-23 -----+---+ 25-27 -+28-29+ 30-32 -+ 33-35 -+
+ // | | 0x0FFFF904 | |VUI(253)|(73) |VUI(246)|VUI(387)|
+ // | Data::Invalid | 0x00000000 | |VI64(-4)|(3) |VI64(-2)|VI64(5) |
+ // | | 0x00000000 | | | | | |
+ // +-----------------+-----------------+---+--------+-----+--------+--------+
+ //
+ // Add IntegerIndexData(0x0FFFF8C3, -1)
+ // (DocumentId = 28, SectionId = 3; Key = -1)
+ // (No VarInt64 for key, since it is stored in special data section)
+ // (No VarUnsignedInt for previous IntegerIndexData BasicHit)
+ // FULL!
+ // +--- byte 0-11 ---+----- 12-23 -----+---+ 25-27 -+28-29+ 30-32 -+ 33-35 -+
+ // | 0x0FFFF8C3 | 0x0FFFF904 | |VUI(253)|(73) |VUI(246)|VUI(387)|
+ // | 0xFFFFFFFF | 0x00000000 | |VI64(-4)|(3) |VI64(-2)|VI64(5) |
+ // | 0xFFFFFFFF | 0x00000000 | | | | | |
+ // +-----------------+-----------------+---+--------+-----+--------+--------+
+
+ // Helpers to determine what state the posting list is in.
+ bool IsFull(const PostingListUsed* posting_list_used) const {
+ return GetSpecialData(posting_list_used, /*index=*/0).data().is_valid() &&
+ GetSpecialData(posting_list_used, /*index=*/1).data().is_valid();
+ }
+
+ bool IsAlmostFull(const PostingListUsed* posting_list_used) const {
+ return !GetSpecialData(posting_list_used, /*index=*/0).data().is_valid() &&
+ GetSpecialData(posting_list_used, /*index=*/1).data().is_valid();
+ }
+
+ bool IsEmpty(const PostingListUsed* posting_list_used) const {
+ return GetSpecialData(posting_list_used, /*index=*/0).data_start_offset() ==
+ posting_list_used->size_in_bytes() &&
+ !GetSpecialData(posting_list_used, /*index=*/1).data().is_valid();
+ }
+
+ // Returns false if both special data are invalid or if data start offset
+ // stored in the special data is less than kSpecialDataSize or greater than
+ // posting_list_used->size_in_bytes(). Returns true, otherwise.
+ bool IsPostingListValid(const PostingListUsed* posting_list_used) const;
+
+ // Prepend data to a posting list that is in the ALMOST_FULL state.
+ //
+ // RETURNS:
+ // - OK, if successful
+ // - INVALID_ARGUMENT if data is not less than the previously added data.
+ libtextclassifier3::Status PrependDataToAlmostFull(
+ PostingListUsed* posting_list_used, const IntegerIndexData& data) const;
+
+ // Prepend data to a posting list that is in the EMPTY state. This will always
+ // succeed because there are no pre-existing data and no validly constructed
+ // posting list could fail to fit one data.
+ void PrependDataToEmpty(PostingListUsed* posting_list_used,
+ const IntegerIndexData& data) const;
+
+ // Prepend data to a posting list that is in the NOT_FULL state.
+ //
+ // RETURNS:
+ // - OK, if successful
+ // - INVALID_ARGUMENT if data is not less than the previously added data.
+ libtextclassifier3::Status PrependDataToNotFull(
+ PostingListUsed* posting_list_used, const IntegerIndexData& data,
+ uint32_t offset) const;
+
+ // Returns either 0 (FULL state), sizeof(IntegerIndexData) (ALMOST_FULL state)
+ // or a byte offset between kSpecialDataSize and
+ // posting_list_used->size_in_bytes() (inclusive) (NOT_FULL state).
+ uint32_t GetStartByteOffset(const PostingListUsed* posting_list_used) const;
+
+ // Sets special data 0 to properly reflect what start byte offset is (see
+ // layout comment for further details).
+ //
+ // Returns false if offset > posting_list_used->size_in_bytes() or offset is
+ // in range (kSpecialDataSize, sizeof(IntegerIndexData)) or
+ // (sizeof(IntegerIndexData), 0). True, otherwise.
+ bool SetStartByteOffset(PostingListUsed* posting_list_used,
+ uint32_t offset) const;
+
+ // Helper for MoveFrom/GetData/PopFrontData. Adds limit number of data to out
+ // or all data in the posting list if the posting list contains less than
+ // limit number of data. out can be NULL.
+ //
+ // NOTE: If called with limit=1, pop=true on a posting list that transitioned
+ // from NOT_FULL directly to FULL, GetDataInternal will not return the posting
+ // list to NOT_FULL. Instead it will leave it in a valid state, but it will be
+ // ALMOST_FULL.
+ //
+ // RETURNS:
+ // - OK on success
+ // - INTERNAL_ERROR if the posting list has been corrupted somehow.
+ libtextclassifier3::Status GetDataInternal(
+ const PostingListUsed* posting_list_used, uint32_t limit, bool pop,
+ std::vector<IntegerIndexData>* out) const;
+
+ // Retrieves the value stored in the index-th special data.
+ //
+ // REQUIRES:
+ // 0 <= index < kNumSpecialData.
+ //
+ // RETURNS:
+ // - A valid SpecialData<IntegerIndexData>.
+ SpecialDataType GetSpecialData(const PostingListUsed* posting_list_used,
+ uint32_t index) const;
+
+ // Sets the value stored in the index-th special data to special_data.
+ //
+ // REQUIRES:
+ // 0 <= index < kNumSpecialData.
+ void SetSpecialData(PostingListUsed* posting_list_used, uint32_t index,
+ const SpecialDataType& special_data) const;
+
+ // Prepends data to the memory region [offset - sizeof(IntegerIndexData),
+ // offset - 1] and returns the new beginning of the region.
+ //
+ // RETURNS:
+ // - The new beginning of the padded region, if successful.
+ // - INVALID_ARGUMENT if data will not fit (uncompressed) between
+ // [kSpecialDataSize, offset - 1]
+ libtextclassifier3::StatusOr<uint32_t> PrependDataUncompressed(
+ PostingListUsed* posting_list_used, const IntegerIndexData& data,
+ uint32_t offset) const;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_NUMERIC_POSTING_LIST_USED_INTEGER_INDEX_DATA_SERIALIZER_H_
diff --git a/icing/index/numeric/posting-list-used-integer-index-data-serializer_test.cc b/icing/index/numeric/posting-list-used-integer-index-data-serializer_test.cc
new file mode 100644
index 0000000..c270137
--- /dev/null
+++ b/icing/index/numeric/posting-list-used-integer-index-data-serializer_test.cc
@@ -0,0 +1,523 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h"
+
+#include <memory>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/file/posting_list/posting-list-used.h"
+#include "icing/index/numeric/integer-index-data.h"
+#include "icing/testing/common-matchers.h"
+
+using testing::ElementsAre;
+using testing::ElementsAreArray;
+using testing::Eq;
+using testing::IsEmpty;
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+// TODO(b/259743562): [Optimization 2] update unit tests after applying
+// compression. Remember to create varint/delta encoding
+// overflow (which causes state NOT_FULL -> FULL directly
+// without ALMOST_FULL) test cases, including for
+// PopFrontData.
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest,
+ GetMinPostingListSizeToFitNotNull) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 2551 * sizeof(IntegerIndexData);
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+
+ ASSERT_THAT(serializer.PrependData(
+ &pl_used, IntegerIndexData(/*section_id=*/0,
+ /*document_id=*/0, /*key=*/2)),
+ IsOk());
+ EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used),
+ Eq(2 * sizeof(IntegerIndexData)));
+
+ ASSERT_THAT(serializer.PrependData(
+ &pl_used, IntegerIndexData(/*section_id=*/0,
+ /*document_id=*/1, /*key=*/5)),
+ IsOk());
+ EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used),
+ Eq(3 * sizeof(IntegerIndexData)));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest,
+ GetMinPostingListSizeToFitAlmostFull) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 3 * sizeof(IntegerIndexData);
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+
+ ASSERT_THAT(serializer.PrependData(
+ &pl_used, IntegerIndexData(/*section_id=*/0,
+ /*document_id=*/0, /*key=*/2)),
+ IsOk());
+ ASSERT_THAT(serializer.PrependData(
+ &pl_used, IntegerIndexData(/*section_id=*/0,
+ /*document_id=*/1, /*key=*/5)),
+ IsOk());
+ EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used), Eq(size));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest,
+ GetMinPostingListSizeToFitFull) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 3 * sizeof(IntegerIndexData);
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+
+ ASSERT_THAT(serializer.PrependData(
+ &pl_used, IntegerIndexData(/*section_id=*/0,
+ /*document_id=*/0, /*key=*/2)),
+ IsOk());
+ ASSERT_THAT(serializer.PrependData(
+ &pl_used, IntegerIndexData(/*section_id=*/0,
+ /*document_id=*/1, /*key=*/5)),
+ IsOk());
+ ASSERT_THAT(serializer.PrependData(
+ &pl_used, IntegerIndexData(/*section_id=*/0,
+ /*document_id=*/2, /*key=*/0)),
+ IsOk());
+ EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used), Eq(size));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest, PrependDataNotFull) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 2551 * sizeof(IntegerIndexData);
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+
+ // Make used.
+ IntegerIndexData data0(/*section_id=*/0, /*document_id=*/0, /*key=*/2);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data0), IsOk());
+ // Size = sizeof(uncompressed data0)
+ int expected_size = sizeof(IntegerIndexData);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetData(&pl_used), IsOkAndHolds(ElementsAre(data0)));
+
+ IntegerIndexData data1(/*section_id=*/0, /*document_id=*/1, /*key=*/5);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data1), IsOk());
+ // Size = sizeof(uncompressed data1)
+ // + sizeof(uncompressed data0)
+ expected_size += sizeof(IntegerIndexData);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAre(data1, data0)));
+
+ IntegerIndexData data2(/*section_id=*/0, /*document_id=*/2, /*key=*/0);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data2), IsOk());
+ // Size = sizeof(uncompressed data2)
+ // + sizeof(uncompressed data1)
+ // + sizeof(uncompressed data0)
+ expected_size += sizeof(IntegerIndexData);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAre(data2, data1, data0)));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest, PrependDataAlmostFull) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 4 * sizeof(IntegerIndexData);
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+
+ // Fill up the compressed region.
+ // Transitions:
+ // Adding data0: EMPTY -> NOT_FULL
+ // Adding data1: NOT_FULL -> NOT_FULL
+ IntegerIndexData data0(/*section_id=*/0, /*document_id=*/0, /*key=*/2);
+ IntegerIndexData data1(/*section_id=*/0, /*document_id=*/1, /*key=*/5);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data0), IsOk());
+ EXPECT_THAT(serializer.PrependData(&pl_used, data1), IsOk());
+ int expected_size = 2 * sizeof(IntegerIndexData);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAre(data1, data0)));
+
+ // Add one more data to transition NOT_FULL -> ALMOST_FULL
+ IntegerIndexData data2(/*section_id=*/0, /*document_id=*/2, /*key=*/0);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data2), IsOk());
+ expected_size = 3 * sizeof(IntegerIndexData);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAre(data2, data1, data0)));
+
+ // Add one more data to transition ALMOST_FULL -> FULL
+ IntegerIndexData data3(/*section_id=*/0, /*document_id=*/3, /*key=*/-3);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data3), IsOk());
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(size));
+ EXPECT_THAT(serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAre(data3, data2, data1, data0)));
+
+ // The posting list is FULL. Adding another data should fail.
+ IntegerIndexData data4(/*section_id=*/0, /*document_id=*/4, /*key=*/100);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data4),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest,
+ PrependDataPostingListUsedMinSize) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = serializer.GetMinPostingListSize();
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+
+ // PL State: EMPTY
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0));
+ EXPECT_THAT(serializer.GetData(&pl_used), IsOkAndHolds(IsEmpty()));
+
+ // Add a data. PL should shift to ALMOST_FULL state
+ IntegerIndexData data0(/*section_id=*/0, /*document_id=*/0, /*key=*/2);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data0), IsOk());
+ // Size = sizeof(uncompressed data0)
+ int expected_size = sizeof(IntegerIndexData);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetData(&pl_used), IsOkAndHolds(ElementsAre(data0)));
+
+ // Add another data. PL should shift to FULL state.
+ IntegerIndexData data1(/*section_id=*/0, /*document_id=*/1, /*key=*/5);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data1), IsOk());
+ // Size = sizeof(uncompressed data1) + sizeof(uncompressed data0)
+ expected_size += sizeof(IntegerIndexData);
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size));
+ EXPECT_THAT(serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAre(data1, data0)));
+
+ // The posting list is FULL. Adding another data should fail.
+ IntegerIndexData data2(/*section_id=*/0, /*document_id=*/2, /*key=*/0);
+ EXPECT_THAT(serializer.PrependData(&pl_used, data2),
+ StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest,
+ PrependDataArrayDoNotKeepPrepended) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 6 * sizeof(IntegerIndexData);
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+
+ std::vector<IntegerIndexData> data_in;
+ std::vector<IntegerIndexData> data_pushed;
+
+ // Add 3 data. The PL is in the empty state and should be able to fit all 3
+ // data without issue, transitioning the PL from EMPTY -> NOT_FULL.
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2));
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5));
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0));
+ EXPECT_THAT(
+ serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(),
+ /*keep_prepended=*/false),
+ Eq(data_in.size()));
+ std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used),
+ Eq(data_pushed.size() * sizeof(IntegerIndexData)));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend())));
+
+ // Add 2 data. The PL should transition from NOT_FULL to ALMOST_FULL.
+ data_in.clear();
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/3, /*key=*/-3));
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/4, /*key=*/100));
+ EXPECT_THAT(
+ serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(),
+ /*keep_prepended=*/false),
+ Eq(data_in.size()));
+ std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used),
+ Eq(data_pushed.size() * sizeof(IntegerIndexData)));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend())));
+
+ // Add 2 data. The PL should remain ALMOST_FULL since the remaining space can
+ // only fit 1 data.
+ data_in.clear();
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/5, /*key=*/-200));
+ data_in.push_back(IntegerIndexData(/*section_id=*/0, /*document_id=*/6,
+ /*key=*/2147483647));
+ EXPECT_THAT(
+ serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(),
+ /*keep_prepended=*/false),
+ Eq(0));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used),
+ Eq(data_pushed.size() * sizeof(IntegerIndexData)));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend())));
+
+ // Add 1 data. The PL should transition from ALMOST_FULL to FULL.
+ data_in.resize(1);
+ EXPECT_THAT(
+ serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(),
+ /*keep_prepended=*/false),
+ Eq(data_in.size()));
+ std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used),
+ Eq(data_pushed.size() * sizeof(IntegerIndexData)));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend())));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest,
+ PrependDataArrayKeepPrepended) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 6 * sizeof(IntegerIndexData);
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+
+ std::vector<IntegerIndexData> data_in;
+ std::vector<IntegerIndexData> data_pushed;
+
+ // Add 3 data. The PL is in the empty state and should be able to fit all 3
+ // data without issue, transitioning the PL from EMPTY -> NOT_FULL.
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2));
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5));
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0));
+ EXPECT_THAT(
+ serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(),
+ /*keep_prepended=*/true),
+ Eq(data_in.size()));
+ std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used),
+ Eq(data_pushed.size() * sizeof(IntegerIndexData)));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend())));
+
+ // Add 4 data. The PL should prepend 3 data and transition from NOT_FULL to
+ // FULL.
+ data_in.clear();
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/3, /*key=*/-3));
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/4, /*key=*/100));
+ data_in.push_back(
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/5, /*key=*/-200));
+ data_in.push_back(IntegerIndexData(/*section_id=*/0, /*document_id=*/6,
+ /*key=*/2147483647));
+ EXPECT_THAT(
+ serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(),
+ /*keep_prepended=*/true),
+ Eq(3));
+ data_in.resize(3);
+ std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed));
+ EXPECT_THAT(serializer.GetBytesUsed(&pl_used),
+ Eq(data_pushed.size() * sizeof(IntegerIndexData)));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend())));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest, MoveFrom) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 3 * serializer.GetMinPostingListSize();
+ std::unique_ptr<char[]> buf1 = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf1.get()), size));
+
+ std::vector<IntegerIndexData> data_arr1 = {
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5)};
+ ASSERT_THAT(
+ serializer.PrependDataArray(&pl_used1, data_arr1.data(), data_arr1.size(),
+ /*keep_prepended=*/false),
+ Eq(data_arr1.size()));
+
+ std::unique_ptr<char[]> buf2 = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf2.get()), size));
+ std::vector<IntegerIndexData> data_arr2 = {
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/3, /*key=*/-3),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/4, /*key=*/100),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/5, /*key=*/-200)};
+ ASSERT_THAT(
+ serializer.PrependDataArray(&pl_used2, data_arr2.data(), data_arr2.size(),
+ /*keep_prepended=*/false),
+ Eq(data_arr2.size()));
+
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1),
+ IsOk());
+ EXPECT_THAT(
+ serializer.GetData(&pl_used2),
+ IsOkAndHolds(ElementsAreArray(data_arr1.rbegin(), data_arr1.rend())));
+ EXPECT_THAT(serializer.GetData(&pl_used1), IsOkAndHolds(IsEmpty()));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest,
+ MoveToNullReturnsFailedPrecondition) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 3 * serializer.GetMinPostingListSize();
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+ std::vector<IntegerIndexData> data_arr = {
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5)};
+ ASSERT_THAT(
+ serializer.PrependDataArray(&pl_used, data_arr.data(), data_arr.size(),
+ /*keep_prepended=*/false),
+ Eq(data_arr.size()));
+
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used, /*src=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_arr.rbegin(), data_arr.rend())));
+
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/nullptr, /*src=*/&pl_used),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_arr.rbegin(), data_arr.rend())));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest, MoveToPostingListTooSmall) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size1 = 3 * serializer.GetMinPostingListSize();
+ std::unique_ptr<char[]> buf1 = std::make_unique<char[]>(size1);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used1,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf1.get()), size1));
+ std::vector<IntegerIndexData> data_arr1 = {
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/3, /*key=*/-3),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/4, /*key=*/100)};
+ ASSERT_THAT(
+ serializer.PrependDataArray(&pl_used1, data_arr1.data(), data_arr1.size(),
+ /*keep_prepended=*/false),
+ Eq(data_arr1.size()));
+
+ int size2 = serializer.GetMinPostingListSize();
+ std::unique_ptr<char[]> buf2 = std::make_unique<char[]>(size2);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used2,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf2.get()), size2));
+ std::vector<IntegerIndexData> data_arr2 = {
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/5, /*key=*/-200)};
+ ASSERT_THAT(
+ serializer.PrependDataArray(&pl_used2, data_arr2.data(), data_arr2.size(),
+ /*keep_prepended=*/false),
+ Eq(data_arr2.size()));
+
+ EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used1),
+ IsOkAndHolds(ElementsAreArray(data_arr1.rbegin(), data_arr1.rend())));
+ EXPECT_THAT(
+ serializer.GetData(&pl_used2),
+ IsOkAndHolds(ElementsAreArray(data_arr2.rbegin(), data_arr2.rend())));
+}
+
+TEST(PostingListUsedIntegerIndexDataSerializerTest, PopFrontData) {
+ PostingListUsedIntegerIndexDataSerializer serializer;
+
+ int size = 2 * serializer.GetMinPostingListSize();
+ std::unique_ptr<char[]> buf = std::make_unique<char[]>(size);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ PostingListUsed pl_used,
+ PostingListUsed::CreateFromUnitializedRegion(
+ &serializer, static_cast<void*>(buf.get()), size));
+
+ std::vector<IntegerIndexData> data_arr = {
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5),
+ IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0)};
+ ASSERT_THAT(
+ serializer.PrependDataArray(&pl_used, data_arr.data(), data_arr.size(),
+ /*keep_prepended=*/false),
+ Eq(data_arr.size()));
+ ASSERT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_arr.rbegin(), data_arr.rend())));
+
+ // Now, pop the last data. The posting list should contain the first three
+ // data.
+ EXPECT_THAT(serializer.PopFrontData(&pl_used, /*num_data=*/1), IsOk());
+ data_arr.pop_back();
+ EXPECT_THAT(
+ serializer.GetData(&pl_used),
+ IsOkAndHolds(ElementsAreArray(data_arr.rbegin(), data_arr.rend())));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/section-indexing-handler.h b/icing/index/section-indexing-handler.h
new file mode 100644
index 0000000..ff461cb
--- /dev/null
+++ b/icing/index/section-indexing-handler.h
@@ -0,0 +1,60 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_SECTION_INDEXING_HANDLER_H_
+#define ICING_INDEX_SECTION_INDEXING_HANDLER_H_
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/proto/logging.pb.h"
+#include "icing/store/document-id.h"
+#include "icing/util/clock.h"
+#include "icing/util/tokenized-document.h"
+
+namespace icing {
+namespace lib {
+
+// Parent class for indexing different types of sections in TokenizedDocument.
+class SectionIndexingHandler {
+ public:
+ explicit SectionIndexingHandler(const Clock* clock) : clock_(*clock) {}
+
+ virtual ~SectionIndexingHandler() = default;
+
+ // Handles the indexing process: add data (hits) into the specific type index
+ // (e.g. string index, integer index) for all contents in the corresponding
+ // type of sections in tokenized_document.
+ // For example, IntegerSectionIndexingHandler::Handle should add data into
+ // integer index for all contents in tokenized_document.integer_sections.
+ //
+ // tokenized_document: document object with different types of tokenized
+ // sections.
+ // document_id: id of the document.
+ // put_document_stats: object for collecting stats during indexing. It can be
+ // nullptr.
+ //
+ /// Returns:
+ // - OK on success
+ // - Any other errors. It depends on each implementation.
+ virtual libtextclassifier3::Status Handle(
+ const TokenizedDocument& tokenized_document, DocumentId document_id,
+ PutDocumentStatsProto* put_document_stats) = 0;
+
+ protected:
+ const Clock& clock_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_SECTION_INDEXING_HANDLER_H_
diff --git a/icing/index/string-section-indexing-handler.cc b/icing/index/string-section-indexing-handler.cc
new file mode 100644
index 0000000..9b1db7e
--- /dev/null
+++ b/icing/index/string-section-indexing-handler.cc
@@ -0,0 +1,146 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/index/string-section-indexing-handler.h"
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/index.h"
+#include "icing/legacy/core/icing-string-util.h"
+#include "icing/proto/logging.pb.h"
+#include "icing/proto/schema.pb.h"
+#include "icing/schema/section.h"
+#include "icing/store/document-id.h"
+#include "icing/transform/normalizer.h"
+#include "icing/util/clock.h"
+#include "icing/util/tokenized-document.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::Status StringSectionIndexingHandler::Handle(
+ const TokenizedDocument& tokenized_document, DocumentId document_id,
+ PutDocumentStatsProto* put_document_stats) {
+ std::unique_ptr<Timer> index_timer = clock_.GetNewTimer();
+
+ if (index_.last_added_document_id() != kInvalidDocumentId &&
+ document_id <= index_.last_added_document_id()) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "DocumentId %d must be greater than last added document_id %d",
+ document_id, index_.last_added_document_id()));
+ }
+ // TODO(b/259744228): revisit last_added_document_id with numeric index for
+ // index rebuilding before rollout.
+ index_.set_last_added_document_id(document_id);
+ uint32_t num_tokens = 0;
+ libtextclassifier3::Status status;
+ for (const TokenizedSection& section :
+ tokenized_document.tokenized_string_sections()) {
+ if (section.metadata.tokenizer ==
+ StringIndexingConfig::TokenizerType::NONE) {
+ ICING_LOG(WARNING)
+ << "Unexpected TokenizerType::NONE found when indexing document.";
+ }
+ // TODO(b/152934343): pass real namespace ids in
+ Index::Editor editor =
+ index_.Edit(document_id, section.metadata.id,
+ section.metadata.term_match_type, /*namespace_id=*/0);
+ for (std::string_view token : section.token_sequence) {
+ ++num_tokens;
+
+ switch (section.metadata.tokenizer) {
+ case StringIndexingConfig::TokenizerType::VERBATIM:
+ // data() is safe to use here because a token created from the
+ // VERBATIM tokenizer is the entire string value. The character at
+ // data() + token.length() is guaranteed to be a null char.
+ status = editor.BufferTerm(token.data());
+ break;
+ case StringIndexingConfig::TokenizerType::NONE:
+ [[fallthrough]];
+ case StringIndexingConfig::TokenizerType::RFC822:
+ [[fallthrough]];
+ case StringIndexingConfig::TokenizerType::URL:
+ [[fallthrough]];
+ case StringIndexingConfig::TokenizerType::PLAIN:
+ std::string normalized_term = normalizer_.NormalizeTerm(token);
+ status = editor.BufferTerm(normalized_term.c_str());
+ }
+
+ if (!status.ok()) {
+ // We've encountered a failure. Bail out. We'll mark this doc as deleted
+ // and signal a failure to the client.
+ ICING_LOG(WARNING) << "Failed to buffer term in lite lexicon due to: "
+ << status.error_message();
+ break;
+ }
+ }
+ if (!status.ok()) {
+ break;
+ }
+ // Add all the seen terms to the index with their term frequency.
+ status = editor.IndexAllBufferedTerms();
+ if (!status.ok()) {
+ ICING_LOG(WARNING) << "Failed to add hits in lite index due to: "
+ << status.error_message();
+ break;
+ }
+ }
+
+ if (put_document_stats != nullptr) {
+ // TODO(b/259744228): switch to set individual index latency.
+ put_document_stats->set_index_latency_ms(
+ index_timer->GetElapsedMilliseconds());
+ put_document_stats->mutable_tokenization_stats()->set_num_tokens_indexed(
+ num_tokens);
+ }
+
+ // If we're either successful or we've hit resource exhausted, then attempt a
+ // merge.
+ if ((status.ok() || absl_ports::IsResourceExhausted(status)) &&
+ index_.WantsMerge()) {
+ ICING_LOG(ERROR) << "Merging the index at docid " << document_id << ".";
+
+ std::unique_ptr<Timer> merge_timer = clock_.GetNewTimer();
+ libtextclassifier3::Status merge_status = index_.Merge();
+
+ if (!merge_status.ok()) {
+ ICING_LOG(ERROR) << "Index merging failed. Clearing index.";
+ if (!index_.Reset().ok()) {
+ return absl_ports::InternalError(IcingStringUtil::StringPrintf(
+ "Unable to reset to clear index after merge failure. Merge "
+ "failure=%d:%s",
+ merge_status.error_code(), merge_status.error_message().c_str()));
+ } else {
+ return absl_ports::DataLossError(IcingStringUtil::StringPrintf(
+ "Forced to reset index after merge failure. Merge failure=%d:%s",
+ merge_status.error_code(), merge_status.error_message().c_str()));
+ }
+ }
+
+ if (put_document_stats != nullptr) {
+ put_document_stats->set_index_merge_latency_ms(
+ merge_timer->GetElapsedMilliseconds());
+ }
+ }
+
+ return status;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/index/string-section-indexing-handler.h b/icing/index/string-section-indexing-handler.h
new file mode 100644
index 0000000..4906f97
--- /dev/null
+++ b/icing/index/string-section-indexing-handler.h
@@ -0,0 +1,67 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_INDEX_STRING_SECTION_INDEXING_HANDLER_H_
+#define ICING_INDEX_STRING_SECTION_INDEXING_HANDLER_H_
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/index/index.h"
+#include "icing/index/section-indexing-handler.h"
+#include "icing/proto/logging.pb.h"
+#include "icing/store/document-id.h"
+#include "icing/transform/normalizer.h"
+#include "icing/util/clock.h"
+#include "icing/util/tokenized-document.h"
+
+namespace icing {
+namespace lib {
+
+class StringSectionIndexingHandler : public SectionIndexingHandler {
+ public:
+ explicit StringSectionIndexingHandler(const Clock* clock,
+ const Normalizer* normalizer,
+ Index* index)
+ : SectionIndexingHandler(clock),
+ normalizer_(*normalizer),
+ index_(*index) {}
+
+ ~StringSectionIndexingHandler() override = default;
+
+ // Handles the string indexing process: add hits into the lite index for all
+ // contents in tokenized_document.tokenized_string_sections and merge lite
+ // index into main index if necessary.
+ //
+ /// Returns:
+ // - OK on success
+ // - INVALID_ARGUMENT_ERROR if document_id is less than the document_id of a
+ // previously indexed document.
+ // - RESOURCE_EXHAUSTED_ERROR if the index is full and can't add anymore
+ // content.
+ // - DATA_LOSS_ERROR if an attempt to merge the index fails and both indices
+ // are cleared as a result.
+ // - INTERNAL_ERROR if any other errors occur.
+ // - Any main/lite index errors.
+ libtextclassifier3::Status Handle(
+ const TokenizedDocument& tokenized_document, DocumentId document_id,
+ PutDocumentStatsProto* put_document_stats) override;
+
+ private:
+ const Normalizer& normalizer_;
+ Index& index_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_INDEX_STRING_SECTION_INDEXING_HANDLER_H_
diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc
index 283c6f5..9a7df38 100644
--- a/icing/jni/icing-search-engine-jni.cc
+++ b/icing/jni/icing-search-engine-jni.cc
@@ -83,7 +83,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
}
JNIEXPORT jlong JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeCreate(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeCreate(
JNIEnv* env, jclass clazz, jbyteArray icing_search_engine_options_bytes) {
icing::lib::IcingSearchEngineOptions options;
if (!ParseProtoFromJniByteArray(env, icing_search_engine_options_bytes,
@@ -103,7 +103,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeCreate(
}
JNIEXPORT void JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeDestroy(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeDestroy(
JNIEnv* env, jclass clazz, jobject object) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -111,7 +111,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDestroy(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeInitialize(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeInitialize(
JNIEnv* env, jclass clazz, jobject object) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -123,7 +123,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeInitialize(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeSetSchema(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetSchema(
JNIEnv* env, jclass clazz, jobject object, jbyteArray schema_bytes,
jboolean ignore_errors_and_delete_documents) {
icing::lib::IcingSearchEngine* icing =
@@ -142,7 +142,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSetSchema(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeGetSchema(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetSchema(
JNIEnv* env, jclass clazz, jobject object) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -153,7 +153,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetSchema(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeGetSchemaType(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetSchemaType(
JNIEnv* env, jclass clazz, jobject object, jstring schema_type) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -166,7 +166,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetSchemaType(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativePut(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativePut(
JNIEnv* env, jclass clazz, jobject object, jbyteArray document_bytes) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -184,7 +184,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativePut(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeGet(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeGet(
JNIEnv* env, jclass clazz, jobject object, jstring name_space, jstring uri,
jbyteArray result_spec_bytes) {
icing::lib::IcingSearchEngine* icing =
@@ -205,7 +205,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGet(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeReportUsage(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeReportUsage(
JNIEnv* env, jclass clazz, jobject object, jbyteArray usage_report_bytes) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -223,7 +223,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeReportUsage(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeGetAllNamespaces(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetAllNamespaces(
JNIEnv* env, jclass clazz, jobject object) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -235,7 +235,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetAllNamespaces(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeGetNextPage(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetNextPage(
JNIEnv* env, jclass clazz, jobject object, jlong next_page_token,
jlong java_to_native_start_timestamp_ms) {
icing::lib::IcingSearchEngine* icing =
@@ -252,13 +252,14 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetNextPage(
icing::lib::QueryStatsProto* query_stats =
next_page_result_proto.mutable_query_stats();
query_stats->set_java_to_native_jni_latency_ms(java_to_native_jni_latency_ms);
- query_stats->set_native_to_java_start_timestamp_ms(clock->GetSystemTimeMilliseconds());
+ query_stats->set_native_to_java_start_timestamp_ms(
+ clock->GetSystemTimeMilliseconds());
return SerializeProtoToJniByteArray(env, next_page_result_proto);
}
JNIEXPORT void JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeInvalidateNextPageToken(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeInvalidateNextPageToken(
JNIEnv* env, jclass clazz, jobject object, jlong next_page_token) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -269,7 +270,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeInvalidateNextPageToken(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeSearch(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearch(
JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes,
jbyteArray scoring_spec_bytes, jbyteArray result_spec_bytes,
jlong java_to_native_start_timestamp_ms) {
@@ -306,13 +307,14 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSearch(
icing::lib::QueryStatsProto* query_stats =
search_result_proto.mutable_query_stats();
query_stats->set_java_to_native_jni_latency_ms(java_to_native_jni_latency_ms);
- query_stats->set_native_to_java_start_timestamp_ms(clock->GetSystemTimeMilliseconds());
+ query_stats->set_native_to_java_start_timestamp_ms(
+ clock->GetSystemTimeMilliseconds());
return SerializeProtoToJniByteArray(env, search_result_proto);
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeDelete(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeDelete(
JNIEnv* env, jclass clazz, jobject object, jstring name_space,
jstring uri) {
icing::lib::IcingSearchEngine* icing =
@@ -327,7 +329,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDelete(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByNamespace(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByNamespace(
JNIEnv* env, jclass clazz, jobject object, jstring name_space) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -340,7 +342,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByNamespace(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeDeleteBySchemaType(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteBySchemaType(
JNIEnv* env, jclass clazz, jobject object, jstring schema_type) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -353,7 +355,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDeleteBySchemaType(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByQuery(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByQuery(
JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes,
jboolean return_deleted_document_info) {
icing::lib::IcingSearchEngine* icing =
@@ -371,7 +373,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByQuery(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativePersistToDisk(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativePersistToDisk(
JNIEnv* env, jclass clazz, jobject object, jint persist_type_code) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -390,7 +392,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativePersistToDisk(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeOptimize(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeOptimize(
JNIEnv* env, jclass clazz, jobject object) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -401,7 +403,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeOptimize(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeGetOptimizeInfo(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetOptimizeInfo(
JNIEnv* env, jclass clazz, jobject object) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -413,7 +415,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetOptimizeInfo(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeGetStorageInfo(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetStorageInfo(
JNIEnv* env, jclass clazz, jobject object) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -425,7 +427,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetStorageInfo(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeReset(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeReset(
JNIEnv* env, jclass clazz, jobject object) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -436,7 +438,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeReset(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearchSuggestions(
JNIEnv* env, jclass clazz, jobject object,
jbyteArray suggestion_spec_bytes) {
icing::lib::IcingSearchEngine* icing =
@@ -455,7 +457,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions(
}
JNIEXPORT jbyteArray JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeGetDebugInfo(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetDebugInfo(
JNIEnv* env, jclass clazz, jobject object, jint verbosity) {
icing::lib::IcingSearchEngine* icing =
GetIcingSearchEnginePointer(env, object);
@@ -473,7 +475,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetDebugInfo(
}
JNIEXPORT jboolean JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeShouldLog(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeShouldLog(
JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) {
if (!icing::lib::LogSeverity::Code_IsValid(severity)) {
ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity;
@@ -484,7 +486,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeShouldLog(
}
JNIEXPORT jboolean JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeSetLoggingLevel(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetLoggingLevel(
JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) {
if (!icing::lib::LogSeverity::Code_IsValid(severity)) {
ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity;
@@ -495,8 +497,215 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSetLoggingLevel(
}
JNIEXPORT jstring JNICALL
-Java_com_google_android_icing_IcingSearchEngine_nativeGetLoggingTag(
+Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetLoggingTag(
JNIEnv* env, jclass clazz) {
return env->NewStringUTF(icing::lib::kIcingLoggingTag);
}
+
+// TODO(b/240333360) Remove the methods below for IcingSearchEngine once we have
+// a sync from Jetpack to g3 to contain the refactored IcingSearchEngine(with
+// IcingSearchEngineImpl).
+JNIEXPORT jlong JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeCreate(
+ JNIEnv* env, jclass clazz, jbyteArray icing_search_engine_options_bytes) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeCreate(
+ env, clazz, icing_search_engine_options_bytes);
+}
+
+JNIEXPORT void JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeDestroy(JNIEnv* env,
+ jclass clazz,
+ jobject object) {
+ Java_com_google_android_icing_IcingSearchEngineImpl_nativeDestroy(env, clazz,
+ object);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeInitialize(
+ JNIEnv* env, jclass clazz, jobject object) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeInitialize(
+ env, clazz, object);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeSetSchema(
+ JNIEnv* env, jclass clazz, jobject object, jbyteArray schema_bytes,
+ jboolean ignore_errors_and_delete_documents) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetSchema(
+ env, clazz, object, schema_bytes, ignore_errors_and_delete_documents);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeGetSchema(
+ JNIEnv* env, jclass clazz, jobject object) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetSchema(
+ env, clazz, object);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeGetSchemaType(
+ JNIEnv* env, jclass clazz, jobject object, jstring schema_type) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetSchemaType(
+ env, clazz, object, schema_type);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativePut(
+ JNIEnv* env, jclass clazz, jobject object, jbyteArray document_bytes) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativePut(
+ env, clazz, object, document_bytes);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeGet(
+ JNIEnv* env, jclass clazz, jobject object, jstring name_space, jstring uri,
+ jbyteArray result_spec_bytes) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGet(
+ env, clazz, object, name_space, uri, result_spec_bytes);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeReportUsage(
+ JNIEnv* env, jclass clazz, jobject object, jbyteArray usage_report_bytes) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeReportUsage(
+ env, clazz, object, usage_report_bytes);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeGetAllNamespaces(
+ JNIEnv* env, jclass clazz, jobject object) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetAllNamespaces(
+ env, clazz, object);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeGetNextPage(
+ JNIEnv* env, jclass clazz, jobject object, jlong next_page_token,
+ jlong java_to_native_start_timestamp_ms) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetNextPage(
+ env, clazz, object, next_page_token, java_to_native_start_timestamp_ms);
+}
+
+JNIEXPORT void JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeInvalidateNextPageToken(
+ JNIEnv* env, jclass clazz, jobject object, jlong next_page_token) {
+ Java_com_google_android_icing_IcingSearchEngineImpl_nativeInvalidateNextPageToken(
+ env, clazz, object, next_page_token);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeSearch(
+ JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes,
+ jbyteArray scoring_spec_bytes, jbyteArray result_spec_bytes,
+ jlong java_to_native_start_timestamp_ms) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearch(
+ env, clazz, object, search_spec_bytes, scoring_spec_bytes,
+ result_spec_bytes, java_to_native_start_timestamp_ms);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeDelete(JNIEnv* env,
+ jclass clazz,
+ jobject object,
+ jstring name_space,
+ jstring uri) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeDelete(
+ env, clazz, object, name_space, uri);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByNamespace(
+ JNIEnv* env, jclass clazz, jobject object, jstring name_space) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByNamespace(
+ env, clazz, object, name_space);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeDeleteBySchemaType(
+ JNIEnv* env, jclass clazz, jobject object, jstring schema_type) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteBySchemaType(
+ env, clazz, object, schema_type);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByQuery(
+ JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes,
+ jboolean return_deleted_document_info) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByQuery(
+ env, clazz, object, search_spec_bytes, return_deleted_document_info);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativePersistToDisk(
+ JNIEnv* env, jclass clazz, jobject object, jint persist_type_code) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativePersistToDisk(
+ env, clazz, object, persist_type_code);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeOptimize(JNIEnv* env,
+ jclass clazz,
+ jobject object) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeOptimize(
+ env, clazz, object);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeGetOptimizeInfo(
+ JNIEnv* env, jclass clazz, jobject object) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetOptimizeInfo(
+ env, clazz, object);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeGetStorageInfo(
+ JNIEnv* env, jclass clazz, jobject object) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetStorageInfo(
+ env, clazz, object);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeReset(JNIEnv* env,
+ jclass clazz,
+ jobject object) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeReset(
+ env, clazz, object);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions(
+ JNIEnv* env, jclass clazz, jobject object,
+ jbyteArray suggestion_spec_bytes) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearchSuggestions(
+ env, clazz, object, suggestion_spec_bytes);
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeGetDebugInfo(
+ JNIEnv* env, jclass clazz, jobject object, jint verbosity) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetDebugInfo(
+ env, clazz, object, verbosity);
+}
+
+JNIEXPORT jboolean JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeShouldLog(
+ JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeShouldLog(
+ env, clazz, severity, verbosity);
+}
+
+JNIEXPORT jboolean JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeSetLoggingLevel(
+ JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetLoggingLevel(
+ env, clazz, severity, verbosity);
+}
+
+JNIEXPORT jstring JNICALL
+Java_com_google_android_icing_IcingSearchEngine_nativeGetLoggingTag(
+ JNIEnv* env, jclass clazz) {
+ return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetLoggingTag(
+ env, clazz);
+}
+
} // extern "C"
diff --git a/icing/join/aggregate-scorer.cc b/icing/join/aggregate-scorer.cc
new file mode 100644
index 0000000..7b17482
--- /dev/null
+++ b/icing/join/aggregate-scorer.cc
@@ -0,0 +1,117 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/join/aggregate-scorer.h"
+
+#include <algorithm>
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "icing/proto/search.pb.h"
+#include "icing/scoring/scored-document-hit.h"
+
+namespace icing {
+namespace lib {
+
+class MinAggregateScorer : public AggregateScorer {
+ public:
+ double GetScore(const ScoredDocumentHit& parent,
+ const std::vector<ScoredDocumentHit>& children) override {
+ return std::min_element(children.begin(), children.end(),
+ [](const ScoredDocumentHit& lhs,
+ const ScoredDocumentHit& rhs) -> bool {
+ return lhs.score() < rhs.score();
+ })
+ ->score();
+ }
+};
+
+class MaxAggregateScorer : public AggregateScorer {
+ public:
+ double GetScore(const ScoredDocumentHit& parent,
+ const std::vector<ScoredDocumentHit>& children) override {
+ return std::max_element(children.begin(), children.end(),
+ [](const ScoredDocumentHit& lhs,
+ const ScoredDocumentHit& rhs) -> bool {
+ return lhs.score() < rhs.score();
+ })
+ ->score();
+ }
+};
+
+class AverageAggregateScorer : public AggregateScorer {
+ public:
+ double GetScore(const ScoredDocumentHit& parent,
+ const std::vector<ScoredDocumentHit>& children) override {
+ if (children.empty()) return 0.0;
+ return std::reduce(
+ children.begin(), children.end(), 0.0,
+ [](const double& prev, const ScoredDocumentHit& item) -> double {
+ return prev + item.score();
+ }) /
+ children.size();
+ }
+};
+
+class CountAggregateScorer : public AggregateScorer {
+ public:
+ double GetScore(const ScoredDocumentHit& parent,
+ const std::vector<ScoredDocumentHit>& children) override {
+ return children.size();
+ }
+};
+
+class SumAggregateScorer : public AggregateScorer {
+ public:
+ double GetScore(const ScoredDocumentHit& parent,
+ const std::vector<ScoredDocumentHit>& children) override {
+ return std::reduce(
+ children.begin(), children.end(), 0.0,
+ [](const double& prev, const ScoredDocumentHit& item) -> double {
+ return prev + item.score();
+ });
+ }
+};
+
+class DefaultAggregateScorer : public AggregateScorer {
+ public:
+ double GetScore(const ScoredDocumentHit& parent,
+ const std::vector<ScoredDocumentHit>& children) override {
+ return parent.score();
+ }
+};
+
+std::unique_ptr<AggregateScorer> AggregateScorer::Create(
+ const JoinSpecProto& join_spec) {
+ switch (join_spec.aggregation_score_strategy()) {
+ case JoinSpecProto_AggregationScore_MIN:
+ return std::make_unique<MinAggregateScorer>();
+ case JoinSpecProto_AggregationScore_MAX:
+ return std::make_unique<MaxAggregateScorer>();
+ case JoinSpecProto_AggregationScore_COUNT:
+ return std::make_unique<CountAggregateScorer>();
+ case JoinSpecProto_AggregationScore_AVG:
+ return std::make_unique<AverageAggregateScorer>();
+ case JoinSpecProto_AggregationScore_SUM:
+ return std::make_unique<SumAggregateScorer>();
+ case JoinSpecProto_AggregationScore_UNDEFINED:
+ [[fallthrough]];
+ default:
+ return std::make_unique<DefaultAggregateScorer>();
+ }
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/join/aggregate-scorer.h b/icing/join/aggregate-scorer.h
new file mode 100644
index 0000000..27731b9
--- /dev/null
+++ b/icing/join/aggregate-scorer.h
@@ -0,0 +1,41 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_JOIN_AGGREGATE_SCORER_H_
+#define ICING_JOIN_AGGREGATE_SCORER_H_
+
+#include <memory>
+#include <vector>
+
+#include "icing/proto/search.pb.h"
+#include "icing/scoring/scored-document-hit.h"
+
+namespace icing {
+namespace lib {
+
+class AggregateScorer {
+ public:
+ static std::unique_ptr<AggregateScorer> Create(
+ const JoinSpecProto& join_spec);
+
+ virtual ~AggregateScorer() = default;
+
+ virtual double GetScore(const ScoredDocumentHit& parent,
+ const std::vector<ScoredDocumentHit>& children) = 0;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_JOIN_AGGREGATE_SCORER_H_
diff --git a/icing/join/join-processor.cc b/icing/join/join-processor.cc
new file mode 100644
index 0000000..7abd821
--- /dev/null
+++ b/icing/join/join-processor.cc
@@ -0,0 +1,180 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/join/join-processor.h"
+
+#include <algorithm>
+#include <functional>
+#include <string_view>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/proto/scoring.pb.h"
+#include "icing/proto/search.pb.h"
+#include "icing/scoring/scored-document-hit.h"
+#include "icing/store/document-id.h"
+#include "icing/util/snippet-helpers.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>>
+JoinProcessor::Join(
+ const JoinSpecProto& join_spec,
+ std::vector<ScoredDocumentHit>&& parent_scored_document_hits,
+ std::vector<ScoredDocumentHit>&& child_scored_document_hits) {
+ std::sort(
+ child_scored_document_hits.begin(), child_scored_document_hits.end(),
+ ScoredDocumentHitComparator(
+ /*is_descending=*/join_spec.nested_spec().scoring_spec().order_by() ==
+ ScoringSpecProto::Order::DESC));
+
+ // TODO(b/256022027):
+ // - Aggregate scoring
+ // - Calculate the aggregated score if strategy is AGGREGATION_SCORING.
+ // - Optimization
+ // - Cache property to speed up property retrieval.
+ // - If there is no cache, then we still have the flexibility to fetch it
+ // from actual docs via DocumentStore.
+
+ // Break children down into maps. The keys of this map are the DocumentIds of
+ // the parent docs the child ScoredDocumentHits refer to. The values in this
+ // map are vectors of child ScoredDocumentHits that refer to a parent
+ // DocumentId.
+ std::unordered_map<DocumentId, std::vector<ScoredDocumentHit>>
+ parent_to_child_map;
+ for (const ScoredDocumentHit& child : child_scored_document_hits) {
+ std::string property_content = FetchPropertyExpressionValue(
+ child.document_id(), join_spec.child_property_expression());
+
+ // Try to split the property content by separators.
+ std::vector<int> separators_in_property_content =
+ GetSeparatorLocations(property_content, "#");
+
+ if (separators_in_property_content.size() != 1) {
+ // Skip the document if the qualified id isn't made up of the namespace
+ // and uri. StrSplit will return just the original string if there are no
+ // spaces.
+ continue;
+ }
+
+ std::string ns =
+ property_content.substr(0, separators_in_property_content[0]);
+ std::string uri =
+ property_content.substr(separators_in_property_content[0] + 1);
+
+ UnescapeSeparator(ns, "#");
+ UnescapeSeparator(uri, "#");
+
+ libtextclassifier3::StatusOr<DocumentId> doc_id_or =
+ doc_store_->GetDocumentId(ns, uri);
+
+ if (!doc_id_or.ok()) {
+ // Skip the document if getting errors.
+ continue;
+ }
+
+ DocumentId parent_doc_id = std::move(doc_id_or).ValueOrDie();
+
+ // This assumes the child docs are already sorted.
+ if (parent_to_child_map[parent_doc_id].size() <
+ join_spec.max_joined_child_count()) {
+ parent_to_child_map[parent_doc_id].push_back(std::move(child));
+ }
+ }
+
+ std::vector<JoinedScoredDocumentHit> joined_scored_document_hits;
+ joined_scored_document_hits.reserve(parent_scored_document_hits.size());
+
+ // Then add use child maps to add to parent ScoredDocumentHits.
+ for (ScoredDocumentHit& parent : parent_scored_document_hits) {
+ DocumentId parent_doc_id = kInvalidDocumentId;
+ if (join_spec.parent_property_expression() == kFullyQualifiedIdExpr) {
+ parent_doc_id = parent.document_id();
+ } else {
+ // TODO(b/256022027): So far we only support kFullyQualifiedIdExpr for
+ // parent_property_expression, we could support more.
+ return absl_ports::UnimplementedError(
+ join_spec.parent_property_expression() +
+ " must be \"fullyQualifiedId(this)\"");
+ }
+
+ // TODO(b/256022027): Derive final score from
+ // parent_to_child_map[parent_doc_id] and
+ // join_spec.aggregation_score_strategy()
+ double final_score = parent.score();
+ joined_scored_document_hits.emplace_back(
+ final_score, std::move(parent),
+ std::vector<ScoredDocumentHit>(
+ std::move(parent_to_child_map[parent_doc_id])));
+ }
+
+ return joined_scored_document_hits;
+}
+
+// This loads a document and uses a property expression to fetch the value of
+// the property from the document. The property expression may refer to nested
+// document properties. We do not allow for repeated values in this property
+// path, as that would allow for a single document to join to multiple
+// documents.
+//
+// Returns:
+// "" on document load error.
+// "" if the property path is not found in the document.
+// "" if part of the property path is a repeated value.
+std::string JoinProcessor::FetchPropertyExpressionValue(
+ const DocumentId& document_id,
+ const std::string& property_expression) const {
+ // TODO(b/256022027): Add caching of document_id -> {expression -> value}
+ libtextclassifier3::StatusOr<DocumentProto> document_or =
+ doc_store_->Get(document_id);
+ if (!document_or.ok()) {
+ // Skip the document if getting errors.
+ return "";
+ }
+
+ DocumentProto document = std::move(document_or).ValueOrDie();
+
+ return std::string(GetString(&document, property_expression));
+}
+
+std::vector<int> JoinProcessor::GetSeparatorLocations(
+ const std::string& content, const std::string& separator) const {
+ std::vector<int> separators_in_property_content;
+
+ for (int i = 0; i < content.length(); ++i) {
+ if (content[i] == '\\') {
+ // Skip the following character
+ i++;
+ } else if (content[i] == '#') {
+ // Unescaped separator
+ separators_in_property_content.push_back(i);
+ }
+ }
+ return separators_in_property_content;
+}
+
+void JoinProcessor::UnescapeSeparator(std::string& property,
+ const std::string& separator) {
+ size_t start_pos = 0;
+ while ((start_pos = property.find("\\" + separator, start_pos)) !=
+ std::string::npos) {
+ property.replace(start_pos, 2, "#");
+ start_pos += 1;
+ }
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/join/join-processor.h b/icing/join/join-processor.h
new file mode 100644
index 0000000..c919b22
--- /dev/null
+++ b/icing/join/join-processor.h
@@ -0,0 +1,57 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_JOIN_JOIN_PROCESSOR_H_
+#define ICING_JOIN_JOIN_PROCESSOR_H_
+
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/proto/search.pb.h"
+#include "icing/scoring/scored-document-hit.h"
+#include "icing/store/document-store.h"
+
+namespace icing {
+namespace lib {
+
+class JoinProcessor {
+ public:
+ static constexpr std::string_view kFullyQualifiedIdExpr =
+ "this.fullyQualifiedId()";
+
+ explicit JoinProcessor(const DocumentStore* doc_store)
+ : doc_store_(doc_store) {}
+
+ libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>> Join(
+ const JoinSpecProto& join_spec,
+ std::vector<ScoredDocumentHit>&& parent_scored_document_hits,
+ std::vector<ScoredDocumentHit>&& child_scored_document_hits);
+
+ private:
+ std::string FetchPropertyExpressionValue(
+ const DocumentId& document_id,
+ const std::string& property_expression) const;
+
+ void UnescapeSeparator(std::string& property, const std::string& separator);
+
+ std::vector<int> GetSeparatorLocations(const std::string& content,
+ const std::string& separator) const;
+
+ const DocumentStore* doc_store_; // Does not own.
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_JOIN_JOIN_PROCESSOR_H_
diff --git a/icing/monkey_test/icing-monkey-test-runner.cc b/icing/monkey_test/icing-monkey-test-runner.cc
index 2dd5a03..a2a6c9b 100644
--- a/icing/monkey_test/icing-monkey-test-runner.cc
+++ b/icing/monkey_test/icing-monkey-test-runner.cc
@@ -40,42 +40,12 @@ using ::testing::Le;
using ::testing::SizeIs;
using ::testing::UnorderedElementsAreArray;
-inline constexpr int kNumTypes = 30;
-const std::vector<int> kPossibleNumProperties = {0,
- 1,
- 2,
- 4,
- 8,
- 16,
- kTotalNumSections / 2,
- kTotalNumSections,
- kTotalNumSections + 1,
- kTotalNumSections * 2};
-inline constexpr int kNumNamespaces = 100;
-inline constexpr int kNumURIs = 1000;
-
-// Merge per 131072 hits
-const int kIndexMergeSize = 1024 * 1024;
-
-// An array of pairs of monkey test APIs with frequencies.
-// If f_sum is the sum of all the frequencies, an operation with frequency f
-// means for every f_sum iterations, the operation is expected to run f times.
-const std::vector<
- std::pair<std::function<void(IcingMonkeyTestRunner*)>, uint32_t>>
- kMonkeyAPISchedules = {{&IcingMonkeyTestRunner::DoPut, 500},
- {&IcingMonkeyTestRunner::DoSearch, 200},
- {&IcingMonkeyTestRunner::DoGet, 70},
- {&IcingMonkeyTestRunner::DoGetAllNamespaces, 50},
- {&IcingMonkeyTestRunner::DoDelete, 50},
- {&IcingMonkeyTestRunner::DoDeleteByNamespace, 50},
- {&IcingMonkeyTestRunner::DoDeleteBySchemaType, 50},
- {&IcingMonkeyTestRunner::DoDeleteByQuery, 20},
- {&IcingMonkeyTestRunner::DoOptimize, 5},
- {&IcingMonkeyTestRunner::ReloadFromDisk, 5}};
-
-SchemaProto GenerateRandomSchema(MonkeyTestRandomEngine* random) {
+SchemaProto GenerateRandomSchema(
+ const IcingMonkeyTestRunnerConfiguration& config,
+ MonkeyTestRandomEngine* random) {
MonkeySchemaGenerator schema_generator(random);
- return schema_generator.GenerateSchema(kNumTypes, kPossibleNumProperties);
+ return schema_generator.GenerateSchema(config.num_types,
+ config.possible_num_properties);
}
SearchSpecProto GenerateRandomSearchSpecProto(
@@ -166,18 +136,20 @@ void SortDocuments(std::vector<DocumentProto>& documents) {
} // namespace
-IcingMonkeyTestRunner::IcingMonkeyTestRunner(uint32_t seed)
- : random_(seed), in_memory_icing_() {
- ICING_LOG(INFO) << "Monkey test runner started with seed: " << seed;
+IcingMonkeyTestRunner::IcingMonkeyTestRunner(
+ const IcingMonkeyTestRunnerConfiguration& config)
+ : config_(config), random_(config.seed), in_memory_icing_() {
+ ICING_LOG(INFO) << "Monkey test runner started with seed: " << config_.seed;
- SchemaProto schema = GenerateRandomSchema(&random_);
+ SchemaProto schema = GenerateRandomSchema(config_, &random_);
ICING_LOG(DBG) << "Schema Generated: " << schema.DebugString();
in_memory_icing_ =
std::make_unique<InMemoryIcingSearchEngine>(&random_, std::move(schema));
document_generator_ = std::make_unique<MonkeyDocumentGenerator>(
- &random_, in_memory_icing_->GetSchema(), kNumNamespaces, kNumURIs);
+ &random_, in_memory_icing_->GetSchema(), config_.possible_num_tokens_,
+ config_.num_namespaces, config_.num_uris);
std::string dir = GetTestTempDir() + "/icing/monkey";
filesystem_.DeleteDirectoryRecursively(dir.c_str());
@@ -190,13 +162,13 @@ void IcingMonkeyTestRunner::Run(uint32_t num) {
"CreateIcingSearchEngineWithSchema() first";
uint32_t frequency_sum = 0;
- for (const auto& schedule : kMonkeyAPISchedules) {
+ for (const auto& schedule : config_.monkey_api_schedules) {
frequency_sum += schedule.second;
}
std::uniform_int_distribution<> dist(0, frequency_sum - 1);
for (; num; --num) {
int p = dist(random_);
- for (const auto& schedule : kMonkeyAPISchedules) {
+ for (const auto& schedule : config_.monkey_api_schedules) {
if (p < schedule.second) {
ASSERT_NO_FATAL_FAILURE(schedule.first(this));
break;
@@ -404,6 +376,11 @@ void IcingMonkeyTestRunner::DoSearch() {
search_result = icing_->GetNextPage(search_result.next_page_token());
ASSERT_THAT(search_result.status(), ProtoIsOk());
}
+ // The maximum number of scored documents allowed in Icing is 30000, in which
+ // case we are not able to compare the results with the in-memory Icing.
+ if (exp_documents.size() >= 30000) {
+ return;
+ }
if (snippet_spec.num_matches_per_property() > 0) {
ASSERT_THAT(num_snippeted,
Eq(std::min<uint32_t>(exp_documents.size(),
@@ -432,7 +409,7 @@ void IcingMonkeyTestRunner::DoOptimize() {
void IcingMonkeyTestRunner::CreateIcingSearchEngine() {
IcingSearchEngineOptions icing_options;
- icing_options.set_index_merge_size(kIndexMergeSize);
+ icing_options.set_index_merge_size(config_.index_merge_size);
icing_options.set_base_dir(icing_dir_->dir());
icing_ = std::make_unique<IcingSearchEngine>(icing_options);
ASSERT_THAT(icing_->Initialize().status(), ProtoIsOk());
diff --git a/icing/monkey_test/icing-monkey-test-runner.h b/icing/monkey_test/icing-monkey-test-runner.h
index 5f5649c..fbaaaaa 100644
--- a/icing/monkey_test/icing-monkey-test-runner.h
+++ b/icing/monkey_test/icing-monkey-test-runner.h
@@ -26,9 +26,42 @@
namespace icing {
namespace lib {
+class IcingMonkeyTestRunner;
+
+struct IcingMonkeyTestRunnerConfiguration {
+ explicit IcingMonkeyTestRunnerConfiguration(uint32_t seed, int num_types,
+ int num_namespaces, int num_uris,
+ int index_merge_size)
+ : seed(seed),
+ num_types(num_types),
+ num_namespaces(num_namespaces),
+ num_uris(num_uris),
+ index_merge_size(index_merge_size) {}
+
+ uint32_t seed;
+ int num_types;
+ int num_namespaces;
+ int num_uris;
+ int index_merge_size;
+
+ // The possible number of properties that may appear in generated schema
+ // types.
+ std::vector<int> possible_num_properties;
+
+ // The possible number of tokens that may appear in generated documents, with
+ // a noise factor from 0.5 to 1 applied.
+ std::vector<int> possible_num_tokens_;
+
+ // An array of pairs of monkey test APIs with frequencies.
+ // If f_sum is the sum of all the frequencies, an operation with frequency f
+ // means for every f_sum iterations, the operation is expected to run f times.
+ std::vector<std::pair<std::function<void(IcingMonkeyTestRunner*)>, uint32_t>>
+ monkey_api_schedules;
+};
+
class IcingMonkeyTestRunner {
public:
- IcingMonkeyTestRunner(uint32_t seed = std::random_device()());
+ IcingMonkeyTestRunner(const IcingMonkeyTestRunnerConfiguration& config);
IcingMonkeyTestRunner(const IcingMonkeyTestRunner&) = delete;
IcingMonkeyTestRunner& operator=(const IcingMonkeyTestRunner&) = delete;
@@ -54,6 +87,7 @@ class IcingMonkeyTestRunner {
void DoOptimize();
private:
+ IcingMonkeyTestRunnerConfiguration config_;
MonkeyTestRandomEngine random_;
Filesystem filesystem_;
std::unique_ptr<DestructibleDirectory> icing_dir_;
diff --git a/icing/monkey_test/icing-search-engine_monkey_test.cc b/icing/monkey_test/icing-search-engine_monkey_test.cc
index ad887b8..a24e57f 100644
--- a/icing/monkey_test/icing-search-engine_monkey_test.cc
+++ b/icing/monkey_test/icing-search-engine_monkey_test.cc
@@ -20,11 +20,71 @@ namespace icing {
namespace lib {
TEST(IcingSearchEngineMonkeyTest, MonkeyTest) {
+ IcingMonkeyTestRunnerConfiguration config(
+ /*seed=*/std::random_device()(),
+ /*num_types=*/30,
+ /*num_namespaces=*/100,
+ /*num_uris=*/1000,
+ /*index_merge_size=*/1024 * 1024);
+ config.possible_num_properties = {0,
+ 1,
+ 2,
+ 4,
+ 8,
+ 16,
+ kTotalNumSections / 2,
+ kTotalNumSections,
+ kTotalNumSections + 1,
+ kTotalNumSections * 2};
+ config.possible_num_tokens_ = {0, 1, 4, 16, 64, 256};
+ config.monkey_api_schedules = {
+ {&IcingMonkeyTestRunner::DoPut, 500},
+ {&IcingMonkeyTestRunner::DoSearch, 200},
+ {&IcingMonkeyTestRunner::DoGet, 70},
+ {&IcingMonkeyTestRunner::DoGetAllNamespaces, 50},
+ {&IcingMonkeyTestRunner::DoDelete, 50},
+ {&IcingMonkeyTestRunner::DoDeleteByNamespace, 50},
+ {&IcingMonkeyTestRunner::DoDeleteBySchemaType, 50},
+ {&IcingMonkeyTestRunner::DoDeleteByQuery, 20},
+ {&IcingMonkeyTestRunner::DoOptimize, 5},
+ {&IcingMonkeyTestRunner::ReloadFromDisk, 5}};
uint32_t num_iterations = IsAndroidArm() ? 1000 : 5000;
- IcingMonkeyTestRunner runner;
+ IcingMonkeyTestRunner runner(config);
ASSERT_NO_FATAL_FAILURE(runner.CreateIcingSearchEngineWithSchema());
ASSERT_NO_FATAL_FAILURE(runner.Run(num_iterations));
}
+TEST(DISABLED_IcingSearchEngineMonkeyTest, MonkeyManyDocTest) {
+ IcingMonkeyTestRunnerConfiguration config(
+ /*seed=*/std::random_device()(),
+ /*num_types=*/30,
+ /*num_namespaces=*/200,
+ /*num_uris=*/100000,
+ /*index_merge_size=*/1024 * 1024);
+
+ // Due to the large amount of documents, we need to make each document smaller
+ // to finish the test.
+ config.possible_num_properties = {0, 1, 2};
+ config.possible_num_tokens_ = {0, 1, 4};
+
+ // No deletion is performed to preserve a large number of documents.
+ config.monkey_api_schedules = {
+ {&IcingMonkeyTestRunner::DoPut, 500},
+ {&IcingMonkeyTestRunner::DoSearch, 200},
+ {&IcingMonkeyTestRunner::DoGet, 70},
+ {&IcingMonkeyTestRunner::DoGetAllNamespaces, 50},
+ {&IcingMonkeyTestRunner::DoOptimize, 5},
+ {&IcingMonkeyTestRunner::ReloadFromDisk, 5}};
+ IcingMonkeyTestRunner runner(config);
+ ASSERT_NO_FATAL_FAILURE(runner.CreateIcingSearchEngineWithSchema());
+ // Pre-fill with 4 million documents
+ SetLoggingLevel(LogSeverity::WARNING);
+ for (int i = 0; i < 4000000; i++) {
+ ASSERT_NO_FATAL_FAILURE(runner.DoPut());
+ }
+ SetLoggingLevel(LogSeverity::INFO);
+ ASSERT_NO_FATAL_FAILURE(runner.Run(1000));
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/monkey_test/monkey-test-generators.cc b/icing/monkey_test/monkey-test-generators.cc
index 88fc0b6..7b2ff56 100644
--- a/icing/monkey_test/monkey-test-generators.cc
+++ b/icing/monkey_test/monkey-test-generators.cc
@@ -106,19 +106,8 @@ std::string MonkeyDocumentGenerator::GetUri() const {
}
int MonkeyDocumentGenerator::GetNumTokens() const {
- std::uniform_int_distribution<> int_dist(-1, 4);
- int n = int_dist(*random_);
- if (n == -1) {
- // 1/6 chance of getting zero token for a property
- return 0;
- }
- if (n == 0) {
- // 1/6 chance of getting one token for a property
- return 1;
- }
- // 1/6 chance of getting one of 4, 16, 64, 256
- n = 1 << (2 * n);
-
+ std::uniform_int_distribution<> dist(0, possible_num_tokens_.size() - 1);
+ int n = possible_num_tokens_[dist(*random_)];
// Add some noise
std::uniform_real_distribution<> real_dist(0.5, 1);
float p = real_dist(*random_);
diff --git a/icing/monkey_test/monkey-test-generators.h b/icing/monkey_test/monkey-test-generators.h
index 68c5e92..6349918 100644
--- a/icing/monkey_test/monkey-test-generators.h
+++ b/icing/monkey_test/monkey-test-generators.h
@@ -70,10 +70,12 @@ class MonkeyDocumentGenerator {
public:
explicit MonkeyDocumentGenerator(MonkeyTestRandomEngine* random,
const SchemaProto* schema,
+ std::vector<int> possible_num_tokens,
uint32_t num_namespaces,
uint32_t num_uris = 0)
: random_(random),
schema_(schema),
+ possible_num_tokens_(std::move(possible_num_tokens)),
num_namespaces_(num_namespaces),
num_uris_(num_uris) {}
@@ -104,6 +106,11 @@ class MonkeyDocumentGenerator {
private:
MonkeyTestRandomEngine* random_; // Does not own.
const SchemaProto* schema_; // Does not own.
+
+ // The possible number of tokens that may appear in generated documents, with
+ // a noise factor from 0.5 to 1 applied.
+ std::vector<int> possible_num_tokens_;
+
uint32_t num_namespaces_;
uint32_t num_uris_;
uint32_t num_docs_generated_ = 0;
diff --git a/icing/portable/platform.h b/icing/portable/platform.h
index 150eede..b68c026 100644
--- a/icing/portable/platform.h
+++ b/icing/portable/platform.h
@@ -15,6 +15,8 @@
#ifndef ICING_PORTABLE_PLATFORM_H_
#define ICING_PORTABLE_PLATFORM_H_
+#include "unicode/uvernum.h"
+
namespace icing {
namespace lib {
@@ -34,6 +36,14 @@ inline bool IsReverseJniTokenization() {
return false;
}
+inline bool IsIcuTokenization() {
+ return !IsReverseJniTokenization() && !IsCfStringTokenization();
+}
+
+inline bool IsIcu72PlusTokenization() {
+ return IsIcuTokenization() && U_ICU_VERSION_MAJOR_NUM >= 72;
+}
+
// Whether we're running on android_x86
inline bool IsAndroidX86() {
#if defined(__ANDROID__) && defined(__i386__)
@@ -58,6 +68,15 @@ inline bool IsIosPlatform() {
return false;
}
+// TODO(b/259129263): verify the flag works for different platforms.
+#if defined(__arm__) || defined(__i386__)
+#define ICING_ARCH_BIT_32
+#elif defined(__aarch64__) || defined(__x86_64__)
+#define ICING_ARCH_BIT_64
+#else
+#define ICING_ARCH_BIT_UNKNOWN
+#endif
+
enum Architecture {
UNKNOWN,
BIT_32,
@@ -69,9 +88,9 @@ enum Architecture {
// Architecture macros pulled from
// https://developer.android.com/ndk/guides/cpu-features
inline Architecture GetArchitecture() {
-#if defined(__arm__) || defined(__i386__)
+#if defined(ICING_ARCH_BIT_32)
return BIT_32;
-#elif defined(__aarch64__) || defined(__x86_64__)
+#elif defined(ICING_ARCH_BIT_64)
return BIT_64;
#else
return UNKNOWN;
diff --git a/icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h b/icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h
new file mode 100644
index 0000000..42be07d
--- /dev/null
+++ b/icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h
@@ -0,0 +1,108 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_TEST_UTILS_H_
+#define ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_TEST_UTILS_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+
+namespace icing {
+namespace lib {
+
+// A visitor that simply collects the nodes and flattens them in left-side
+// depth-first order.
+enum class NodeType {
+ kFunctionName,
+ kString,
+ kText,
+ kMember,
+ kFunction,
+ kUnaryOperator,
+ kNaryOperator
+};
+
+struct NodeInfo {
+ std::string value;
+ NodeType type;
+
+ bool operator==(const NodeInfo& rhs) const {
+ return value == rhs.value && type == rhs.type;
+ }
+};
+
+MATCHER_P2(EqualsNodeInfo, value, type, "") {
+ if (arg.value != value || arg.type != type) {
+ *result_listener << "(Expected: value=\"" << value
+ << "\", type=" << static_cast<int>(type)
+ << ". Actual: value=\"" << arg.value
+ << "\", type=" << static_cast<int>(arg.type) << ")";
+ return false;
+ }
+ return true;
+}
+
+class SimpleVisitor : public AbstractSyntaxTreeVisitor {
+ public:
+ void VisitFunctionName(const FunctionNameNode* node) override {
+ nodes_.push_back({node->value(), NodeType::kFunctionName});
+ }
+ void VisitString(const StringNode* node) override {
+ nodes_.push_back({node->value(), NodeType::kString});
+ }
+ void VisitText(const TextNode* node) override {
+ nodes_.push_back({node->value(), NodeType::kText});
+ }
+ void VisitMember(const MemberNode* node) override {
+ for (const std::unique_ptr<TextNode>& child : node->children()) {
+ child->Accept(this);
+ }
+ if (node->function() != nullptr) {
+ node->function()->Accept(this);
+ }
+ nodes_.push_back({"", NodeType::kMember});
+ }
+ void VisitFunction(const FunctionNode* node) override {
+ node->function_name()->Accept(this);
+ for (const std::unique_ptr<Node>& arg : node->args()) {
+ arg->Accept(this);
+ }
+ nodes_.push_back({"", NodeType::kFunction});
+ }
+ void VisitUnaryOperator(const UnaryOperatorNode* node) override {
+ node->child()->Accept(this);
+ nodes_.push_back({node->operator_text(), NodeType::kUnaryOperator});
+ }
+ void VisitNaryOperator(const NaryOperatorNode* node) override {
+ for (const std::unique_ptr<Node>& child : node->children()) {
+ child->Accept(this);
+ }
+ nodes_.push_back({node->operator_text(), NodeType::kNaryOperator});
+ }
+
+ const std::vector<NodeInfo>& nodes() const { return nodes_; }
+
+ private:
+ std::vector<NodeInfo> nodes_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_TEST_UTILS_H_
diff --git a/icing/query/advanced_query_parser/abstract-syntax-tree.h b/icing/query/advanced_query_parser/abstract-syntax-tree.h
new file mode 100644
index 0000000..dc28ab6
--- /dev/null
+++ b/icing/query/advanced_query_parser/abstract-syntax-tree.h
@@ -0,0 +1,168 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_H_
+#define ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace icing {
+namespace lib {
+
+class FunctionNameNode;
+class StringNode;
+class TextNode;
+class MemberNode;
+class FunctionNode;
+class UnaryOperatorNode;
+class NaryOperatorNode;
+
+class AbstractSyntaxTreeVisitor {
+ public:
+ virtual ~AbstractSyntaxTreeVisitor() = default;
+
+ virtual void VisitFunctionName(const FunctionNameNode* node) = 0;
+ virtual void VisitString(const StringNode* node) = 0;
+ virtual void VisitText(const TextNode* node) = 0;
+ virtual void VisitMember(const MemberNode* node) = 0;
+ virtual void VisitFunction(const FunctionNode* node) = 0;
+ virtual void VisitUnaryOperator(const UnaryOperatorNode* node) = 0;
+ virtual void VisitNaryOperator(const NaryOperatorNode* node) = 0;
+};
+
+class Node {
+ public:
+ virtual ~Node() = default;
+ virtual void Accept(AbstractSyntaxTreeVisitor* visitor) const = 0;
+};
+
+class TerminalNode : public Node {
+ public:
+ explicit TerminalNode(std::string value) : value_(std::move(value)) {}
+
+ const std::string& value() const { return value_; }
+
+ private:
+ std::string value_;
+};
+
+class FunctionNameNode : public TerminalNode {
+ public:
+ explicit FunctionNameNode(std::string value)
+ : TerminalNode(std::move(value)) {}
+ void Accept(AbstractSyntaxTreeVisitor* visitor) const override {
+ visitor->VisitFunctionName(this);
+ }
+};
+
+class StringNode : public TerminalNode {
+ public:
+ explicit StringNode(std::string value) : TerminalNode(std::move(value)) {}
+ void Accept(AbstractSyntaxTreeVisitor* visitor) const override {
+ visitor->VisitString(this);
+ }
+};
+
+class TextNode : public TerminalNode {
+ public:
+ explicit TextNode(std::string value) : TerminalNode(std::move(value)) {}
+ void Accept(AbstractSyntaxTreeVisitor* visitor) const override {
+ visitor->VisitText(this);
+ }
+};
+
+class MemberNode : public Node {
+ public:
+ explicit MemberNode(std::vector<std::unique_ptr<TextNode>> children,
+ std::unique_ptr<FunctionNode> function)
+ : children_(std::move(children)), function_(std::move(function)) {}
+
+ void Accept(AbstractSyntaxTreeVisitor* visitor) const override {
+ visitor->VisitMember(this);
+ }
+ const std::vector<std::unique_ptr<TextNode>>& children() const {
+ return children_;
+ }
+ const FunctionNode* function() const { return function_.get(); }
+
+ private:
+ std::vector<std::unique_ptr<TextNode>> children_;
+ // This is nullable. When it is not nullptr, this class will represent a
+ // function call.
+ std::unique_ptr<FunctionNode> function_;
+};
+
+class FunctionNode : public Node {
+ public:
+ explicit FunctionNode(std::unique_ptr<FunctionNameNode> function_name)
+ : FunctionNode(std::move(function_name), {}) {}
+ explicit FunctionNode(std::unique_ptr<FunctionNameNode> function_name,
+ std::vector<std::unique_ptr<Node>> args)
+ : function_name_(std::move(function_name)), args_(std::move(args)) {}
+
+ void Accept(AbstractSyntaxTreeVisitor* visitor) const override {
+ visitor->VisitFunction(this);
+ }
+ const FunctionNameNode* function_name() const { return function_name_.get(); }
+ const std::vector<std::unique_ptr<Node>>& args() const { return args_; }
+
+ private:
+ std::unique_ptr<FunctionNameNode> function_name_;
+ std::vector<std::unique_ptr<Node>> args_;
+};
+
+class UnaryOperatorNode : public Node {
+ public:
+ explicit UnaryOperatorNode(std::string operator_text,
+ std::unique_ptr<Node> child)
+ : operator_text_(std::move(operator_text)), child_(std::move(child)) {}
+
+ void Accept(AbstractSyntaxTreeVisitor* visitor) const override {
+ visitor->VisitUnaryOperator(this);
+ }
+ const std::string& operator_text() const { return operator_text_; }
+ const Node* child() const { return child_.get(); }
+
+ private:
+ std::string operator_text_;
+ std::unique_ptr<Node> child_;
+};
+
+class NaryOperatorNode : public Node {
+ public:
+ explicit NaryOperatorNode(std::string operator_text,
+ std::vector<std::unique_ptr<Node>> children)
+ : operator_text_(std::move(operator_text)),
+ children_(std::move(children)) {}
+
+ void Accept(AbstractSyntaxTreeVisitor* visitor) const override {
+ visitor->VisitNaryOperator(this);
+ }
+ const std::string& operator_text() const { return operator_text_; }
+ const std::vector<std::unique_ptr<Node>>& children() const {
+ return children_;
+ }
+
+ private:
+ std::string operator_text_;
+ std::vector<std::unique_ptr<Node>> children_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_H_
diff --git a/icing/query/advanced_query_parser/abstract-syntax-tree_test.cc b/icing/query/advanced_query_parser/abstract-syntax-tree_test.cc
new file mode 100644
index 0000000..a8599fd
--- /dev/null
+++ b/icing/query/advanced_query_parser/abstract-syntax-tree_test.cc
@@ -0,0 +1,141 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+
+#include <memory>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h"
+
+namespace icing {
+namespace lib {
+namespace {
+
+using ::testing::ElementsAre;
+
+TEST(AbstractSyntaxTreeTest, Simple) {
+ // foo
+ std::unique_ptr<Node> root = std::make_unique<TextNode>("foo");
+ SimpleVisitor visitor;
+ root->Accept(&visitor);
+
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText)));
+}
+
+TEST(AbstractSyntaxTreeTest, Composite) {
+ // (foo bar) OR baz
+ std::vector<std::unique_ptr<Node>> and_args;
+ and_args.push_back(std::make_unique<TextNode>("foo"));
+ and_args.push_back(std::make_unique<TextNode>("bar"));
+ auto and_node =
+ std::make_unique<NaryOperatorNode>("AND", std::move(and_args));
+
+ std::vector<std::unique_ptr<Node>> or_args;
+ or_args.push_back(std::move(and_node));
+ or_args.push_back(std::make_unique<TextNode>("baz"));
+ std::unique_ptr<Node> root =
+ std::make_unique<NaryOperatorNode>("OR", std::move(or_args));
+
+ SimpleVisitor visitor;
+ root->Accept(&visitor);
+
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator),
+ EqualsNodeInfo("baz", NodeType::kText),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator)));
+}
+
+TEST(AbstractSyntaxTreeTest, Function) {
+ // foo()
+ std::unique_ptr<Node> root =
+ std::make_unique<FunctionNode>(std::make_unique<FunctionNameNode>("foo"));
+ SimpleVisitor visitor;
+ root->Accept(&visitor);
+
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction)));
+
+ // foo("bar")
+ std::vector<std::unique_ptr<Node>> args;
+ args.push_back(std::make_unique<StringNode>("bar"));
+ root = std::make_unique<FunctionNode>(
+ std::make_unique<FunctionNameNode>("foo"), std::move(args));
+ visitor = SimpleVisitor();
+ root->Accept(&visitor);
+
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("bar", NodeType::kString),
+ EqualsNodeInfo("", NodeType::kFunction)));
+
+ // foo(bar("baz"))
+ std::vector<std::unique_ptr<Node>> inner_args;
+ inner_args.push_back(std::make_unique<StringNode>("baz"));
+ args.clear();
+ args.push_back(std::make_unique<FunctionNode>(
+ std::make_unique<FunctionNameNode>("bar"), std::move(inner_args)));
+ root = std::make_unique<FunctionNode>(
+ std::make_unique<FunctionNameNode>("foo"), std::move(args));
+ visitor = SimpleVisitor();
+ root->Accept(&visitor);
+
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("bar", NodeType::kFunctionName),
+ EqualsNodeInfo("baz", NodeType::kString),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kFunction)));
+}
+
+TEST(AbstractSyntaxTreeTest, Restriction) {
+ // sender.name:(IMPORTANT OR URGENT)
+ std::vector<std::unique_ptr<TextNode>> member_args;
+ member_args.push_back(std::make_unique<TextNode>("sender"));
+ member_args.push_back(std::make_unique<TextNode>("name"));
+
+ std::vector<std::unique_ptr<Node>> or_args;
+ or_args.push_back(std::make_unique<TextNode>("IMPORTANT"));
+ or_args.push_back(std::make_unique<TextNode>("URGENT"));
+
+ std::vector<std::unique_ptr<Node>> has_args;
+ has_args.push_back(std::make_unique<MemberNode>(std::move(member_args),
+ /*function=*/nullptr));
+ has_args.push_back(
+ std::make_unique<NaryOperatorNode>("OR", std::move(or_args)));
+
+ std::unique_ptr<Node> root =
+ std::make_unique<NaryOperatorNode>(":", std::move(has_args));
+
+ SimpleVisitor visitor;
+ root->Accept(&visitor);
+
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("sender", NodeType::kText),
+ EqualsNodeInfo("name", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("IMPORTANT", NodeType::kText),
+ EqualsNodeInfo("URGENT", NodeType::kText),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator),
+ EqualsNodeInfo(":", NodeType::kNaryOperator)));
+}
+
+} // namespace
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/lexer.cc b/icing/query/advanced_query_parser/lexer.cc
new file mode 100644
index 0000000..18932f6
--- /dev/null
+++ b/icing/query/advanced_query_parser/lexer.cc
@@ -0,0 +1,228 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/query/advanced_query_parser/lexer.h"
+
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/util/i18n-utils.h"
+
+namespace icing {
+namespace lib {
+
+bool Lexer::ConsumeWhitespace() {
+ if (current_char_ == '\0') {
+ return false;
+ }
+ if (i18n_utils::IsWhitespaceAt(query_, current_index_)) {
+ UChar32 uchar32 = i18n_utils::GetUChar32At(query_.data(), query_.length(),
+ current_index_);
+ int length = i18n_utils::GetUtf8Length(uchar32);
+ Advance(length);
+ return true;
+ }
+ return false;
+}
+
+bool Lexer::ConsumeQuerySingleChar() {
+ if (current_char_ != ':') {
+ return false;
+ }
+ tokens_.push_back({":", TokenType::COMPARATOR});
+ Advance();
+ return true;
+}
+
+bool Lexer::ConsumeScoringSingleChar() {
+ switch (current_char_) {
+ case '+':
+ tokens_.push_back({"", TokenType::PLUS});
+ break;
+ case '*':
+ tokens_.push_back({"", TokenType::TIMES});
+ break;
+ case '/':
+ tokens_.push_back({"", TokenType::DIV});
+ break;
+ default:
+ return false;
+ }
+ Advance();
+ return true;
+}
+
+bool Lexer::ConsumeGeneralSingleChar() {
+ switch (current_char_) {
+ case ',':
+ tokens_.push_back({"", TokenType::COMMA});
+ break;
+ case '.':
+ tokens_.push_back({"", TokenType::DOT});
+ break;
+ case '-':
+ tokens_.push_back({"", TokenType::MINUS});
+ break;
+ case '(':
+ tokens_.push_back({"", TokenType::LPAREN});
+ break;
+ case ')':
+ tokens_.push_back({"", TokenType::RPAREN});
+ break;
+ default:
+ return false;
+ }
+ Advance();
+ return true;
+}
+
+bool Lexer::ConsumeSingleChar() {
+ if (language_ == Language::QUERY) {
+ if (ConsumeQuerySingleChar()) {
+ return true;
+ }
+ } else if (language_ == Language::SCORING) {
+ if (ConsumeScoringSingleChar()) {
+ return true;
+ }
+ }
+ return ConsumeGeneralSingleChar();
+}
+
+bool Lexer::ConsumeComparator() {
+ if (current_char_ != '<' && current_char_ != '>' && current_char_ != '!' &&
+ current_char_ != '=') {
+ return false;
+ }
+ // Now, current_char_ must be one of '<', '>', '!', or '='.
+ // Matching for '<=', '>=', '!=', or '=='.
+ char next_char = PeekNext(1);
+ if (next_char == '=') {
+ tokens_.push_back({{current_char_, next_char}, TokenType::COMPARATOR});
+ Advance(2);
+ return true;
+ }
+ // Now, next_char must not be '='. Let's match for '<' and '>'.
+ if (current_char_ == '<' || current_char_ == '>') {
+ tokens_.push_back({{current_char_}, TokenType::COMPARATOR});
+ Advance();
+ return true;
+ }
+ return false;
+}
+
+bool Lexer::ConsumeAndOr() {
+ if (current_char_ != '&' && current_char_ != '|') {
+ return false;
+ }
+ char next_char = PeekNext(1);
+ if (current_char_ != next_char) {
+ return false;
+ }
+ if (current_char_ == '&') {
+ tokens_.push_back({"", TokenType::AND});
+ } else {
+ tokens_.push_back({"", TokenType::OR});
+ }
+ Advance(2);
+ return true;
+}
+
+bool Lexer::ConsumeStringLiteral() {
+ if (current_char_ != '"') {
+ return false;
+ }
+ std::string text;
+ Advance();
+ while (current_char_ != '\0' && current_char_ != '"') {
+ // When getting a backslash, we will always match the next character, even
+ // if the next character is a quotation mark
+ if (current_char_ == '\\') {
+ text.push_back(current_char_);
+ Advance();
+ if (current_char_ == '\0') {
+ // In this case, we are missing a terminating quotation mark.
+ break;
+ }
+ }
+ text.push_back(current_char_);
+ Advance();
+ }
+ if (current_char_ == '\0') {
+ SyntaxError("missing terminating \" character");
+ return false;
+ }
+ tokens_.push_back({text, TokenType::STRING});
+ Advance();
+ return true;
+}
+
+bool Lexer::Text() {
+ if (current_char_ == '\0') {
+ return false;
+ }
+ tokens_.push_back({"", TokenType::TEXT});
+ int token_index = tokens_.size() - 1;
+ while (!ConsumeNonText() && current_char_ != '\0') {
+ // When getting a backslash in TEXT, unescape it by accepting its following
+ // character no matter which character it is, including white spaces,
+ // operator symbols, parentheses, etc.
+ if (current_char_ == '\\') {
+ Advance();
+ if (current_char_ == '\0') {
+ SyntaxError("missing a escaping character after \\");
+ break;
+ }
+ }
+ tokens_[token_index].text.push_back(current_char_);
+ Advance();
+ if (current_char_ == '(') {
+ // A TEXT followed by a LPAREN is a FUNCTION_NAME.
+ tokens_.back().type = TokenType::FUNCTION_NAME;
+ // No need to break, since NonText() must be true at this point.
+ }
+ }
+ if (language_ == Lexer::Language::QUERY) {
+ std::string &text = tokens_[token_index].text;
+ TokenType &type = tokens_[token_index].type;
+ if (text == "AND") {
+ text.clear();
+ type = TokenType::AND;
+ } else if (text == "OR") {
+ text.clear();
+ type = TokenType::OR;
+ } else if (text == "NOT") {
+ text.clear();
+ type = TokenType::NOT;
+ }
+ }
+ return true;
+}
+
+libtextclassifier3::StatusOr<std::vector<Lexer::LexerToken>>
+Lexer::ExtractTokens() {
+ while (current_char_ != '\0') {
+ // Clear out any non-text before matching a Text.
+ while (ConsumeNonText()) {
+ }
+ Text();
+ }
+ if (!error_.empty()) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("Syntax Error: ", error_));
+ }
+ return tokens_;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/lexer.h b/icing/query/advanced_query_parser/lexer.h
new file mode 100644
index 0000000..f72affb
--- /dev/null
+++ b/icing/query/advanced_query_parser/lexer.h
@@ -0,0 +1,153 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_LEXER_H_
+#define ICING_QUERY_ADVANCED_QUERY_PARSER_LEXER_H_
+
+#include <cstdint>
+#include <string>
+#include <string_view>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+
+namespace icing {
+namespace lib {
+
+class Lexer {
+ public:
+ enum class Language { QUERY, SCORING };
+
+ enum class TokenType {
+ COMMA, // ','
+ DOT, // '.'
+ PLUS, // '+' Not allowed in QUERY language.
+ MINUS, // '-'
+ TIMES, // '*' Not allowed in QUERY language.
+ DIV, // '/' Not allowed in QUERY language.
+ LPAREN, // '('
+ RPAREN, // ')'
+ COMPARATOR, // '<=' | '<' | '>=' | '>' | '!=' | '==' | ':'
+ // Not allowed in SCORING language.
+ AND, // 'AND' | '&&' Not allowed in SCORING language.
+ OR, // 'OR' | '||' Not allowed in SCORING language.
+ NOT, // 'NOT' Not allowed in SCORING language.
+ STRING, // String literal surrounded by quotation marks
+ TEXT, // A sequence of chars that are not any above-listed operator
+ FUNCTION_NAME, // A TEXT followed by LPAREN.
+ // Whitespaces not inside a string literal will be skipped.
+ // WS: " " | "\t" | "\n" | "\r" | "\f" -> skip ;
+ };
+
+ struct LexerToken {
+ // For STRING, text will contain the raw original text of the token
+ // in between quotation marks, without unescaping.
+ //
+ // For TEXT, text will contain the text of the token after unescaping all
+ // escaped characters.
+ //
+ // For FUNCTION_NAME, this field will contain the name of the function.
+ //
+ // For COMPARATOR, this field will contain the comparator.
+ //
+ // For other types, this field will be empty.
+ std::string text;
+
+ // The type of the token.
+ TokenType type;
+ };
+
+ explicit Lexer(std::string_view query, Language language)
+ : query_(query), language_(language) {
+ Advance();
+ }
+
+ // Get a vector of LexerToken after lexing the query given in the constructor.
+ //
+ // Returns:
+ // A vector of LexerToken on success
+ // INVALID_ARGUMENT on syntax error.
+ libtextclassifier3::StatusOr<std::vector<LexerToken>> ExtractTokens();
+
+ private:
+ // Advance to current_index_ + n.
+ void Advance(uint32_t n = 1) {
+ if (current_index_ + n >= query_.size()) {
+ current_index_ = query_.size();
+ current_char_ = '\0';
+ } else {
+ current_index_ += n;
+ current_char_ = query_[current_index_];
+ }
+ }
+
+ // Get the character at current_index_ + n.
+ char PeekNext(uint32_t n = 1) {
+ if (current_index_ + n >= query_.size()) {
+ return '\0';
+ } else {
+ return query_[current_index_ + n];
+ }
+ }
+
+ void SyntaxError(std::string error) {
+ current_index_ = query_.size();
+ current_char_ = '\0';
+ error_ = std::move(error);
+ }
+
+ // Try to match a whitespace token and skip it.
+ bool ConsumeWhitespace();
+
+ // Try to match a single-char token other than '<' and '>'.
+ bool ConsumeSingleChar();
+ bool ConsumeQuerySingleChar();
+ bool ConsumeScoringSingleChar();
+ bool ConsumeGeneralSingleChar();
+
+ // Try to match a comparator token other than ':'.
+ bool ConsumeComparator();
+
+ // Try to match '&&' and '||'.
+ // 'AND' and 'OR' will be handled in Text() instead, so that 'ANDfoo' and
+ // 'fooOR' is a TEXT, instead of an 'AND' or 'OR'.
+ bool ConsumeAndOr();
+
+ // Try to match a string literal.
+ bool ConsumeStringLiteral();
+
+ // Try to match a non-text.
+ bool ConsumeNonText() {
+ return ConsumeWhitespace() || ConsumeSingleChar() ||
+ (language_ == Language::QUERY && ConsumeComparator()) ||
+ (language_ == Language::QUERY && ConsumeAndOr()) ||
+ ConsumeStringLiteral();
+ }
+
+ // Try to match TEXT, FUNCTION_NAME, 'AND', 'OR' and 'NOT'.
+ // Should make sure that NonText() is false before calling into this method.
+ bool Text();
+
+ std::string_view query_;
+ std::string error_;
+ Language language_;
+ int32_t current_index_ = -1;
+ char current_char_ = '\0';
+ std::vector<LexerToken> tokens_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_LEXER_H_
diff --git a/icing/query/advanced_query_parser/lexer_fuzz_test.cc b/icing/query/advanced_query_parser/lexer_fuzz_test.cc
new file mode 100644
index 0000000..f9190db
--- /dev/null
+++ b/icing/query/advanced_query_parser/lexer_fuzz_test.cc
@@ -0,0 +1,37 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cstdint>
+#include <memory>
+#include <string_view>
+
+#include "icing/query/advanced_query_parser/lexer.h"
+
+namespace icing {
+namespace lib {
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ std::string_view text(reinterpret_cast<const char*>(data), size);
+
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>(text, Lexer::Language::QUERY);
+ lexer->ExtractTokens();
+
+ lexer = std::make_unique<Lexer>(text, Lexer::Language::SCORING);
+ lexer->ExtractTokens();
+ return 0;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/lexer_test.cc b/icing/query/advanced_query_parser/lexer_test.cc
new file mode 100644
index 0000000..41e78fe
--- /dev/null
+++ b/icing/query/advanced_query_parser/lexer_test.cc
@@ -0,0 +1,613 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/query/advanced_query_parser/lexer.h"
+
+#include <memory>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/testing/common-matchers.h"
+
+namespace icing {
+namespace lib {
+
+using ::testing::ElementsAre;
+
+MATCHER_P2(EqualsLexerToken, text, type, "") {
+ const Lexer::LexerToken& actual = arg;
+ *result_listener << "actual is {text=" << actual.text
+ << ", type=" << static_cast<int>(actual.type)
+ << "}, but expected was {text=" << text
+ << ", type=" << static_cast<int>(type) << "}.";
+ return actual.text == text && actual.type == type;
+}
+
+MATCHER_P(EqualsLexerToken, type, "") {
+ const Lexer::LexerToken& actual = arg;
+ *result_listener << "actual is {text=" << actual.text
+ << ", type=" << static_cast<int>(actual.type)
+ << "}, but expected was {text=(empty), type="
+ << static_cast<int>(type) << "}.";
+ return actual.text.empty() && actual.type == type;
+}
+
+TEST(LexerTest, SimpleQuery) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("foo", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("fooAND", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("fooAND", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("ORfoo", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("ORfoo", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("fooANDbar", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("fooANDbar",
+ Lexer::TokenType::TEXT)));
+}
+
+TEST(LexerTest, PrefixQuery) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("foo*", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo*", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("fooAND*", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("fooAND*", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("*ORfoo", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("*ORfoo", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("fooANDbar*", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("fooANDbar*",
+ Lexer::TokenType::TEXT)));
+}
+
+TEST(LexerTest, SimpleStringQuery) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("\"foo\"", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::STRING)));
+
+ lexer = std::make_unique<Lexer>("\"fooAND\"", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("fooAND",
+ Lexer::TokenType::STRING)));
+
+ lexer = std::make_unique<Lexer>("\"ORfoo\"", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("ORfoo", Lexer::TokenType::STRING)));
+
+ lexer = std::make_unique<Lexer>("\"fooANDbar\"", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("fooANDbar",
+ Lexer::TokenType::STRING)));
+}
+
+TEST(LexerTest, TwoTermQuery) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("foo AND bar", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("bar", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("foo && bar", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("bar", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("foo&&bar", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("bar", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("foo OR \"bar\"", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::OR),
+ EqualsLexerToken("bar", Lexer::TokenType::STRING)));
+}
+
+TEST(LexerTest, QueryWithSpecialSymbol) {
+ // With escaping
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("foo\\ \\&\\&bar", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("foo &&bar",
+ Lexer::TokenType::TEXT)));
+ lexer = std::make_unique<Lexer>("foo\\&\\&bar&&baz", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo&&bar", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("baz", Lexer::TokenType::TEXT)));
+ lexer = std::make_unique<Lexer>("foo\\\"", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo\"", Lexer::TokenType::TEXT)));
+
+ // With quotation marks
+ lexer = std::make_unique<Lexer>("\"foo &&bar\"", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("foo &&bar",
+ Lexer::TokenType::STRING)));
+ lexer = std::make_unique<Lexer>("\"foo&&bar\"&&baz", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(
+ tokens,
+ ElementsAre(EqualsLexerToken("foo&&bar", Lexer::TokenType::STRING),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("baz", Lexer::TokenType::TEXT)));
+ lexer = std::make_unique<Lexer>("\"foo\\\"\"", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("foo\\\"",
+ Lexer::TokenType::STRING)));
+}
+
+TEST(LexerTest, TextInStringShouldBeOriginal) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("\"foo\\nbar\"", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("foo\\nbar",
+ Lexer::TokenType::STRING)));
+}
+
+TEST(LexerTest, QueryWithFunctionCalls) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("foo AND fun(bar)", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(
+ tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("fun", Lexer::TokenType::FUNCTION_NAME),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken("bar", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::RPAREN)));
+
+ // Not a function call
+ lexer = std::make_unique<Lexer>("foo AND fun (bar)", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("fun", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken("bar", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::RPAREN)));
+}
+
+TEST(LexerTest, QueryWithComparator) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("name: foo", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("name", Lexer::TokenType::TEXT),
+ EqualsLexerToken(":", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("foo", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("email.name:foo", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("email", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::DOT),
+ EqualsLexerToken("name", Lexer::TokenType::TEXT),
+ EqualsLexerToken(":", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("foo", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("age > 20", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT),
+ EqualsLexerToken(">", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("20", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("age>=20", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT),
+ EqualsLexerToken(">=", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("20", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("age <20", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT),
+ EqualsLexerToken("<", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("20", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("age<= 20", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT),
+ EqualsLexerToken("<=", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("20", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("age == 20", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT),
+ EqualsLexerToken("==", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("20", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("age != 20", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT),
+ EqualsLexerToken("!=", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("20", Lexer::TokenType::TEXT)));
+}
+
+TEST(LexerTest, ComplexQuery) {
+ std::unique_ptr<Lexer> lexer = std::make_unique<Lexer>(
+ "email.sender: (foo* AND bar OR pow(age, 2)>100) || (-baz foo) && "
+ "NOT verbatimSearch(\"hello world\")",
+ Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(
+ tokens,
+ ElementsAre(
+ EqualsLexerToken("email", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::DOT),
+ EqualsLexerToken("sender", Lexer::TokenType::TEXT),
+ EqualsLexerToken(":", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken("foo*", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("bar", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::OR),
+ EqualsLexerToken("pow", Lexer::TokenType::FUNCTION_NAME),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken("age", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::COMMA),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::RPAREN),
+ EqualsLexerToken(">", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("100", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::RPAREN),
+ EqualsLexerToken(Lexer::TokenType::OR),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken(Lexer::TokenType::MINUS),
+ EqualsLexerToken("baz", Lexer::TokenType::TEXT),
+ EqualsLexerToken("foo", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::RPAREN),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken(Lexer::TokenType::NOT),
+ EqualsLexerToken("verbatimSearch", Lexer::TokenType::FUNCTION_NAME),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken("hello world", Lexer::TokenType::STRING),
+ EqualsLexerToken(Lexer::TokenType::RPAREN)));
+}
+
+TEST(LexerTest, UTF8WhiteSpace) {
+ std::unique_ptr<Lexer> lexer = std::make_unique<Lexer>(
+ "\xe2\x80\x88"
+ "foo"
+ "\xe2\x80\x89"
+ "\xe2\x80\x89"
+ "bar"
+ "\xe2\x80\x8a",
+ Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT),
+ EqualsLexerToken("bar", Lexer::TokenType::TEXT)));
+}
+
+TEST(LexerTest, CJKT) {
+ std::unique_ptr<Lexer> lexer = std::make_unique<Lexer>(
+ "我 && 每天 || 走路 OR 去 -上班", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("我", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("每天", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::OR),
+ EqualsLexerToken("走路", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::OR),
+ EqualsLexerToken("去", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::MINUS),
+ EqualsLexerToken("上班", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("私&& は ||毎日 AND 仕事 -に 歩い て い ます",
+ Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("私", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("は", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::OR),
+ EqualsLexerToken("毎日", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("仕事", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::MINUS),
+ EqualsLexerToken("に", Lexer::TokenType::TEXT),
+ EqualsLexerToken("歩い", Lexer::TokenType::TEXT),
+ EqualsLexerToken("て", Lexer::TokenType::TEXT),
+ EqualsLexerToken("い", Lexer::TokenType::TEXT),
+ EqualsLexerToken("ます", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("ញុំ&&ដើរទៅ||ធ្វើការ-រាល់ថ្ងៃ",
+ Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("ញុំ", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::AND),
+ EqualsLexerToken("ដើរទៅ", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::OR),
+ EqualsLexerToken("ធ្វើការ", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::MINUS),
+ EqualsLexerToken("រាល់ថ្ងៃ", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>(
+ "나는"
+ "\xe2\x80\x88" // White Space
+ "매일"
+ "\xe2\x80\x89" // White Space
+ "출근합니다",
+ Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(
+ tokens,
+ ElementsAre(EqualsLexerToken("나는", Lexer::TokenType::TEXT),
+ EqualsLexerToken("매일", Lexer::TokenType::TEXT),
+ EqualsLexerToken("출근합니다", Lexer::TokenType::TEXT)));
+}
+
+TEST(LexerTest, SyntaxError) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("\"foo", Lexer::Language::QUERY);
+ EXPECT_THAT(lexer->ExtractTokens(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ lexer = std::make_unique<Lexer>("\"foo\\", Lexer::Language::QUERY);
+ EXPECT_THAT(lexer->ExtractTokens(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ lexer = std::make_unique<Lexer>("foo\\", Lexer::Language::QUERY);
+ EXPECT_THAT(lexer->ExtractTokens(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+// "!", "=", "&" and "|" should be treated as valid symbols in TEXT, if not
+// matched as "!=", "==", "&&", or "||".
+TEST(LexerTest, SpecialSymbolAsText) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("age=20", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("age=20", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("age !20", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT),
+ EqualsLexerToken("!20", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("foo& bar", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo&", Lexer::TokenType::TEXT),
+ EqualsLexerToken("bar", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("foo | bar", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT),
+ EqualsLexerToken("|", Lexer::TokenType::TEXT),
+ EqualsLexerToken("bar", Lexer::TokenType::TEXT)));
+}
+
+TEST(LexerTest, ScoringArithmetic) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("1 + 2", Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::PLUS),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("1+2*3/4", Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::PLUS),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::TIMES),
+ EqualsLexerToken("3", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::DIV),
+ EqualsLexerToken("4", Lexer::TokenType::TEXT)));
+
+ // Arithmetic operators will not be produced in query language.
+ lexer = std::make_unique<Lexer>("1 + 2", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT),
+ EqualsLexerToken("+", Lexer::TokenType::TEXT),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("1+2*3/4", Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1+2*3/4", Lexer::TokenType::TEXT)));
+}
+
+// Currently, in scoring language, the lexer will view these logic operators as
+// TEXTs. In the future, they may be rejected instead.
+TEST(LexerTest, LogicOperatorNotInScoring) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("1 && 2", Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT),
+ EqualsLexerToken("&&", Lexer::TokenType::TEXT),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("1&&2", Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1&&2", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("1&&2 ||3", Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1&&2", Lexer::TokenType::TEXT),
+ EqualsLexerToken("||3", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("1 AND 2 OR 3 AND NOT 4",
+ Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT),
+ EqualsLexerToken("AND", Lexer::TokenType::TEXT),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT),
+ EqualsLexerToken("OR", Lexer::TokenType::TEXT),
+ EqualsLexerToken("3", Lexer::TokenType::TEXT),
+ EqualsLexerToken("AND", Lexer::TokenType::TEXT),
+ EqualsLexerToken("NOT", Lexer::TokenType::TEXT),
+ EqualsLexerToken("4", Lexer::TokenType::TEXT)));
+}
+
+TEST(LexerTest, ComparatorNotInScoring) {
+ std::unique_ptr<Lexer> lexer =
+ std::make_unique<Lexer>("1 > 2", Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT),
+ EqualsLexerToken(">", Lexer::TokenType::TEXT),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("1>2", Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1>2", Lexer::TokenType::TEXT)));
+
+ lexer = std::make_unique<Lexer>("1>2>=3 <= 4:5== 6<7<=8!= 9",
+ Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1>2>=3", Lexer::TokenType::TEXT),
+ EqualsLexerToken("<=", Lexer::TokenType::TEXT),
+ EqualsLexerToken("4:5==", Lexer::TokenType::TEXT),
+ EqualsLexerToken("6<7<=8!=", Lexer::TokenType::TEXT),
+ EqualsLexerToken("9", Lexer::TokenType::TEXT)));
+
+ // Comparator should be produced in query language.
+ lexer = std::make_unique<Lexer>("1>2>=3 <= 4:5== 6<7<=8!= 9",
+ Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens());
+ EXPECT_THAT(tokens,
+ ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT),
+ EqualsLexerToken(">", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT),
+ EqualsLexerToken(">=", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("3", Lexer::TokenType::TEXT),
+ EqualsLexerToken("<=", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("4", Lexer::TokenType::TEXT),
+ EqualsLexerToken(":", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("5", Lexer::TokenType::TEXT),
+ EqualsLexerToken("==", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("6", Lexer::TokenType::TEXT),
+ EqualsLexerToken("<", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("7", Lexer::TokenType::TEXT),
+ EqualsLexerToken("<=", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("8", Lexer::TokenType::TEXT),
+ EqualsLexerToken("!=", Lexer::TokenType::COMPARATOR),
+ EqualsLexerToken("9", Lexer::TokenType::TEXT)));
+}
+
+TEST(LexerTest, ComplexScoring) {
+ std::unique_ptr<Lexer> lexer = std::make_unique<Lexer>(
+ "1/log( (CreationTimestamp(document) + LastUsedTimestamp(document)) / 2 "
+ ") * pow(2.3, DocumentScore())",
+ Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens,
+ lexer->ExtractTokens());
+ EXPECT_THAT(
+ tokens,
+ ElementsAre(
+ EqualsLexerToken("1", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::DIV),
+ EqualsLexerToken("log", Lexer::TokenType::FUNCTION_NAME),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken("CreationTimestamp",
+ Lexer::TokenType::FUNCTION_NAME),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken("document", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::RPAREN),
+ EqualsLexerToken(Lexer::TokenType::PLUS),
+ EqualsLexerToken("LastUsedTimestamp",
+ Lexer::TokenType::FUNCTION_NAME),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken("document", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::RPAREN),
+ EqualsLexerToken(Lexer::TokenType::RPAREN),
+ EqualsLexerToken(Lexer::TokenType::DIV),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::RPAREN),
+ EqualsLexerToken(Lexer::TokenType::TIMES),
+ EqualsLexerToken("pow", Lexer::TokenType::FUNCTION_NAME),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken("2", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::DOT),
+ EqualsLexerToken("3", Lexer::TokenType::TEXT),
+ EqualsLexerToken(Lexer::TokenType::COMMA),
+ EqualsLexerToken("DocumentScore", Lexer::TokenType::FUNCTION_NAME),
+ EqualsLexerToken(Lexer::TokenType::LPAREN),
+ EqualsLexerToken(Lexer::TokenType::RPAREN),
+ EqualsLexerToken(Lexer::TokenType::RPAREN)));
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/parser.cc b/icing/query/advanced_query_parser/parser.cc
new file mode 100644
index 0000000..086f038
--- /dev/null
+++ b/icing/query/advanced_query_parser/parser.cc
@@ -0,0 +1,414 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/query/advanced_query_parser/parser.h"
+
+#include <memory>
+#include <string_view>
+
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/legacy/core/icing-string-util.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+std::unique_ptr<Node> CreateNaryNode(
+ std::string_view operator_text,
+ std::vector<std::unique_ptr<Node>>&& operands) {
+ if (operands.empty()) {
+ return nullptr;
+ }
+ if (operands.size() == 1) {
+ return std::move(operands.at(0));
+ }
+ return std::make_unique<NaryOperatorNode>(std::string(operator_text),
+ std::move(operands));
+}
+
+} // namespace
+
+libtextclassifier3::Status Parser::Consume(Lexer::TokenType token_type) {
+ if (!Match(token_type)) {
+ return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
+ "Unable to consume token %d.", static_cast<int>(token_type)));
+ }
+ ++current_token_;
+ return libtextclassifier3::Status::OK;
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<TextNode>> Parser::ConsumeText() {
+ if (!Match(Lexer::TokenType::TEXT)) {
+ return absl_ports::InvalidArgumentError("Unable to consume token as TEXT.");
+ }
+ auto text_node = std::make_unique<TextNode>(std::move(current_token_->text));
+ ++current_token_;
+ return text_node;
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<FunctionNameNode>>
+Parser::ConsumeFunctionName() {
+ if (!Match(Lexer::TokenType::FUNCTION_NAME)) {
+ return absl_ports::InvalidArgumentError(
+ "Unable to consume token as FUNCTION_NAME.");
+ }
+ auto function_name_node =
+ std::make_unique<FunctionNameNode>(std::move(current_token_->text));
+ ++current_token_;
+ return function_name_node;
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<StringNode>>
+Parser::ConsumeString() {
+ if (!Match(Lexer::TokenType::STRING)) {
+ return absl_ports::InvalidArgumentError(
+ "Unable to consume token as STRING.");
+ }
+ auto node = std::make_unique<StringNode>(std::move(current_token_->text));
+ ++current_token_;
+ return node;
+}
+
+libtextclassifier3::StatusOr<std::string> Parser::ConsumeComparator() {
+ if (!Match(Lexer::TokenType::COMPARATOR)) {
+ return absl_ports::InvalidArgumentError(
+ "Unable to consume token as COMPARATOR.");
+ }
+ std::string comparator = std::move(current_token_->text);
+ ++current_token_;
+ return comparator;
+}
+
+// member
+// : TEXT (DOT TEXT)* (DOT function)?
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<MemberNode>>
+Parser::ConsumeMember() {
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<TextNode> text_node, ConsumeText());
+ std::vector<std::unique_ptr<TextNode>> children;
+ children.push_back(std::move(text_node));
+
+ while (Match(Lexer::TokenType::DOT)) {
+ Consume(Lexer::TokenType::DOT);
+ if (MatchFunction()) {
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<FunctionNode> function_node,
+ ConsumeFunction());
+ // Once a function is matched, we should exit the current rule based on
+ // the grammar.
+ return std::make_unique<MemberNode>(std::move(children),
+ std::move(function_node));
+ }
+ ICING_ASSIGN_OR_RETURN(text_node, ConsumeText());
+ children.push_back(std::move(text_node));
+ }
+ return std::make_unique<MemberNode>(std::move(children),
+ /*function=*/nullptr);
+}
+
+// function
+// : FUNCTION_NAME LPAREN argList? RPAREN
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<FunctionNode>>
+Parser::ConsumeFunction() {
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<FunctionNameNode> function_name,
+ ConsumeFunctionName());
+ ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::LPAREN));
+
+ std::vector<std::unique_ptr<Node>> args;
+ if (Match(Lexer::TokenType::RPAREN)) {
+ // Got empty argument.
+ ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::RPAREN));
+ } else {
+ ICING_ASSIGN_OR_RETURN(args, ConsumeArgs());
+ ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::RPAREN));
+ }
+ return std::make_unique<FunctionNode>(std::move(function_name),
+ std::move(args));
+}
+
+// comparable
+// : STRING
+// | member
+// | function
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>>
+Parser::ConsumeComparable() {
+ if (Match(Lexer::TokenType::STRING)) {
+ return ConsumeString();
+ } else if (MatchMember()) {
+ return ConsumeMember();
+ }
+ // The current token sequence isn't a STRING or member. Therefore, it must be
+ // a function.
+ return ConsumeFunction();
+}
+
+// composite
+// : LPAREN expression RPAREN
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeComposite() {
+ ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::LPAREN));
+
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> expression, ConsumeExpression());
+
+ ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::RPAREN));
+ return expression;
+}
+
+// argList
+// : expression (COMMA expression)*
+// ;
+libtextclassifier3::StatusOr<std::vector<std::unique_ptr<Node>>>
+Parser::ConsumeArgs() {
+ std::vector<std::unique_ptr<Node>> args;
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> arg, ConsumeExpression());
+ args.push_back(std::move(arg));
+ while (Match(Lexer::TokenType::COMMA)) {
+ Consume(Lexer::TokenType::COMMA);
+ ICING_ASSIGN_OR_RETURN(arg, ConsumeExpression());
+ args.push_back(std::move(arg));
+ }
+ return args;
+}
+
+// restriction
+// : comparable (COMPARATOR (comparable | composite))?
+// ;
+// COMPARATOR will not be produced in Scoring Lexer.
+libtextclassifier3::StatusOr<std::unique_ptr<Node>>
+Parser::ConsumeRestriction() {
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> comparable, ConsumeComparable());
+
+ if (!Match(Lexer::TokenType::COMPARATOR)) {
+ return comparable;
+ }
+ ICING_ASSIGN_OR_RETURN(std::string operator_text, ConsumeComparator());
+ std::unique_ptr<Node> arg;
+ if (MatchComposite()) {
+ ICING_ASSIGN_OR_RETURN(arg, ConsumeComposite());
+ } else if (MatchComparable()) {
+ ICING_ASSIGN_OR_RETURN(arg, ConsumeComparable());
+ } else {
+ return absl_ports::InvalidArgumentError(
+ "ARG: must begin with LPAREN or FIRST(comparable)");
+ }
+ std::vector<std::unique_ptr<Node>> args;
+ args.push_back(std::move(comparable));
+ args.push_back(std::move(arg));
+ return std::make_unique<NaryOperatorNode>(std::move(operator_text),
+ std::move(args));
+}
+
+// simple
+// : restriction
+// | composite
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeSimple() {
+ if (MatchComposite()) {
+ return ConsumeComposite();
+ } else if (MatchRestriction()) {
+ return ConsumeRestriction();
+ }
+ return absl_ports::InvalidArgumentError(
+ "SIMPLE: must be a restriction or composite");
+}
+
+// term
+// : NOT? simple
+// | MINUS simple
+// ;
+// NOT will not be produced in Scoring Lexer.
+libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeTerm() {
+ if (!Match(Lexer::TokenType::NOT) && !Match(Lexer::TokenType::MINUS)) {
+ return ConsumeSimple();
+ }
+ std::string operator_text;
+ if (language_ == Lexer::Language::SCORING) {
+ ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::MINUS));
+ operator_text = "MINUS";
+ } else {
+ if (Match(Lexer::TokenType::NOT)) {
+ Consume(Lexer::TokenType::NOT);
+ } else {
+ Consume(Lexer::TokenType::MINUS);
+ }
+ operator_text = "NOT";
+ }
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> simple, ConsumeSimple());
+ return std::make_unique<UnaryOperatorNode>(operator_text, std::move(simple));
+}
+
+// factor
+// : term (OR term)*
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeFactor() {
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> term, ConsumeTerm());
+ std::vector<std::unique_ptr<Node>> terms;
+ terms.push_back(std::move(term));
+
+ while (Match(Lexer::TokenType::OR)) {
+ Consume(Lexer::TokenType::OR);
+ ICING_ASSIGN_OR_RETURN(term, ConsumeTerm());
+ terms.push_back(std::move(term));
+ }
+
+ return CreateNaryNode("OR", std::move(terms));
+}
+
+// sequence
+// : (factor)+
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeSequence() {
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> factor, ConsumeFactor());
+ std::vector<std::unique_ptr<Node>> factors;
+ factors.push_back(std::move(factor));
+
+ while (MatchFactor()) {
+ ICING_ASSIGN_OR_RETURN(factor, ConsumeFactor());
+ factors.push_back(std::move(factor));
+ }
+
+ return CreateNaryNode("AND", std::move(factors));
+}
+
+// expression
+// : sequence (AND sequence)*
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>>
+Parser::ConsumeQueryExpression() {
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> sequence, ConsumeSequence());
+ std::vector<std::unique_ptr<Node>> sequences;
+ sequences.push_back(std::move(sequence));
+
+ while (Match(Lexer::TokenType::AND)) {
+ Consume(Lexer::TokenType::AND);
+ ICING_ASSIGN_OR_RETURN(sequence, ConsumeSequence());
+ sequences.push_back(std::move(sequence));
+ }
+
+ return CreateNaryNode("AND", std::move(sequences));
+}
+
+// multExpr
+// : term ((TIMES | DIV) term)*
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeMultExpr() {
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> node, ConsumeTerm());
+ std::vector<std::unique_ptr<Node>> stack;
+ stack.push_back(std::move(node));
+
+ while (Match(Lexer::TokenType::TIMES) || Match(Lexer::TokenType::DIV)) {
+ while (Match(Lexer::TokenType::TIMES)) {
+ Consume(Lexer::TokenType::TIMES);
+ ICING_ASSIGN_OR_RETURN(node, ConsumeTerm());
+ stack.push_back(std::move(node));
+ }
+ node = CreateNaryNode("TIMES", std::move(stack));
+ stack.clear();
+ stack.push_back(std::move(node));
+
+ while (Match(Lexer::TokenType::DIV)) {
+ Consume(Lexer::TokenType::DIV);
+ ICING_ASSIGN_OR_RETURN(node, ConsumeTerm());
+ stack.push_back(std::move(node));
+ }
+ node = CreateNaryNode("DIV", std::move(stack));
+ stack.clear();
+ stack.push_back(std::move(node));
+ }
+
+ return std::move(stack[0]);
+}
+
+// expression
+// : multExpr ((PLUS | MINUS) multExpr)*
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>>
+Parser::ConsumeScoringExpression() {
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> node, ConsumeMultExpr());
+ std::vector<std::unique_ptr<Node>> stack;
+ stack.push_back(std::move(node));
+
+ while (Match(Lexer::TokenType::PLUS) || Match(Lexer::TokenType::MINUS)) {
+ while (Match(Lexer::TokenType::PLUS)) {
+ Consume(Lexer::TokenType::PLUS);
+ ICING_ASSIGN_OR_RETURN(node, ConsumeMultExpr());
+ stack.push_back(std::move(node));
+ }
+ node = CreateNaryNode("PLUS", std::move(stack));
+ stack.clear();
+ stack.push_back(std::move(node));
+
+ while (Match(Lexer::TokenType::MINUS)) {
+ Consume(Lexer::TokenType::MINUS);
+ ICING_ASSIGN_OR_RETURN(node, ConsumeMultExpr());
+ stack.push_back(std::move(node));
+ }
+ node = CreateNaryNode("MINUS", std::move(stack));
+ stack.clear();
+ stack.push_back(std::move(node));
+ }
+
+ return std::move(stack[0]);
+}
+
+libtextclassifier3::StatusOr<std::unique_ptr<Node>>
+Parser::ConsumeExpression() {
+ switch (language_) {
+ case Lexer::Language::QUERY:
+ return ConsumeQueryExpression();
+ case Lexer::Language::SCORING:
+ return ConsumeScoringExpression();
+ }
+}
+
+// query
+// : expression? EOF
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeQuery() {
+ language_ = Lexer::Language::QUERY;
+ std::unique_ptr<Node> node;
+ if (current_token_ != lexer_tokens_.end()) {
+ ICING_ASSIGN_OR_RETURN(node, ConsumeExpression());
+ }
+ if (current_token_ != lexer_tokens_.end()) {
+ return absl_ports::InvalidArgumentError(
+ "Error parsing Query. Must reach EOF after parsing Expression!");
+ }
+ return node;
+}
+
+// scoring
+// : expression EOF
+// ;
+libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeScoring() {
+ language_ = Lexer::Language::SCORING;
+ std::unique_ptr<Node> node;
+ if (current_token_ == lexer_tokens_.end()) {
+ return absl_ports::InvalidArgumentError("Got empty scoring expression!");
+ }
+ ICING_ASSIGN_OR_RETURN(node, ConsumeExpression());
+ if (current_token_ != lexer_tokens_.end()) {
+ return absl_ports::InvalidArgumentError(
+ "Error parsing the scoring expression. Must reach EOF after parsing "
+ "Expression!");
+ }
+ return node;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/parser.h b/icing/query/advanced_query_parser/parser.h
new file mode 100644
index 0000000..330b8b9
--- /dev/null
+++ b/icing/query/advanced_query_parser/parser.h
@@ -0,0 +1,140 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_PARSER_H_
+#define ICING_QUERY_ADVANCED_QUERY_PARSER_PARSER_H_
+
+#include <memory>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/query/advanced_query_parser/lexer.h"
+
+namespace icing {
+namespace lib {
+
+class Parser {
+ public:
+ static Parser Create(std::vector<Lexer::LexerToken>&& lexer_tokens) {
+ return Parser(std::move(lexer_tokens));
+ }
+
+ // Returns:
+ // On success, pointer to the root node of the AST
+ // INVALID_ARGUMENT for input that does not conform to the grammar
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeQuery();
+
+ // Returns:
+ // On success, pointer to the root node of the AST
+ // INVALID_ARGUMENT for input that does not conform to the grammar
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeScoring();
+
+ private:
+ explicit Parser(std::vector<Lexer::LexerToken>&& lexer_tokens)
+ : lexer_tokens_(std::move(lexer_tokens)),
+ current_token_(lexer_tokens_.begin()) {}
+
+ // Match Functions
+ // These functions are used to test whether the current_token matches a member
+ // of the FIRST set of a particular symbol in our grammar.
+ bool Match(Lexer::TokenType token_type) const {
+ return current_token_ != lexer_tokens_.end() &&
+ current_token_->type == token_type;
+ }
+
+ bool MatchMember() const { return Match(Lexer::TokenType::TEXT); }
+
+ bool MatchFunction() const { return Match(Lexer::TokenType::FUNCTION_NAME); }
+
+ bool MatchComparable() const {
+ return Match(Lexer::TokenType::STRING) || MatchMember() || MatchFunction();
+ }
+
+ bool MatchComposite() const { return Match(Lexer::TokenType::LPAREN); }
+
+ bool MatchRestriction() const { return MatchComparable(); }
+
+ bool MatchSimple() const { return MatchRestriction() || MatchComposite(); }
+
+ bool MatchTerm() const {
+ return MatchSimple() || Match(Lexer::TokenType::NOT) ||
+ Match(Lexer::TokenType::MINUS);
+ }
+
+ bool MatchFactor() const { return MatchTerm(); }
+
+ // Consume Functions
+ // These functions attempt to parse the token sequence starting at
+ // current_token_.
+ // Returns INVALID_ARGUMENT if unable to parse the token sequence starting at
+ // current_token_ as that particular grammar symbol. There are no guarantees
+ // about what state current_token and lexer_tokens_ are in when returning an
+ // error.
+ //
+ // Consume functions for terminal symbols. These are the only Consume
+ // functions that will directly modify current_token_.
+ // The Consume functions for terminals will guarantee not to modify
+ // current_token_ and lexer_tokens_ when returning an error.
+ libtextclassifier3::Status Consume(Lexer::TokenType token_type);
+
+ libtextclassifier3::StatusOr<std::unique_ptr<TextNode>> ConsumeText();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<FunctionNameNode>>
+ ConsumeFunctionName();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<StringNode>> ConsumeString();
+
+ libtextclassifier3::StatusOr<std::string> ConsumeComparator();
+
+ // Consume functions for non-terminal symbols.
+ libtextclassifier3::StatusOr<std::unique_ptr<MemberNode>> ConsumeMember();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<FunctionNode>> ConsumeFunction();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeComparable();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeComposite();
+
+ libtextclassifier3::StatusOr<std::vector<std::unique_ptr<Node>>>
+ ConsumeArgs();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeRestriction();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeSimple();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeTerm();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeFactor();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeSequence();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeQueryExpression();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeMultExpr();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>>
+ ConsumeScoringExpression();
+
+ libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeExpression();
+
+ std::vector<Lexer::LexerToken> lexer_tokens_;
+ std::vector<Lexer::LexerToken>::const_iterator current_token_;
+ Lexer::Language language_ = Lexer::Language::QUERY;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_PARSER_H_
diff --git a/icing/query/advanced_query_parser/parser_integration_test.cc b/icing/query/advanced_query_parser/parser_integration_test.cc
new file mode 100644
index 0000000..75be15b
--- /dev/null
+++ b/icing/query/advanced_query_parser/parser_integration_test.cc
@@ -0,0 +1,945 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/query/advanced_query_parser/lexer.h"
+#include "icing/query/advanced_query_parser/parser.h"
+#include "icing/testing/common-matchers.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::IsNull;
+
+TEST(ParserIntegrationTest, EmptyQuery) {
+ std::string query = "";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+ EXPECT_THAT(tree_root, IsNull());
+}
+
+TEST(ParserIntegrationTest, EmptyScoring) {
+ std::string query = "";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeScoring(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, SingleTerm) {
+ std::string query = "foo";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // member
+ // |
+ // text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+TEST(ParserIntegrationTest, ImplicitAnd) {
+ std::string query = "foo bar";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // AND
+ // / \
+ // member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, AND }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, Or) {
+ std::string query = "foo OR bar";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // OR
+ // / \
+ // member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, OR }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, And) {
+ std::string query = "foo AND bar";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // AND
+ // / \
+ // member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, AND }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, Not) {
+ std::string query = "NOT foo";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // NOT
+ // |
+ // member
+ // |
+ // text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, NOT }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("NOT", NodeType::kUnaryOperator)));
+}
+
+TEST(ParserIntegrationTest, Minus) {
+ std::string query = "-foo";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // NOT
+ // |
+ // member
+ // |
+ // text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, NOT }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("NOT", NodeType::kUnaryOperator)));
+}
+
+TEST(ParserIntegrationTest, Has) {
+ std::string query = "subject:foo";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // :
+ // / \
+ // member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, binaryOp }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("subject", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo(":", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, HasNested) {
+ std::string query = "sender.name:foo";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // :
+ // / \
+ // member member
+ // / \ |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, text, member, text, member, binaryOp }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("sender", NodeType::kText),
+ EqualsNodeInfo("name", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo(":", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, EmptyFunction) {
+ std::string query = "foo()";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // function
+ // |
+ // function_name
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, function }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction)));
+}
+
+TEST(ParserIntegrationTest, FunctionSingleArg) {
+ std::string query = "foo(\"bar\")";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // function
+ // / \
+ // function_name string
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, string, function }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("bar", NodeType::kString),
+ EqualsNodeInfo("", NodeType::kFunction)));
+}
+
+TEST(ParserIntegrationTest, FunctionMultiArg) {
+ std::string query = "foo(\"bar\", \"baz\")";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // function
+ // / | \
+ // function_name string string
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, string, string, function }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("bar", NodeType::kString),
+ EqualsNodeInfo("baz", NodeType::kString),
+ EqualsNodeInfo("", NodeType::kFunction)));
+}
+
+TEST(ParserIntegrationTest, FunctionNested) {
+ std::string query = "foo(bar())";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // function
+ // / \
+ // function_name function
+ // |
+ // function_name
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, function_name, function, function }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("bar", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kFunction)));
+}
+
+TEST(ParserIntegrationTest, FunctionWithTrailingSequence) {
+ std::string query = "foo() OR bar";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // OR
+ // / \
+ // function member
+ // | |
+ // function_name text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, function, text, member, OR }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, Composite) {
+ std::string query = "foo OR (bar baz)";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // OR
+ // / \
+ // member AND
+ // | / \
+ // text member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, text, member, AND, OR }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("baz", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, CompositeWithTrailingSequence) {
+ std::string query = "(bar baz) OR foo";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // OR
+ // / \
+ // AND member
+ // / \ |
+ // member member text
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, AND, text, member, OR }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("baz", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator),
+ EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, Complex) {
+ std::string query = "foo bar:baz OR pal(\"bat\")";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // AND
+ // / \
+ // member OR
+ // | / \
+ // text : function
+ // / \ / \
+ // member member function_name string
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, text, member, :, function_name, string,
+ // function, OR, AND }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("baz", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo(":", NodeType::kNaryOperator),
+ EqualsNodeInfo("pal", NodeType::kFunctionName),
+ EqualsNodeInfo("bat", NodeType::kString),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, InvalidHas) {
+ std::string query = "foo:"; // No right hand operand to :
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, InvalidComposite) {
+ std::string query = "(foo bar"; // No terminating RPAREN
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, InvalidMember) {
+ std::string query = "foo."; // DOT must have succeeding TEXT
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, InvalidOr) {
+ std::string query = "foo OR"; // No right hand operand to OR
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, InvalidAnd) {
+ std::string query = "foo AND"; // No right hand operand to AND
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, InvalidNot) {
+ std::string query = "NOT"; // No right hand operand to NOT
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, InvalidMinus) {
+ std::string query = "-"; // No right hand operand to -
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, InvalidFunctionCallNoRparen) {
+ std::string query = "foo("; // No terminating RPAREN
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, InvalidFunctionArgsHangingComma) {
+ std::string query = "foo(\"bar\",)"; // no valid arg following COMMA
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserIntegrationTest, ScoringPlus) {
+ std::string scoring = "1 + 1 + 1";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // PLUS
+ // / | \
+ // member member member
+ // | | |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, ScoringMinus) {
+ std::string scoring = "1 - 1 - 1";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // MINUS
+ // / | \
+ // member member member
+ // | | |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("MINUS", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, ScoringUnaryMinus) {
+ std::string scoring = "1 + -1 + 1";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // PLUS
+ // / | \
+ // member MINUS member
+ // | | |
+ // text member text
+ // |
+ // text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("MINUS", NodeType::kUnaryOperator),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, ScoringPlusMinus) {
+ std::string scoring = "11 + 12 - 13 + 14";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // PLUS
+ // / \
+ // MINUS member
+ // / \ |
+ // PLUS member text
+ // / \ |
+ // member member text
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("11", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("12", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator),
+ EqualsNodeInfo("13", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("MINUS", NodeType::kNaryOperator),
+ EqualsNodeInfo("14", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, ScoringTimes) {
+ std::string scoring = "1 * 1 * 1";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // TIMES
+ // / | \
+ // member member member
+ // | | |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("TIMES", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, ScoringDiv) {
+ std::string scoring = "1 / 1 / 1";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // DIV
+ // / | \
+ // member member member
+ // | | |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("DIV", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, ScoringTimesDiv) {
+ std::string scoring = "11 / 12 * 13 / 14 / 15";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // DIV
+ // / | \
+ // TIMES member member
+ // / \ | |
+ // DIV member text text
+ // / \ |
+ // member member text
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("11", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("12", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("DIV", NodeType::kNaryOperator),
+ EqualsNodeInfo("13", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("TIMES", NodeType::kNaryOperator),
+ EqualsNodeInfo("14", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("15", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("DIV", NodeType::kNaryOperator)));
+}
+
+TEST(ParserIntegrationTest, ComplexScoring) {
+ // With parentheses in function arguments.
+ std::string scoring = "1 + pow((2 * sin(3)), 4) + -5 / 6";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ std::vector<NodeInfo> node = visitor.nodes();
+ EXPECT_THAT(node,
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("pow", NodeType::kFunctionName),
+ EqualsNodeInfo("2", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("sin", NodeType::kFunctionName),
+ EqualsNodeInfo("3", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("TIMES", NodeType::kNaryOperator),
+ EqualsNodeInfo("4", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("5", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("MINUS", NodeType::kUnaryOperator),
+ EqualsNodeInfo("6", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("DIV", NodeType::kNaryOperator),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator)));
+
+ // Without parentheses in function arguments.
+ scoring = "1 + pow(2 * sin(3), 4) + -5 / 6";
+ lexer = Lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(lexer_tokens, lexer.ExtractTokens());
+ parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(tree_root, parser.ConsumeScoring());
+ visitor = SimpleVisitor();
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(), ElementsAreArray(node));
+}
+
+TEST(ParserIntegrationTest, ScoringMemberFunction) {
+ std::string scoring = "this.CreationTimestamp()";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // member
+ // / \
+ // text function
+ // |
+ // function_name
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(
+ visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("this", NodeType::kText),
+ EqualsNodeInfo("CreationTimestamp", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+TEST(ParserIntegrationTest, QueryMemberFunction) {
+ std::string query = "this.foo()";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // member
+ // / \
+ // text function
+ // |
+ // function_name
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("this", NodeType::kText),
+ EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+TEST(ParserIntegrationTest, ScoringComplexMemberFunction) {
+ std::string scoring = "a.b.fun(c, d)";
+ Lexer lexer(scoring, Lexer::Language::SCORING);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // member
+ // / | \
+ // text text function
+ // / | \
+ // function_name member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("a", NodeType::kText),
+ EqualsNodeInfo("b", NodeType::kText),
+ EqualsNodeInfo("fun", NodeType::kFunctionName),
+ EqualsNodeInfo("c", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("d", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+TEST(ParserTest, QueryComplexMemberFunction) {
+ std::string query = "this.abc.fun(def, ghi)";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // member
+ // / | \
+ // text text function
+ // / | \
+ // function_name member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("this", NodeType::kText),
+ EqualsNodeInfo("abc", NodeType::kText),
+ EqualsNodeInfo("fun", NodeType::kFunctionName),
+ EqualsNodeInfo("def", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("ghi", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/parser_test.cc b/icing/query/advanced_query_parser/parser_test.cc
new file mode 100644
index 0000000..f997329
--- /dev/null
+++ b/icing/query/advanced_query_parser/parser_test.cc
@@ -0,0 +1,1043 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/query/advanced_query_parser/parser.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/query/advanced_query_parser/lexer.h"
+#include "icing/testing/common-matchers.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+using ::testing::IsNull;
+
+TEST(ParserTest, EmptyQuery) {
+ std::vector<Lexer::LexerToken> lexer_tokens;
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+ EXPECT_THAT(tree_root, IsNull());
+}
+
+TEST(ParserTest, EmptyScoring) {
+ std::vector<Lexer::LexerToken> lexer_tokens;
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeScoring(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, SingleTerm) {
+ // Query: "foo"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // member
+ // |
+ // text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+TEST(ParserTest, ImplicitAnd) {
+ // Query: "foo bar"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT}, {"bar", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // AND
+ // / \
+ // member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, AND }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, Or) {
+ // Query: "foo OR bar"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::OR},
+ {"bar", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // OR
+ // / \
+ // member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, OR }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, And) {
+ // Query: "foo AND bar"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::AND},
+ {"bar", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // AND
+ // / \
+ // member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, AND }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, Not) {
+ // Query: "NOT foo"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"", Lexer::TokenType::NOT}, {"foo", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // NOT
+ // |
+ // member
+ // |
+ // text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, NOT }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("NOT", NodeType::kUnaryOperator)));
+}
+
+TEST(ParserTest, Minus) {
+ // Query: "-foo"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"", Lexer::TokenType::MINUS}, {"foo", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // NOT
+ // |
+ // member
+ // |
+ // text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, NOT }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("NOT", NodeType::kUnaryOperator)));
+}
+
+TEST(ParserTest, Has) {
+ // Query: "subject:foo"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"subject", Lexer::TokenType::TEXT},
+ {":", Lexer::TokenType::COMPARATOR},
+ {"foo", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // :
+ // / \
+ // member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, binaryOp }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("subject", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo(":", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, HasNested) {
+ // Query: "sender.name:foo"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"sender", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::DOT},
+ {"name", Lexer::TokenType::TEXT},
+ {":", Lexer::TokenType::COMPARATOR},
+ {"foo", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // :
+ // / \
+ // member member
+ // / \ |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, text, member, text, member, binaryOp }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("sender", NodeType::kText),
+ EqualsNodeInfo("name", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo(":", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, EmptyFunction) {
+ // Query: "foo()"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // function
+ // |
+ // function_name
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, function }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction)));
+}
+
+TEST(ParserTest, FunctionSingleArg) {
+ // Query: "foo("bar")"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"bar", Lexer::TokenType::STRING},
+ {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // function
+ // / \
+ // function_name string
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, string, function }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("bar", NodeType::kString),
+ EqualsNodeInfo("", NodeType::kFunction)));
+}
+
+TEST(ParserTest, FunctionMultiArg) {
+ // Query: "foo("bar", "baz")"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN},
+ {"bar", Lexer::TokenType::STRING}, {"", Lexer::TokenType::COMMA},
+ {"baz", Lexer::TokenType::STRING}, {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // function
+ // / | \
+ // function_name string string
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, string, string, function }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("bar", NodeType::kString),
+ EqualsNodeInfo("baz", NodeType::kString),
+ EqualsNodeInfo("", NodeType::kFunction)));
+}
+
+TEST(ParserTest, FunctionNested) {
+ // Query: "foo(bar())"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN},
+ {"bar", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN},
+ {"", Lexer::TokenType::RPAREN}, {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // function
+ // / \
+ // function_name function
+ // |
+ // function_name
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, function_name, function, function }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("bar", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kFunction)));
+}
+
+TEST(ParserTest, FunctionWithTrailingSequence) {
+ // Query: "foo() OR bar"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"", Lexer::TokenType::RPAREN},
+ {"", Lexer::TokenType::OR},
+ {"bar", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // OR
+ // / \
+ // function member
+ // | |
+ // function_name text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { function_name, function, text, member, OR }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, Composite) {
+ // Query: "foo OR (bar baz)"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::OR},
+ {"", Lexer::TokenType::LPAREN}, {"bar", Lexer::TokenType::TEXT},
+ {"baz", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // OR
+ // / \
+ // member AND
+ // | / \
+ // text member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, text, member, AND, OR }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("baz", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, CompositeWithTrailingSequence) {
+ // Query: "(bar baz) OR foo"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"", Lexer::TokenType::LPAREN}, {"bar", Lexer::TokenType::TEXT},
+ {"baz", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::RPAREN},
+ {"", Lexer::TokenType::OR}, {"foo", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // OR
+ // / \
+ // AND member
+ // / \ |
+ // member member text
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, AND, text, member, OR }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("baz", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator),
+ EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, Complex) {
+ // Query: "foo bar:baz OR pal("bat")"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT},
+ {"bar", Lexer::TokenType::TEXT},
+ {":", Lexer::TokenType::COMPARATOR},
+ {"baz", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::OR},
+ {"pal", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"bat", Lexer::TokenType::STRING},
+ {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // AND
+ // / \
+ // member OR
+ // | / \
+ // text : function
+ // / \ / \
+ // member member function_name string
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ // SimpleVisitor ordering
+ // { text, member, text, member, text, member, :, function_name, string,
+ // function, OR, AND }
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("foo", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("bar", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("baz", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo(":", NodeType::kNaryOperator),
+ EqualsNodeInfo("pal", NodeType::kFunctionName),
+ EqualsNodeInfo("bat", NodeType::kString),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("OR", NodeType::kNaryOperator),
+ EqualsNodeInfo("AND", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, InvalidHas) {
+ // Query: "foo:" No right hand operand to :
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT}, {":", Lexer::TokenType::COMPARATOR}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, InvalidComposite) {
+ // Query: "(foo bar" No terminating RPAREN
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"", Lexer::TokenType::LPAREN},
+ {"foo", Lexer::TokenType::TEXT},
+ {"bar", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, InvalidMember) {
+ // Query: "foo." DOT must have succeeding TEXT
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DOT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, InvalidOr) {
+ // Query: "foo OR" No right hand operand to OR
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::OR}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, InvalidAnd) {
+ // Query: "foo AND" No right hand operand to AND
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::AND}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, InvalidNot) {
+ // Query: "NOT" No right hand operand to NOT
+ std::vector<Lexer::LexerToken> lexer_tokens = {{"", Lexer::TokenType::NOT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, InvalidMinus) {
+ // Query: "-" No right hand operand to -
+ std::vector<Lexer::LexerToken> lexer_tokens = {{"", Lexer::TokenType::MINUS}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, InvalidFunctionCallNoRparen) {
+ // Query: "foo(" No terminating RPAREN
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, InvalidFunctionCallNoLparen) {
+ // Query: "foo bar" foo labeled FUNCTION_NAME despite no LPAREN
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::FUNCTION_NAME},
+ {"bar", Lexer::TokenType::FUNCTION_NAME}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, InvalidFunctionArgsHangingComma) {
+ // Query: "foo("bar",)" no valid arg following COMMA
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"foo", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"bar", Lexer::TokenType::STRING},
+ {"", Lexer::TokenType::COMMA},
+ {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeQuery(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(ParserTest, ScoringPlus) {
+ // Scoring: "1 + 1 + 1"
+ std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::PLUS},
+ {"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::PLUS},
+ {"1", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // PLUS
+ // / | \
+ // member member member
+ // | | |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, ScoringMinus) {
+ // Scoring: "1 - 1 - 1"
+ std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::MINUS},
+ {"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::MINUS},
+ {"1", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // MINUS
+ // / | \
+ // member member member
+ // | | |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("MINUS", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, ScoringUnaryMinus) {
+ // Scoring: "1 + -1 + 1"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"1", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::PLUS},
+ {"", Lexer::TokenType::MINUS}, {"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::PLUS}, {"1", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // PLUS
+ // / | \
+ // member MINUS member
+ // | | |
+ // text member text
+ // |
+ // text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("MINUS", NodeType::kUnaryOperator),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, ScoringPlusMinus) {
+ // Scoring: "11 + 12 - 13 + 14"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"11", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::PLUS},
+ {"12", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::MINUS},
+ {"13", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::PLUS},
+ {"14", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // PLUS
+ // / \
+ // MINUS member
+ // / \ |
+ // PLUS member text
+ // / \ |
+ // member member text
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("11", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("12", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator),
+ EqualsNodeInfo("13", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("MINUS", NodeType::kNaryOperator),
+ EqualsNodeInfo("14", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, ScoringTimes) {
+ // Scoring: "1 * 1 * 1"
+ std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::TIMES},
+ {"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::TIMES},
+ {"1", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // TIMES
+ // / | \
+ // member member member
+ // | | |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("TIMES", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, ScoringDiv) {
+ // Scoring: "1 / 1 / 1"
+ std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::DIV},
+ {"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::DIV},
+ {"1", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // DIV
+ // / | \
+ // member member member
+ // | | |
+ // text text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("DIV", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, ScoringTimesDiv) {
+ // Scoring: "11 / 12 * 13 / 14 / 15"
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"11", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DIV},
+ {"12", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::TIMES},
+ {"13", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DIV},
+ {"14", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DIV},
+ {"15", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // DIV
+ // / | \
+ // TIMES member member
+ // / \ | |
+ // DIV member text text
+ // / \ |
+ // member member text
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("11", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("12", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("DIV", NodeType::kNaryOperator),
+ EqualsNodeInfo("13", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("TIMES", NodeType::kNaryOperator),
+ EqualsNodeInfo("14", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("15", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("DIV", NodeType::kNaryOperator)));
+}
+
+TEST(ParserTest, ComplexScoring) {
+ // Scoring: "1 + pow((2 * sin(3)), 4) + -5 / 6"
+ // With parentheses in function arguments.
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::PLUS},
+ {"pow", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"", Lexer::TokenType::LPAREN},
+ {"2", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::TIMES},
+ {"sin", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"3", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::RPAREN},
+ {"", Lexer::TokenType::RPAREN},
+ {"", Lexer::TokenType::COMMA},
+ {"4", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::RPAREN},
+ {"", Lexer::TokenType::PLUS},
+ {"", Lexer::TokenType::MINUS},
+ {"5", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::DIV},
+ {"6", Lexer::TokenType::TEXT},
+ };
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ std::vector<NodeInfo> node = visitor.nodes();
+ EXPECT_THAT(node,
+ ElementsAre(EqualsNodeInfo("1", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("pow", NodeType::kFunctionName),
+ EqualsNodeInfo("2", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("sin", NodeType::kFunctionName),
+ EqualsNodeInfo("3", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("TIMES", NodeType::kNaryOperator),
+ EqualsNodeInfo("4", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("5", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("MINUS", NodeType::kUnaryOperator),
+ EqualsNodeInfo("6", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("DIV", NodeType::kNaryOperator),
+ EqualsNodeInfo("PLUS", NodeType::kNaryOperator)));
+
+ // Scoring: "1 + pow(2 * sin(3), 4) + -5 / 6"
+ // Without parentheses in function arguments.
+ lexer_tokens = {
+ {"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::PLUS},
+ {"pow", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"2", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::TIMES},
+ {"sin", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"3", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::RPAREN},
+ {"", Lexer::TokenType::COMMA},
+ {"4", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::RPAREN},
+ {"", Lexer::TokenType::PLUS},
+ {"", Lexer::TokenType::MINUS},
+ {"5", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::DIV},
+ {"6", Lexer::TokenType::TEXT},
+ };
+ parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(tree_root, parser.ConsumeScoring());
+ visitor = SimpleVisitor();
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(), ElementsAreArray(node));
+}
+
+TEST(ParserTest, ScoringMemberFunction) {
+ // Scoring: this.CreationTimestamp()
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"this", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::DOT},
+ {"CreationTimestamp", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // member
+ // / \
+ // text function
+ // |
+ // function_name
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(
+ visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("this", NodeType::kText),
+ EqualsNodeInfo("CreationTimestamp", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+TEST(ParserTest, QueryMemberFunction) {
+ // Query: this.foo()
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"this", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::DOT},
+ {"foo", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // member
+ // / \
+ // text function
+ // |
+ // function_name
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("this", NodeType::kText),
+ EqualsNodeInfo("foo", NodeType::kFunctionName),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+TEST(ParserTest, ScoringComplexMemberFunction) {
+ // Scoring: a.b.fun(c, d)
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"a", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::DOT},
+ {"b", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::DOT},
+ {"fun", Lexer::TokenType::FUNCTION_NAME},
+ {"", Lexer::TokenType::LPAREN},
+ {"c", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::COMMA},
+ {"d", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ // Expected AST:
+ // member
+ // / | \
+ // text text function
+ // / | \
+ // function_name member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("a", NodeType::kText),
+ EqualsNodeInfo("b", NodeType::kText),
+ EqualsNodeInfo("fun", NodeType::kFunctionName),
+ EqualsNodeInfo("c", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("d", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+TEST(ParserTest, QueryComplexMemberFunction) {
+ // Query: this.abc.fun(def, ghi)
+ std::vector<Lexer::LexerToken> lexer_tokens = {
+ {"this", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DOT},
+ {"abc", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DOT},
+ {"fun", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN},
+ {"def", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::COMMA},
+ {"ghi", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::RPAREN}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ // Expected AST:
+ // member
+ // / | \
+ // text text function
+ // / | \
+ // function_name member member
+ // | |
+ // text text
+ SimpleVisitor visitor;
+ tree_root->Accept(&visitor);
+ EXPECT_THAT(visitor.nodes(),
+ ElementsAre(EqualsNodeInfo("this", NodeType::kText),
+ EqualsNodeInfo("abc", NodeType::kText),
+ EqualsNodeInfo("fun", NodeType::kFunctionName),
+ EqualsNodeInfo("def", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("ghi", NodeType::kText),
+ EqualsNodeInfo("", NodeType::kMember),
+ EqualsNodeInfo("", NodeType::kFunction),
+ EqualsNodeInfo("", NodeType::kMember)));
+}
+
+TEST(ParserTest, InvalidScoringToken) {
+ // Scoring: "1 + NOT 1"
+ std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT},
+ {"", Lexer::TokenType::PLUS},
+ {"", Lexer::TokenType::NOT},
+ {"1", Lexer::TokenType::TEXT}};
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ EXPECT_THAT(parser.ConsumeScoring(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc
new file mode 100644
index 0000000..21ce55b
--- /dev/null
+++ b/icing/query/advanced_query_parser/query-visitor.cc
@@ -0,0 +1,228 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/query/advanced_query_parser/query-visitor.h"
+
+#include <cstdint>
+#include <cstdlib>
+#include <limits>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/schema/section-manager.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+bool IsNumericComparator(std::string_view operator_text) {
+ if (operator_text.length() < 1 || operator_text.length() > 2) {
+ return false;
+ }
+ // TODO(tjbarron) decide how/if to support !=
+ return operator_text == "<" || operator_text == ">" ||
+ operator_text == "==" || operator_text == "<=" ||
+ operator_text == ">=";
+}
+
+bool IsSupportedOperator(std::string_view operator_text) {
+ return IsNumericComparator(operator_text);
+}
+
+} // namespace
+
+libtextclassifier3::StatusOr<int64_t> QueryVisitor::RetrieveIntValue() {
+ if (pending_values_.empty() || !pending_values_.top().holds_text()) {
+ return absl_ports::InvalidArgumentError("Unable to retrieve int value.");
+ }
+ std::string& value = pending_values_.top().text;
+ char* value_end;
+ int64_t int_value = std::strtoll(value.c_str(), &value_end, /*base=*/10);
+ if (value_end != value.c_str() + value.length()) {
+ return absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("Unable to parse \"", value, "\" as number."));
+ }
+ pending_values_.pop();
+ return int_value;
+}
+
+libtextclassifier3::StatusOr<std::string> QueryVisitor::RetrieveStringValue() {
+ if (pending_values_.empty() || !pending_values_.top().holds_text()) {
+ return absl_ports::InvalidArgumentError("Unable to retrieve string value.");
+ }
+ std::string string_value = std::move(pending_values_.top().text);
+ pending_values_.pop();
+ return string_value;
+}
+
+struct Int64Range {
+ int64_t low;
+ int64_t high;
+};
+
+libtextclassifier3::StatusOr<Int64Range> GetInt64Range(
+ std::string_view operator_text, int64_t int_value) {
+ Int64Range range = {std::numeric_limits<int64_t>::min(),
+ std::numeric_limits<int64_t>::max()};
+ if (operator_text == "<") {
+ if (int_value == std::numeric_limits<int64_t>::min()) {
+ return absl_ports::InvalidArgumentError(
+ "Cannot specify < INT64_MIN in query expression.");
+ }
+ range.high = int_value - 1;
+ } else if (operator_text == "<=") {
+ range.high = int_value;
+ } else if (operator_text == "==") {
+ range.high = int_value;
+ range.low = int_value;
+ } else if (operator_text == ">=") {
+ range.low = int_value;
+ } else if (operator_text == ">") {
+ if (int_value == std::numeric_limits<int64_t>::max()) {
+ return absl_ports::InvalidArgumentError(
+ "Cannot specify > INT64_MAX in query expression.");
+ }
+ range.low = int_value + 1;
+ }
+ return range;
+}
+
+libtextclassifier3::StatusOr<QueryVisitor::PendingValue>
+QueryVisitor::ProcessNumericComparator(const NaryOperatorNode* node) {
+ // 1. The children should have been processed and added their outputs to
+ // pending_values_. Time to process them.
+ // The first two pending values should be the int value and the property.
+ ICING_ASSIGN_OR_RETURN(int64_t int_value, RetrieveIntValue());
+ ICING_ASSIGN_OR_RETURN(std::string property, RetrieveStringValue());
+
+ // 2. Create the iterator.
+ ICING_ASSIGN_OR_RETURN(Int64Range range,
+ GetInt64Range(node->operator_text(), int_value));
+ auto iterator_or =
+ numeric_index_.GetIterator(property, range.low, range.high);
+ if (!iterator_or.ok()) {
+ return std::move(iterator_or).status();
+ }
+ std::unique_ptr<DocHitInfoIterator> iterator =
+ std::move(iterator_or).ValueOrDie();
+ return PendingValue(std::move(iterator));
+}
+
+void QueryVisitor::VisitFunctionName(const FunctionNameNode* node) {
+ pending_error_ = absl_ports::UnimplementedError(
+ "Function Name node visiting not implemented yet.");
+}
+
+void QueryVisitor::VisitString(const StringNode* node) {
+ pending_error_ = absl_ports::UnimplementedError(
+ "String node visiting not implemented yet.");
+}
+
+void QueryVisitor::VisitText(const TextNode* node) {
+ pending_values_.push(PendingValue(node->value()));
+}
+
+void QueryVisitor::VisitMember(const MemberNode* node) {
+ // 1. Put in a placeholder PendingValue
+ pending_values_.push(PendingValue());
+
+ // 2. Visit the children.
+ for (const std::unique_ptr<TextNode>& child : node->children()) {
+ child->Accept(this);
+ if (has_pending_error()) {
+ return;
+ }
+ }
+
+ // 3. The children should have been processed and added their outputs to
+ // pending_values_. Time to process them.
+ std::string member = std::move(pending_values_.top().text);
+ pending_values_.pop();
+ while (!pending_values_.empty() && !pending_values_.top().is_placeholder()) {
+ member = absl_ports::StrCat(pending_values_.top().text, kPropertySeparator,
+ member);
+ pending_values_.pop();
+ }
+
+ // 4. If pending_values_ is empty somehow, then our placeholder disappeared
+ // somehow.
+ if (pending_values_.empty()) {
+ pending_error_ = absl_ports::InvalidArgumentError(
+ "\"<\" operator must have two arguments.");
+ return;
+ }
+ pending_values_.pop();
+
+ pending_values_.push(PendingValue(std::move(member)));
+}
+
+void QueryVisitor::VisitFunction(const FunctionNode* node) {
+ pending_error_ = absl_ports::UnimplementedError(
+ "Function node visiting not implemented yet.");
+}
+
+void QueryVisitor::VisitUnaryOperator(const UnaryOperatorNode* node) {
+ pending_error_ =
+ absl_ports::UnimplementedError("Not node visiting not implemented yet.");
+}
+
+void QueryVisitor::VisitNaryOperator(const NaryOperatorNode* node) {
+ if (has_pending_error()) {
+ return;
+ }
+
+ if (!IsSupportedOperator(node->operator_text())) {
+ pending_error_ = absl_ports::UnimplementedError(
+ "No support for any non-numeric operators.");
+ return;
+ }
+
+ // 1. Put in a placeholder PendingValue
+ pending_values_.push(PendingValue());
+
+ // 2. Visit the children.
+ for (const std::unique_ptr<Node>& child : node->children()) {
+ child->Accept(this);
+ if (has_pending_error()) {
+ return;
+ }
+ }
+
+ // 3. Retrieve the pending value for this node.
+ PendingValue pending_value;
+ if (IsNumericComparator(node->operator_text())) {
+ auto pending_value_or = ProcessNumericComparator(node);
+ if (!pending_value_or.ok()) {
+ pending_error_ = std::move(pending_value_or).status();
+ return;
+ }
+ pending_value = std::move(pending_value_or).ValueOrDie();
+ }
+
+ // 4. Check for the placeholder.
+ if (!pending_values_.top().is_placeholder()) {
+ pending_error_ = absl_ports::InvalidArgumentError(
+ "Error processing arguments for node.");
+ return;
+ }
+ pending_values_.pop();
+
+ pending_values_.push(std::move(pending_value));
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/advanced_query_parser/query-visitor.h b/icing/query/advanced_query_parser/query-visitor.h
new file mode 100644
index 0000000..e834606
--- /dev/null
+++ b/icing/query/advanced_query_parser/query-visitor.h
@@ -0,0 +1,119 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_QUERY_VISITOR_H_
+#define ICING_QUERY_ADVANCED_QUERY_PARSER_QUERY_VISITOR_H_
+
+#include <cstdint>
+#include <memory>
+#include <stack>
+#include <string>
+
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/numeric-index.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+
+namespace icing {
+namespace lib {
+
+// The Visitor used to create the DocHitInfoIterator tree from the AST output by
+// the parser.
+class QueryVisitor : public AbstractSyntaxTreeVisitor {
+ public:
+ explicit QueryVisitor(const NumericIndex<int64_t>* numeric_index)
+ : numeric_index_(*numeric_index) {}
+
+ void VisitFunctionName(const FunctionNameNode* node) override;
+ void VisitString(const StringNode* node) override;
+ void VisitText(const TextNode* node) override;
+ void VisitMember(const MemberNode* node) override;
+ void VisitFunction(const FunctionNode* node) override;
+ void VisitUnaryOperator(const UnaryOperatorNode* node) override;
+ void VisitNaryOperator(const NaryOperatorNode* node) override;
+
+ // RETURNS:
+ // - the DocHitInfoIterator that is the root of the query iterator tree
+ // - INVALID_ARGUMENT if the AST does not conform to supported expressions
+ libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> root() && {
+ if (has_pending_error()) {
+ return pending_error_;
+ }
+ if (pending_values_.size() != 1 ||
+ !pending_values_.top().holds_iterator()) {
+ return absl_ports::InvalidArgumentError(
+ "Visitor does not contain a single root iterator.");
+ }
+ return std::move(pending_values_.top().iterator);
+ }
+
+ private:
+ // A holder for intermediate results when processing child nodes.
+ struct PendingValue {
+ PendingValue() = default;
+
+ explicit PendingValue(std::unique_ptr<DocHitInfoIterator> iterator)
+ : iterator(std::move(iterator)) {}
+
+ explicit PendingValue(std::string text) : text(std::move(text)) {}
+
+ // Placeholder is used to indicate where the children of a particular node
+ // begin.
+ bool is_placeholder() const { return iterator == nullptr && text.empty(); }
+
+ bool holds_text() const { return iterator == nullptr && !text.empty(); }
+
+ bool holds_iterator() const { return iterator != nullptr && text.empty(); }
+
+ std::unique_ptr<DocHitInfoIterator> iterator;
+ std::string text;
+ };
+
+ bool has_pending_error() const { return !pending_error_.ok(); }
+
+ // Processes the PendingValue at the top of pending_values_, parses it into a
+ // int64_t and pops the top.
+ // Returns:
+ // - On success, the int value stored in the text at the top
+ // - INVALID_ARGUMENT if pending_values_ is empty, doesn't hold a text or
+ // can't be parsed as an int.
+ libtextclassifier3::StatusOr<int64_t> RetrieveIntValue();
+
+ // Processes the PendingValue at the top of pending_values_ and pops the top.
+ // Returns:
+ // - On success, the string value stored in the text at the top
+ // - INVALID_ARGUMENT if pending_values_ is empty or doesn't hold a text.
+ libtextclassifier3::StatusOr<std::string> RetrieveStringValue();
+
+ // Processes the NumericComparator represented by node. This must be called
+ // *after* this node's children have been visited. The PendingValues added by
+ // this node's children will be consumed by this function and the PendingValue
+ // for this node will be returned.
+ // Returns:
+ // - On success, then PendingValue representing this node and it's children.
+ // - INVALID_ARGUMENT if unable to retrieve string value or int value
+ // - NOT_FOUND if there is no entry in the numeric index for the property
+ libtextclassifier3::StatusOr<PendingValue> ProcessNumericComparator(
+ const NaryOperatorNode* node);
+
+ std::stack<PendingValue> pending_values_;
+ libtextclassifier3::Status pending_error_;
+
+ const NumericIndex<int64>& numeric_index_; // Does not own!
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_QUERY_VISITOR_H_
diff --git a/icing/query/advanced_query_parser/query-visitor_test.cc b/icing/query/advanced_query_parser/query-visitor_test.cc
new file mode 100644
index 0000000..1e456fe
--- /dev/null
+++ b/icing/query/advanced_query_parser/query-visitor_test.cc
@@ -0,0 +1,557 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/query/advanced_query_parser/query-visitor.h"
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
+#include "icing/index/numeric/dummy-numeric-index.h"
+#include "icing/index/numeric/numeric-index.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/query/advanced_query_parser/lexer.h"
+#include "icing/query/advanced_query_parser/parser.h"
+#include "icing/testing/common-matchers.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+
+constexpr DocumentId kDocumentId0 = 0;
+constexpr DocumentId kDocumentId1 = 1;
+constexpr DocumentId kDocumentId2 = 2;
+
+constexpr SectionId kSectionId0 = 0;
+constexpr SectionId kSectionId1 = 1;
+constexpr SectionId kSectionId2 = 2;
+
+TEST(QueryVisitorTest, SimpleLessThan) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2
+ // respectively.
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("price", kDocumentId0, kSectionId0);
+ editor->BufferKey(0);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId1, kSectionId1);
+ editor->BufferKey(1);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId2, kSectionId2);
+ editor->BufferKey(2);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "price < 2";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
+ std::move(query_visitor).root());
+
+ EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ElementsAre(kDocumentId1, kDocumentId0));
+}
+
+TEST(QueryVisitorTest, SimpleLessThanEq) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2
+ // respectively.
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("price", kDocumentId0, kSectionId0);
+ editor->BufferKey(0);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId1, kSectionId1);
+ editor->BufferKey(1);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId2, kSectionId2);
+ editor->BufferKey(2);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "price <= 1";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
+ std::move(query_visitor).root());
+
+ EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ElementsAre(kDocumentId1, kDocumentId0));
+}
+
+TEST(QueryVisitorTest, SimpleEqual) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2
+ // respectively.
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("price", kDocumentId0, kSectionId0);
+ editor->BufferKey(0);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId1, kSectionId1);
+ editor->BufferKey(1);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId2, kSectionId2);
+ editor->BufferKey(2);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "price == 2";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
+ std::move(query_visitor).root());
+
+ EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+}
+
+TEST(QueryVisitorTest, SimpleGreaterThanEq) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2
+ // respectively.
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("price", kDocumentId0, kSectionId0);
+ editor->BufferKey(0);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId1, kSectionId1);
+ editor->BufferKey(1);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId2, kSectionId2);
+ editor->BufferKey(2);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "price >= 1";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
+ std::move(query_visitor).root());
+
+ EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ElementsAre(kDocumentId2, kDocumentId1));
+}
+
+TEST(QueryVisitorTest, SimpleGreaterThan) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2
+ // respectively.
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("price", kDocumentId0, kSectionId0);
+ editor->BufferKey(0);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId1, kSectionId1);
+ editor->BufferKey(1);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId2, kSectionId2);
+ editor->BufferKey(2);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "price > 1";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
+ std::move(query_visitor).root());
+
+ EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2));
+}
+
+// TODO(b/208654892) Properly handle negative numbers in query expressions.
+TEST(QueryVisitorTest, DISABLED_IntMinLessThanEqual) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values INT_MIN,
+ // INT_MAX and INT_MIN + 1 respectively.
+ int64_t int_min = std::numeric_limits<int64_t>::min();
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("price", kDocumentId0, kSectionId0);
+ editor->BufferKey(int_min);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId1, kSectionId1);
+ editor->BufferKey(std::numeric_limits<int64_t>::max());
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId2, kSectionId2);
+ editor->BufferKey(int_min + 1);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "price <= " + std::to_string(int_min);
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
+ std::move(query_visitor).root());
+
+ EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0));
+}
+
+TEST(QueryVisitorTest, IntMaxGreaterThanEqual) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values INT_MIN,
+ // INT_MAX and INT_MAX - 1 respectively.
+ int64_t int_max = std::numeric_limits<int64_t>::max();
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("price", kDocumentId0, kSectionId0);
+ editor->BufferKey(std::numeric_limits<int64_t>::min());
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId1, kSectionId1);
+ editor->BufferKey(int_max);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId2, kSectionId2);
+ editor->BufferKey(int_max - 1);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "price >= " + std::to_string(int_max);
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
+ std::move(query_visitor).root());
+
+ EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1));
+}
+
+TEST(QueryVisitorTest, NestedPropertyLessThan) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2
+ // respectively.
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("subscription.price", kDocumentId0, kSectionId0);
+ editor->BufferKey(0);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("subscription.price", kDocumentId1, kSectionId1);
+ editor->BufferKey(1);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("subscription.price", kDocumentId2, kSectionId2);
+ editor->BufferKey(2);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "subscription.price < 2";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator,
+ std::move(query_visitor).root());
+
+ EXPECT_THAT(GetDocumentIds(root_iterator.get()),
+ ElementsAre(kDocumentId1, kDocumentId0));
+}
+
+TEST(QueryVisitorTest, IntParsingError) {
+ DummyNumericIndex<int64_t> numeric_index;
+
+ std::string query = "subscription.price < fruit";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).root(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(QueryVisitorTest, NotEqualsUnsupported) {
+ DummyNumericIndex<int64_t> numeric_index;
+
+ std::string query = "subscription.price != 3";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).root(),
+ StatusIs(libtextclassifier3::StatusCode::UNIMPLEMENTED));
+}
+
+TEST(QueryVisitorTest, UnrecognizedOperatorTooLongUnsupported) {
+ DummyNumericIndex<int64_t> numeric_index;
+
+ // Create an AST for the query 'subscription.price !<= 3'
+ std::string query = "subscription.price !<= 3";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // There is no support for the 'not less than or equal to' operator.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).root(),
+ StatusIs(libtextclassifier3::StatusCode::UNIMPLEMENTED));
+}
+
+TEST(QueryVisitorTest, LessThanTooManyOperandsInvalid) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2
+ // respectively.
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("subscription.price", kDocumentId0, kSectionId0);
+ editor->BufferKey(0);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("subscription.price", kDocumentId1, kSectionId1);
+ editor->BufferKey(1);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("subscription.price", kDocumentId2, kSectionId2);
+ editor->BufferKey(2);
+ editor->IndexAllBufferedKeys();
+
+ // Create an invalid AST for the query '3 < subscription.price 25' where '<'
+ // has three operands
+ auto property_node = std::make_unique<TextNode>("subscription");
+ auto subproperty_node = std::make_unique<TextNode>("price");
+ std::vector<std::unique_ptr<TextNode>> member_args;
+ member_args.push_back(std::move(property_node));
+ member_args.push_back(std::move(subproperty_node));
+ auto member_node = std::make_unique<MemberNode>(std::move(member_args),
+ /*function=*/nullptr);
+
+ auto value_node = std::make_unique<TextNode>("3");
+ auto extra_value_node = std::make_unique<TextNode>("25");
+ std::vector<std::unique_ptr<Node>> args;
+ args.push_back(std::move(value_node));
+ args.push_back(std::move(member_node));
+ args.push_back(std::move(extra_value_node));
+ auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args));
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).root(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(QueryVisitorTest, LessThanTooFewOperandsInvalid) {
+ DummyNumericIndex<int64_t> numeric_index;
+
+ // Create an invalid AST for the query 'subscription.price <' where '<'
+ // has a single operand
+ auto property_node = std::make_unique<TextNode>("subscription");
+ auto subproperty_node = std::make_unique<TextNode>("price");
+ std::vector<std::unique_ptr<TextNode>> member_args;
+ member_args.push_back(std::move(property_node));
+ member_args.push_back(std::move(subproperty_node));
+ auto member_node = std::make_unique<MemberNode>(std::move(member_args),
+ /*function=*/nullptr);
+
+ std::vector<std::unique_ptr<Node>> args;
+ args.push_back(std::move(member_node));
+ auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args));
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).root(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(QueryVisitorTest, LessThanNonExistentPropertyNotFound) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2
+ // respectively.
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("subscription.price", kDocumentId0, kSectionId0);
+ editor->BufferKey(0);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("subscription.price", kDocumentId1, kSectionId1);
+ editor->BufferKey(1);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("subscription.price", kDocumentId2, kSectionId2);
+ editor->BufferKey(2);
+ editor->IndexAllBufferedKeys();
+
+ // Create an invalid AST for the query 'time < 25' where '<'
+ // has three operands
+ std::string query = "time < 25";
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).root(),
+ StatusIs(libtextclassifier3::StatusCode::NOT_FOUND));
+}
+
+TEST(QueryVisitorTest, NeverVisitedReturnsInvalid) {
+ DummyNumericIndex<int64_t> numeric_index;
+ QueryVisitor query_visitor(&numeric_index);
+ EXPECT_THAT(std::move(query_visitor).root(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+// TODO(b/208654892) Properly handle negative numbers in query expressions.
+TEST(QueryVisitorTest, DISABLED_IntMinLessThanInvalid) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values INT_MIN,
+ // INT_MAX and INT_MIN + 1 respectively.
+ int64_t int_min = std::numeric_limits<int64_t>::min();
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("price", kDocumentId0, kSectionId0);
+ editor->BufferKey(int_min);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId1, kSectionId1);
+ editor->BufferKey(std::numeric_limits<int64_t>::max());
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId2, kSectionId2);
+ editor->BufferKey(int_min + 1);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "price <" + std::to_string(int_min);
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).root(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST(QueryVisitorTest, IntMaxGreaterThanInvalid) {
+ // Setup the numeric index with docs 0, 1 and 2 holding the values INT_MIN,
+ // INT_MAX and INT_MAX - 1 respectively.
+ int64_t int_max = std::numeric_limits<int64_t>::max();
+ DummyNumericIndex<int64_t> numeric_index;
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index.Edit("price", kDocumentId0, kSectionId0);
+ editor->BufferKey(std::numeric_limits<int64_t>::min());
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId1, kSectionId1);
+ editor->BufferKey(int_max);
+ editor->IndexAllBufferedKeys();
+
+ editor = numeric_index.Edit("price", kDocumentId2, kSectionId2);
+ editor->BufferKey(int_max - 1);
+ editor->IndexAllBufferedKeys();
+
+ std::string query = "price >" + std::to_string(int_max);
+ Lexer lexer(query, Lexer::Language::QUERY);
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node,
+ parser.ConsumeQuery());
+
+ // Retrieve the root_iterator from the visitor.
+ QueryVisitor query_visitor(&numeric_index);
+ root_node->Accept(&query_visitor);
+ EXPECT_THAT(std::move(query_visitor).root(),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc
index 90587aa..abef9e4 100644
--- a/icing/query/query-processor.cc
+++ b/icing/query/query-processor.cc
@@ -35,12 +35,12 @@
#include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/proto/search.pb.h"
-
-#ifdef ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY
-#include "icing/query/advanced-query-processor.h"
-#endif // ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY
-
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/query/advanced_query_parser/lexer.h"
+#include "icing/query/advanced_query_parser/parser.h"
+#include "icing/query/advanced_query_parser/query-visitor.h"
#include "icing/query/query-processor.h"
+#include "icing/query/query-results.h"
#include "icing/query/query-terms.h"
#include "icing/query/query-utils.h"
#include "icing/schema/schema-store.h"
@@ -107,27 +107,31 @@ std::unique_ptr<DocHitInfoIterator> ProcessParserStateFrame(
} // namespace
libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>>
-QueryProcessor::Create(Index* index,
+QueryProcessor::Create(Index* index, const NumericIndex<int64_t>* numeric_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
const SchemaStore* schema_store) {
ICING_RETURN_ERROR_IF_NULL(index);
+ ICING_RETURN_ERROR_IF_NULL(numeric_index);
ICING_RETURN_ERROR_IF_NULL(language_segmenter);
ICING_RETURN_ERROR_IF_NULL(normalizer);
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
- return std::unique_ptr<QueryProcessor>(new QueryProcessor(
- index, language_segmenter, normalizer, document_store, schema_store));
+ return std::unique_ptr<QueryProcessor>(
+ new QueryProcessor(index, numeric_index, language_segmenter, normalizer,
+ document_store, schema_store));
}
QueryProcessor::QueryProcessor(Index* index,
+ const NumericIndex<int64_t>* numeric_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
const SchemaStore* schema_store)
: index_(*index),
+ numeric_index_(*numeric_index),
language_segmenter_(*language_segmenter),
normalizer_(*normalizer),
document_store_(*document_store),
@@ -145,22 +149,19 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch(
QueryResults results;
if (search_spec.search_type() ==
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
-#ifdef ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY
ICING_VLOG(1) << "Using EXPERIMENTAL_ICING_ADVANCED_QUERY parser!";
- ICING_ASSIGN_OR_RETURN(
- std::unique_ptr<AdvancedQueryProcessor> advanced_query_processor,
- AdvancedQueryProcessor::Create(&index_, &language_segmenter_,
- &normalizer_, &document_store_,
- &schema_store_));
- ICING_ASSIGN_OR_RETURN(results, advanced_query_processor->ParseSearch(
- search_spec, ranking_strategy));
-#else // !ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY
- ICING_LOG(ERROR) << "Requested EXPERIMENTAL_ICING_ADVANCED_QUERY search "
- "type, but advanced query is not compiled in. Falling "
- "back to ICING_RAW_QUERY.";
- ICING_ASSIGN_OR_RETURN(results,
- ParseRawQuery(search_spec, ranking_strategy));
-#endif // ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY
+ libtextclassifier3::StatusOr<QueryResults> results_or =
+ ParseAdvancedQuery(search_spec);
+ if (results_or.ok()) {
+ results = std::move(results_or).ValueOrDie();
+ } else {
+ ICING_VLOG(1)
+ << "Unable to parse query using advanced query parser. Error: "
+ << results_or.status().error_message()
+ << ". Falling back to old query parser.";
+ ICING_ASSIGN_OR_RETURN(results,
+ ParseRawQuery(search_spec, ranking_strategy));
+ }
} else {
ICING_ASSIGN_OR_RETURN(results,
ParseRawQuery(search_spec, ranking_strategy));
@@ -173,6 +174,30 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch(
return results;
}
+libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery(
+ const SearchSpecProto& search_spec) const {
+ QueryResults results;
+ Lexer lexer(search_spec.query(), Lexer::Language::QUERY);
+ ICING_ASSIGN_OR_RETURN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeQuery());
+
+ if (tree_root == nullptr) {
+ results.root_iterator = std::make_unique<DocHitInfoIteratorAllDocumentId>(
+ document_store_.last_added_document_id());
+ return results;
+ }
+
+ QueryVisitor query_visitor(&numeric_index_);
+ tree_root->Accept(&query_visitor);
+ ICING_ASSIGN_OR_RETURN(results.root_iterator,
+ std::move(query_visitor).root());
+ return results;
+}
+
// TODO(cassiewang): Collect query stats to populate the SearchResultsProto
libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseRawQuery(
const SearchSpecProto& search_spec,
diff --git a/icing/query/query-processor.h b/icing/query/query-processor.h
index f544a7a..a4f8973 100644
--- a/icing/query/query-processor.h
+++ b/icing/query/query-processor.h
@@ -15,12 +15,14 @@
#ifndef ICING_QUERY_QUERY_PROCESSOR_H_
#define ICING_QUERY_QUERY_PROCESSOR_H_
+#include <cstdint>
#include <memory>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-filter.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/numeric-index.h"
#include "icing/proto/search.pb.h"
#include "icing/query/query-results.h"
#include "icing/query/query-terms.h"
@@ -45,9 +47,9 @@ class QueryProcessor {
// An QueryProcessor on success
// FAILED_PRECONDITION if any of the pointers is null.
static libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>> Create(
- Index* index, const LanguageSegmenter* language_segmenter,
- const Normalizer* normalizer, const DocumentStore* document_store,
- const SchemaStore* schema_store);
+ Index* index, const NumericIndex<int64_t>* numeric_index,
+ const LanguageSegmenter* language_segmenter, const Normalizer* normalizer,
+ const DocumentStore* document_store, const SchemaStore* schema_store);
// Parse the search configurations (including the query, any additional
// filters, etc.) in the SearchSpecProto into one DocHitInfoIterator.
@@ -69,12 +71,23 @@ class QueryProcessor {
private:
explicit QueryProcessor(Index* index,
+ const NumericIndex<int64_t>* numeric_index,
const LanguageSegmenter* language_segmenter,
const Normalizer* normalizer,
const DocumentStore* document_store,
const SchemaStore* schema_store);
// Parse the query into a one DocHitInfoIterator that represents the root of a
+ // query tree in our new Advanced Query Language.
+ //
+ // Returns:
+ // On success,
+ // - One iterator that represents the entire query
+ // INVALID_ARGUMENT if query syntax is incorrect and cannot be tokenized
+ libtextclassifier3::StatusOr<QueryResults> ParseAdvancedQuery(
+ const SearchSpecProto& search_spec) const;
+
+ // Parse the query into a one DocHitInfoIterator that represents the root of a
// query tree.
//
// Returns:
@@ -90,6 +103,7 @@ class QueryProcessor {
// Not const because we could modify/sort the hit buffer in the lite index at
// query time.
Index& index_;
+ const NumericIndex<int64_t>& numeric_index_;
const LanguageSegmenter& language_segmenter_;
const Normalizer& normalizer_;
const DocumentStore& document_store_;
diff --git a/icing/query/query-processor_benchmark.cc b/icing/query/query-processor_benchmark.cc
index 2015d81..6d776ce 100644
--- a/icing/query/query-processor_benchmark.cc
+++ b/icing/query/query-processor_benchmark.cc
@@ -17,6 +17,8 @@
#include "third_party/absl/flags/flag.h"
#include "icing/document-builder.h"
#include "icing/index/index.h"
+#include "icing/index/numeric/dummy-numeric-index.h"
+#include "icing/index/numeric/numeric-index.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/search.pb.h"
#include "icing/proto/term.pb.h"
@@ -113,6 +115,9 @@ void BM_QueryOneTerm(benchmark::State& state) {
std::unique_ptr<Index> index =
CreateIndex(icing_filesystem, filesystem, index_dir);
+ // TODO(b/249829533): switch to use persistent numeric index.
+ auto numeric_index = std::make_unique<DummyNumericIndex<int64_t>>();
+
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
language_segmenter_factory::Create(std::move(options)).ValueOrDie();
@@ -147,9 +152,9 @@ void BM_QueryOneTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index.get(), language_segmenter.get(),
- normalizer.get(), document_store.get(),
- schema_store.get()));
+ QueryProcessor::Create(index.get(), numeric_index.get(),
+ language_segmenter.get(), normalizer.get(),
+ document_store.get(), schema_store.get()));
SearchSpecProto search_spec;
search_spec.set_query(input_string);
@@ -233,6 +238,9 @@ void BM_QueryFiveTerms(benchmark::State& state) {
std::unique_ptr<Index> index =
CreateIndex(icing_filesystem, filesystem, index_dir);
+ // TODO(b/249829533): switch to use persistent numeric index.
+ auto numeric_index = std::make_unique<DummyNumericIndex<int64_t>>();
+
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
language_segmenter_factory::Create(std::move(options)).ValueOrDie();
@@ -281,9 +289,9 @@ void BM_QueryFiveTerms(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index.get(), language_segmenter.get(),
- normalizer.get(), document_store.get(),
- schema_store.get()));
+ QueryProcessor::Create(index.get(), numeric_index.get(),
+ language_segmenter.get(), normalizer.get(),
+ document_store.get(), schema_store.get()));
const std::string query_string = absl_ports::StrCat(
input_string_a, " ", input_string_b, " ", input_string_c, " ",
@@ -371,6 +379,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) {
std::unique_ptr<Index> index =
CreateIndex(icing_filesystem, filesystem, index_dir);
+ // TODO(b/249829533): switch to use persistent numeric index.
+ auto numeric_index = std::make_unique<DummyNumericIndex<int64_t>>();
+
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
language_segmenter_factory::Create(std::move(options)).ValueOrDie();
@@ -408,9 +419,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index.get(), language_segmenter.get(),
- normalizer.get(), document_store.get(),
- schema_store.get()));
+ QueryProcessor::Create(index.get(), numeric_index.get(),
+ language_segmenter.get(), normalizer.get(),
+ document_store.get(), schema_store.get()));
SearchSpecProto search_spec;
search_spec.set_query(input_string);
@@ -494,6 +505,9 @@ void BM_QueryHiragana(benchmark::State& state) {
std::unique_ptr<Index> index =
CreateIndex(icing_filesystem, filesystem, index_dir);
+ // TODO(b/249829533): switch to use persistent numeric index.
+ auto numeric_index = std::make_unique<DummyNumericIndex<int64_t>>();
+
language_segmenter_factory::SegmenterOptions options(ULOC_US);
std::unique_ptr<LanguageSegmenter> language_segmenter =
language_segmenter_factory::Create(std::move(options)).ValueOrDie();
@@ -531,9 +545,9 @@ void BM_QueryHiragana(benchmark::State& state) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index.get(), language_segmenter.get(),
- normalizer.get(), document_store.get(),
- schema_store.get()));
+ QueryProcessor::Create(index.get(), numeric_index.get(),
+ language_segmenter.get(), normalizer.get(),
+ document_store.get(), schema_store.get()));
SearchSpecProto search_spec;
search_spec.set_query(input_string);
diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc
index da35df8..9f1386c 100644
--- a/icing/query/query-processor_test.cc
+++ b/icing/query/query-processor_test.cc
@@ -14,6 +14,7 @@
#include "icing/query/query-processor.h"
+#include <cstdint>
#include <memory>
#include <string>
@@ -26,13 +27,14 @@
#include "icing/index/index.h"
#include "icing/index/iterator/doc-hit-info-iterator-test-util.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/index/numeric/dummy-numeric-index.h"
+#include "icing/index/numeric/numeric-index.h"
#include "icing/jni/jni-cache.h"
#include "icing/legacy/index/icing-filesystem.h"
#include "icing/portable/platform.h"
#include "icing/proto/schema.pb.h"
#include "icing/proto/search.pb.h"
#include "icing/proto/term.pb.h"
-#include "icing/query/query-processor.h"
#include "icing/schema-builder.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
@@ -59,7 +61,6 @@ using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::IsEmpty;
using ::testing::SizeIs;
-using ::testing::Test;
using ::testing::UnorderedElementsAre;
class QueryProcessorTest
@@ -88,11 +89,22 @@ class QueryProcessorTest
icu_data_file_helper::SetUpICUDataFile(
GetTestFilePath("icing/icu.dat")));
}
+ ICING_ASSERT_OK_AND_ASSIGN(
+ schema_store_,
+ SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentStore::CreateResult create_result,
+ DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
+ schema_store_.get()));
+ document_store_ = std::move(create_result.document_store);
Index::Options options(index_dir_,
/*index_merge_size=*/1024 * 1024);
ICING_ASSERT_OK_AND_ASSIGN(
index_, Index::Create(options, &filesystem_, &icing_filesystem_));
+ // TODO(b/249829533): switch to use persistent numeric index.
+ numeric_index_ = std::make_unique<DummyNumericIndex<int64_t>>();
language_segmenter_factory::SegmenterOptions segmenter_options(
ULOC_US, jni_cache_.get());
@@ -102,6 +114,12 @@ class QueryProcessorTest
ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create(
/*max_term_byte_size=*/1000));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ query_processor_,
+ QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ language_segmenter_.get(), normalizer_.get(),
+ document_store_.get(), schema_store_.get()));
}
libtextclassifier3::Status AddTokenToIndex(
@@ -113,6 +131,16 @@ class QueryProcessorTest
return status.ok() ? editor.IndexAllBufferedTerms() : status;
}
+ libtextclassifier3::Status AddToNumericIndex(DocumentId document_id,
+ const std::string& property,
+ SectionId section_id,
+ int64_t value) {
+ std::unique_ptr<NumericIndex<int64_t>::Editor> editor =
+ numeric_index_->Edit(property, document_id, section_id);
+ ICING_RETURN_IF_ERROR(editor->BufferKey(value));
+ return editor->IndexAllBufferedKeys();
+ }
+
void TearDown() override {
document_store_.reset();
schema_store_.reset();
@@ -129,35 +157,44 @@ class QueryProcessorTest
protected:
std::unique_ptr<Index> index_;
+ std::unique_ptr<NumericIndex<int64_t>> numeric_index_;
std::unique_ptr<LanguageSegmenter> language_segmenter_;
std::unique_ptr<Normalizer> normalizer_;
FakeClock fake_clock_;
std::unique_ptr<const JniCache> jni_cache_ = GetTestJniCache();
std::unique_ptr<SchemaStore> schema_store_;
std::unique_ptr<DocumentStore> document_store_;
+ std::unique_ptr<QueryProcessor> query_processor_;
};
TEST_P(QueryProcessorTest, CreationWithNullPointerShouldFail) {
EXPECT_THAT(
- QueryProcessor::Create(/*index=*/nullptr, language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()),
+ QueryProcessor::Create(/*index=*/nullptr, numeric_index_.get(),
+ language_segmenter_.get(), normalizer_.get(),
+ document_store_.get(), schema_store_.get()),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(
- QueryProcessor::Create(index_.get(), /*language_segmenter=*/nullptr,
- normalizer_.get(), document_store_.get(),
- schema_store_.get()),
+ QueryProcessor::Create(index_.get(), /*numeric_index_=*/nullptr,
+ language_segmenter_.get(), normalizer_.get(),
+ document_store_.get(), schema_store_.get()),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
EXPECT_THAT(
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- /*normalizer=*/nullptr, document_store_.get(),
- schema_store_.get()),
+ QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ /*language_segmenter=*/nullptr, normalizer_.get(),
+ document_store_.get(), schema_store_.get()),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
- EXPECT_THAT(QueryProcessor::Create(
- index_.get(), language_segmenter_.get(), normalizer_.get(),
- /*document_store=*/nullptr, schema_store_.get()),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
- EXPECT_THAT(QueryProcessor::Create(index_.get(), language_segmenter_.get(),
+ EXPECT_THAT(
+ QueryProcessor::Create(
+ index_.get(), numeric_index_.get(), language_segmenter_.get(),
+ /*normalizer=*/nullptr, document_store_.get(), schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(
+ QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ language_segmenter_.get(), normalizer_.get(),
+ /*document_store=*/nullptr, schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ language_segmenter_.get(),
normalizer_.get(), document_store_.get(),
/*schema_store=*/nullptr),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
@@ -168,18 +205,8 @@ TEST_P(QueryProcessorTest, EmptyGroupMatchAllDocuments) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1,
document_store_->Put(DocumentBuilder()
.SetKey("namespace", "1")
@@ -193,22 +220,14 @@ TEST_P(QueryProcessorTest, EmptyGroupMatchAllDocuments) {
// We don't need to insert anything in the index since the empty query will
// match all DocumentIds from the DocumentStore
-
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("()");
search_spec.set_search_type(GetParam());
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(search_spec,
- ScoringSpecProto::RankingStrategy::NONE));
+ query_processor_->ParseSearch(search_spec,
+ ScoringSpecProto::RankingStrategy::NONE));
// Descending order of valid DocumentIds
EXPECT_THAT(GetDocumentIds(results.root_iterator.get()),
@@ -222,18 +241,8 @@ TEST_P(QueryProcessorTest, EmptyQueryMatchAllDocuments) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1,
document_store_->Put(DocumentBuilder()
.SetKey("namespace", "1")
@@ -247,22 +256,14 @@ TEST_P(QueryProcessorTest, EmptyQueryMatchAllDocuments) {
// We don't need to insert anything in the index since the empty query will
// match all DocumentIds from the DocumentStore
-
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("");
search_spec.set_search_type(GetParam());
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(search_spec,
- ScoringSpecProto::RankingStrategy::NONE));
+ query_processor_->ParseSearch(search_spec,
+ ScoringSpecProto::RankingStrategy::NONE));
// Descending order of valid DocumentIds
EXPECT_THAT(GetDocumentIds(results.root_iterator.get()),
@@ -276,18 +277,8 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// namespaces populated.
@@ -310,13 +301,6 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) {
AddTokenToIndex(document_id, section_id, term_match_type, "world"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("hElLo WORLD");
search_spec.set_term_match_type(term_match_type);
@@ -324,7 +308,7 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
std::vector<TermMatchInfo> matched_terms_stats;
@@ -354,18 +338,8 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// namespaces populated.
@@ -385,13 +359,6 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) {
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("he");
search_spec.set_term_match_type(term_match_type);
@@ -399,7 +366,7 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
std::vector<TermMatchInfo> matched_terms_stats;
@@ -424,18 +391,8 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatchWithMaxSectionID) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// namespaces populated.
@@ -456,13 +413,6 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatchWithMaxSectionID) {
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("he");
search_spec.set_term_match_type(term_match_type);
@@ -470,7 +420,7 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatchWithMaxSectionID) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
std::vector<TermMatchInfo> matched_terms_stats;
@@ -495,18 +445,8 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// namespaces populated.
@@ -526,13 +466,6 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) {
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("hello");
search_spec.set_term_match_type(term_match_type);
@@ -540,7 +473,7 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
std::vector<TermMatchInfo> matched_terms_stats;
@@ -565,18 +498,8 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -596,13 +519,6 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) {
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("hello hello");
search_spec.set_term_match_type(term_match_type);
@@ -610,7 +526,7 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
std::vector<TermMatchInfo> matched_terms_stats;
@@ -637,18 +553,8 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -671,13 +577,6 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) {
AddTokenToIndex(document_id, section_id, term_match_type, "world"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("hello world");
search_spec.set_term_match_type(term_match_type);
@@ -685,7 +584,7 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
std::vector<TermMatchInfo> matched_terms_stats;
@@ -714,18 +613,8 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -745,13 +634,6 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) {
AddTokenToIndex(document_id, section_id, term_match_type, "hello"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("he he");
search_spec.set_term_match_type(term_match_type);
@@ -759,7 +641,7 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
std::vector<TermMatchInfo> matched_terms_stats;
@@ -786,18 +668,8 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -820,13 +692,6 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) {
AddTokenToIndex(document_id, section_id, term_match_type, "world"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("he wo");
search_spec.set_term_match_type(term_match_type);
@@ -834,7 +699,7 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -864,18 +729,8 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -898,13 +753,6 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) {
AddTokenToIndex(document_id, section_id, term_match_type, "world"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("hello wo");
search_spec.set_term_match_type(term_match_type);
@@ -912,7 +760,7 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -942,18 +790,8 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -981,13 +819,6 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) {
AddTokenToIndex(document_id2, section_id, term_match_type, "world"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("hello OR world");
search_spec.set_term_match_type(term_match_type);
@@ -995,7 +826,7 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -1033,18 +864,8 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -1072,13 +893,6 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) {
AddTokenToIndex(document_id2, section_id, term_match_type, "world"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("he OR wo");
search_spec.set_term_match_type(term_match_type);
@@ -1086,7 +900,7 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -1123,18 +937,8 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -1161,13 +965,6 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) {
AddTokenToIndex(document_id2, section_id, TermMatchType::PREFIX, "world"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("hello OR wo");
search_spec.set_term_match_type(TermMatchType::PREFIX);
@@ -1175,7 +972,7 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -1212,18 +1009,8 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -1264,13 +1051,6 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
{
// OR gets precedence over AND, this is parsed as ((puppy OR kitten) AND
// dog)
@@ -1281,7 +1061,7 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Only Document 1 matches since it has puppy AND dog
@@ -1318,7 +1098,7 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Both Document 1 and 2 match since Document 1 has animal AND puppy, and
@@ -1374,7 +1154,7 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Only Document 2 matches since it has both kitten and cat
@@ -1407,18 +1187,8 @@ TEST_P(QueryProcessorTest, OneGroup) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -1451,13 +1221,6 @@ TEST_P(QueryProcessorTest, OneGroup) {
EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
// Without grouping, this would be parsed as ((puppy OR kitten) AND foo) and
// no documents would match. But with grouping, Document 1 matches puppy
SearchSpecProto search_spec;
@@ -1467,7 +1230,7 @@ TEST_P(QueryProcessorTest, OneGroup) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -1486,18 +1249,8 @@ TEST_P(QueryProcessorTest, TwoGroups) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -1530,12 +1283,6 @@ TEST_P(QueryProcessorTest, TwoGroups) {
EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
// Without grouping, this would be parsed as (puppy AND (dog OR kitten) AND
// cat) and wouldn't match any documents. But with grouping, Document 1
// matches (puppy AND dog) and Document 2 matches (kitten and cat).
@@ -1546,7 +1293,7 @@ TEST_P(QueryProcessorTest, TwoGroups) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -1567,18 +1314,8 @@ TEST_P(QueryProcessorTest, ManyLevelNestedGrouping) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -1611,13 +1348,6 @@ TEST_P(QueryProcessorTest, ManyLevelNestedGrouping) {
EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
// Without grouping, this would be parsed as ((puppy OR kitten) AND foo) and
// no documents would match. But with grouping, Document 1 matches puppy
SearchSpecProto search_spec;
@@ -1627,7 +1357,7 @@ TEST_P(QueryProcessorTest, ManyLevelNestedGrouping) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -1646,18 +1376,8 @@ TEST_P(QueryProcessorTest, OneLevelNestedGrouping) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -1690,13 +1410,6 @@ TEST_P(QueryProcessorTest, OneLevelNestedGrouping) {
EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
// Document 1 will match puppy and Document 2 matches (kitten AND (cat))
SearchSpecProto search_spec;
// TODO(b/208654892) decide how we want to handle queries of the form foo(...)
@@ -1706,7 +1419,7 @@ TEST_P(QueryProcessorTest, OneLevelNestedGrouping) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -1727,18 +1440,8 @@ TEST_P(QueryProcessorTest, ExcludeTerm) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that they'll bump the
// last_added_document_id, which will give us the proper exclusion results
@@ -1764,13 +1467,6 @@ TEST_P(QueryProcessorTest, ExcludeTerm) {
AddTokenToIndex(document_id2, section_id, term_match_type, "world"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("-hello");
search_spec.set_term_match_type(term_match_type);
@@ -1778,8 +1474,8 @@ TEST_P(QueryProcessorTest, ExcludeTerm) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(search_spec,
- ScoringSpecProto::RankingStrategy::NONE));
+ query_processor_->ParseSearch(search_spec,
+ ScoringSpecProto::RankingStrategy::NONE));
// We don't know have the section mask to indicate what section "world"
// came. It doesn't matter which section it was in since the query doesn't
@@ -1795,18 +1491,8 @@ TEST_P(QueryProcessorTest, ExcludeNonexistentTerm) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that they'll bump the
// last_added_document_id, which will give us the proper exclusion results
@@ -1831,13 +1517,6 @@ TEST_P(QueryProcessorTest, ExcludeNonexistentTerm) {
AddTokenToIndex(document_id2, section_id, term_match_type, "world"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("-foo");
search_spec.set_term_match_type(term_match_type);
@@ -1845,8 +1524,8 @@ TEST_P(QueryProcessorTest, ExcludeNonexistentTerm) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(search_spec,
- ScoringSpecProto::RankingStrategy::NONE));
+ query_processor_->ParseSearch(search_spec,
+ ScoringSpecProto::RankingStrategy::NONE));
// Descending order of valid DocumentIds
EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()),
@@ -1861,18 +1540,8 @@ TEST_P(QueryProcessorTest, ExcludeAnd) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that they'll bump the
// last_added_document_id, which will give us the proper exclusion results
@@ -1905,13 +1574,6 @@ TEST_P(QueryProcessorTest, ExcludeAnd) {
ASSERT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
{
SearchSpecProto search_spec;
search_spec.set_query("-dog -cat");
@@ -1920,7 +1582,7 @@ TEST_P(QueryProcessorTest, ExcludeAnd) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// The query is interpreted as "exclude all documents that have animal,
@@ -1939,7 +1601,7 @@ TEST_P(QueryProcessorTest, ExcludeAnd) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// The query is interpreted as "exclude all documents that have animal,
@@ -1957,18 +1619,8 @@ TEST_P(QueryProcessorTest, ExcludeOr) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that they'll bump the
// last_added_document_id, which will give us the proper exclusion results
@@ -2001,13 +1653,6 @@ TEST_P(QueryProcessorTest, ExcludeOr) {
ASSERT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
{
SearchSpecProto search_spec;
search_spec.set_query("-animal OR -cat");
@@ -2016,7 +1661,7 @@ TEST_P(QueryProcessorTest, ExcludeOr) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// We don't have a section mask indicating which sections in this document
@@ -2036,7 +1681,7 @@ TEST_P(QueryProcessorTest, ExcludeOr) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -2056,18 +1701,8 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// just inserting the documents so that the DocHitInfoIterators will see
// that the document exists and not filter out the DocumentId as deleted.
@@ -2110,12 +1745,6 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) {
EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
// OR gets precedence over AND, this is parsed as (animal AND (puppy OR
// kitten))
SearchSpecProto search_spec;
@@ -2125,8 +1754,8 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(search_spec,
- ScoringSpecProto::RankingStrategy::NONE));
+ query_processor_->ParseSearch(search_spec,
+ ScoringSpecProto::RankingStrategy::NONE));
// Since need_hit_term_frequency is false, the expected term frequencies
// should all be 0.
Hit::TermFrequencyArray exp_term_frequencies{0};
@@ -2175,18 +1804,8 @@ TEST_P(QueryProcessorTest, DeletedFilter) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// namespaces populated.
@@ -2220,13 +1839,6 @@ TEST_P(QueryProcessorTest, DeletedFilter) {
ASSERT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("animal");
search_spec.set_term_match_type(term_match_type);
@@ -2234,7 +1846,7 @@ TEST_P(QueryProcessorTest, DeletedFilter) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -2252,18 +1864,8 @@ TEST_P(QueryProcessorTest, NamespaceFilter) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// namespaces populated.
@@ -2296,13 +1898,6 @@ TEST_P(QueryProcessorTest, NamespaceFilter) {
ASSERT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("animal");
search_spec.set_term_match_type(term_match_type);
@@ -2311,7 +1906,7 @@ TEST_P(QueryProcessorTest, NamespaceFilter) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -2331,18 +1926,8 @@ TEST_P(QueryProcessorTest, SchemaTypeFilter) {
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.AddType(SchemaTypeConfigBuilder().SetType("message"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// schema types populated.
@@ -2371,13 +1956,6 @@ TEST_P(QueryProcessorTest, SchemaTypeFilter) {
AddTokenToIndex(document_id2, section_id, term_match_type, "animal"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
search_spec.set_query("animal");
search_spec.set_term_match_type(term_match_type);
@@ -2386,7 +1964,7 @@ TEST_P(QueryProcessorTest, SchemaTypeFilter) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -2411,18 +1989,8 @@ TEST_P(QueryProcessorTest, PropertyFilterForOneDocument) {
.Build();
// First and only indexed property, so it gets a section_id of 0
int subject_section_id = 0;
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// schema types populated.
@@ -2440,13 +2008,6 @@ TEST_P(QueryProcessorTest, PropertyFilterForOneDocument) {
"animal"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
// Create a section filter '<section name>:<query term>'
search_spec.set_query("subject:animal");
@@ -2455,7 +2016,7 @@ TEST_P(QueryProcessorTest, PropertyFilterForOneDocument) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Descending order of valid DocumentIds
@@ -2496,18 +2057,8 @@ TEST_P(QueryProcessorTest, PropertyFilterAcrossSchemaTypes) {
// alphabetically.
int email_foo_section_id = 1;
int message_foo_section_id = 0;
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// schema types populated.
@@ -2535,13 +2086,6 @@ TEST_P(QueryProcessorTest, PropertyFilterAcrossSchemaTypes) {
term_match_type, "animal"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
// Create a section filter '<section name>:<query term>'
search_spec.set_query("foo:animal");
@@ -2550,7 +2094,7 @@ TEST_P(QueryProcessorTest, PropertyFilterAcrossSchemaTypes) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Ordered by descending DocumentId, so message comes first since it was
@@ -2582,18 +2126,8 @@ TEST_P(QueryProcessorTest, PropertyFilterWithinSchemaType) {
.Build();
int email_foo_section_id = 0;
int message_foo_section_id = 0;
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// schema types populated.
@@ -2621,13 +2155,6 @@ TEST_P(QueryProcessorTest, PropertyFilterWithinSchemaType) {
term_match_type, "animal"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
// Create a section filter '<section name>:<query term>', but only look
// within documents of email schema
@@ -2638,7 +2165,7 @@ TEST_P(QueryProcessorTest, PropertyFilterWithinSchemaType) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Shouldn't include the message document since we're only looking at email
@@ -2686,18 +2213,8 @@ TEST_P(QueryProcessorTest, NestedPropertyFilter) {
TOKENIZER_PLAIN)
.SetCardinality(CARDINALITY_OPTIONAL)))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// schema types populated.
@@ -2715,13 +2232,6 @@ TEST_P(QueryProcessorTest, NestedPropertyFilter) {
term_match_type, "animal"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
// Create a section filter '<section name>:<query term>', but only look
// within documents of email schema
@@ -2731,7 +2241,7 @@ TEST_P(QueryProcessorTest, NestedPropertyFilter) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Even though the section id is the same, we should be able to tell that it
@@ -2763,18 +2273,8 @@ TEST_P(QueryProcessorTest, PropertyFilterRespectsDifferentSectionIds) {
.Build();
int email_foo_section_id = 0;
int message_foo_section_id = 0;
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// schema types populated.
@@ -2804,13 +2304,6 @@ TEST_P(QueryProcessorTest, PropertyFilterRespectsDifferentSectionIds) {
term_match_type, "animal"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
// Create a section filter '<section name>:<query term>', but only look
// within documents of email schema
@@ -2820,7 +2313,7 @@ TEST_P(QueryProcessorTest, PropertyFilterRespectsDifferentSectionIds) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Even though the section id is the same, we should be able to tell that it
@@ -2839,18 +2332,8 @@ TEST_P(QueryProcessorTest, NonexistentPropertyFilterReturnsEmptyResults) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// schema types populated.
@@ -2868,13 +2351,6 @@ TEST_P(QueryProcessorTest, NonexistentPropertyFilterReturnsEmptyResults) {
term_match_type, "animal"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
// Create a section filter '<section name>:<query term>', but only look
// within documents of email schema
@@ -2884,7 +2360,7 @@ TEST_P(QueryProcessorTest, NonexistentPropertyFilterReturnsEmptyResults) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Even though the section id is the same, we should be able to tell that it
@@ -2909,18 +2385,8 @@ TEST_P(QueryProcessorTest, UnindexedPropertyFilterReturnsEmptyResults) {
.SetDataType(TYPE_STRING)
.SetCardinality(CARDINALITY_OPTIONAL)))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// schema types populated.
@@ -2938,13 +2404,6 @@ TEST_P(QueryProcessorTest, UnindexedPropertyFilterReturnsEmptyResults) {
term_match_type, "animal"),
IsOk());
- // Perform query
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
// Create a section filter '<section name>:<query term>', but only look
// within documents of email schema
@@ -2954,7 +2413,7 @@ TEST_P(QueryProcessorTest, UnindexedPropertyFilterReturnsEmptyResults) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Even though the section id is the same, we should be able to tell that it
@@ -2982,18 +2441,8 @@ TEST_P(QueryProcessorTest, PropertyFilterTermAndUnrestrictedTerm) {
.Build();
int email_foo_section_id = 0;
int message_foo_section_id = 0;
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- DocumentStore::CreateResult create_result,
- DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_,
- schema_store_.get()));
- document_store_ = std::move(create_result.document_store);
-
// These documents don't actually match to the tokens in the index. We're
// inserting the documents to get the appropriate number of documents and
// schema types populated.
@@ -3024,12 +2473,6 @@ TEST_P(QueryProcessorTest, PropertyFilterTermAndUnrestrictedTerm) {
term_match_type, "animal"),
IsOk());
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
-
SearchSpecProto search_spec;
// Create a section filter '<section name>:<query term>'
search_spec.set_query("cat OR foo:animal");
@@ -3038,7 +2481,7 @@ TEST_P(QueryProcessorTest, PropertyFilterTermAndUnrestrictedTerm) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(
+ query_processor_->ParseSearch(
search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE));
// Ordered by descending DocumentId, so message comes first since it was
@@ -3060,10 +2503,6 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
// Arbitrary value, just has to be less than the document's creation
@@ -3096,10 +2535,10 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) {
// Perform query
ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
+ std::unique_ptr<QueryProcessor> local_query_processor,
+ QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ language_segmenter_.get(), normalizer_.get(),
+ document_store_.get(), schema_store_.get()));
SearchSpecProto search_spec;
search_spec.set_query("hello");
@@ -3108,8 +2547,8 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(search_spec,
- ScoringSpecProto::RankingStrategy::NONE));
+ local_query_processor->ParseSearch(
+ search_spec, ScoringSpecProto::RankingStrategy::NONE));
DocHitInfo expectedDocHitInfo(document_id);
expectedDocHitInfo.UpdateSection(/*section_id=*/0);
@@ -3122,10 +2561,6 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) {
SchemaProto schema = SchemaBuilder()
.AddType(SchemaTypeConfigBuilder().SetType("email"))
.Build();
-
- ICING_ASSERT_OK_AND_ASSIGN(
- schema_store_,
- SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
// Arbitrary value, just has to be greater than the document's creation
@@ -3158,10 +2593,10 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) {
// Perform query
ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<QueryProcessor> query_processor,
- QueryProcessor::Create(index_.get(), language_segmenter_.get(),
- normalizer_.get(), document_store_.get(),
- schema_store_.get()));
+ std::unique_ptr<QueryProcessor> local_query_processor,
+ QueryProcessor::Create(index_.get(), numeric_index_.get(),
+ language_segmenter_.get(), normalizer_.get(),
+ document_store_.get(), schema_store_.get()));
SearchSpecProto search_spec;
search_spec.set_query("hello");
@@ -3170,21 +2605,117 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) {
ICING_ASSERT_OK_AND_ASSIGN(
QueryResults results,
- query_processor->ParseSearch(search_spec,
- ScoringSpecProto::RankingStrategy::NONE));
+ local_query_processor->ParseSearch(
+ search_spec, ScoringSpecProto::RankingStrategy::NONE));
EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), IsEmpty());
}
+TEST_P(QueryProcessorTest, NumericFilter) {
+ if (GetParam() !=
+ SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) {
+ GTEST_SKIP() << "Numeric filter is only supported in advanced query.";
+ }
+
+ // Create the schema and document store
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder()
+ .SetType("transaction")
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("price")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName("cost")
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+ // SectionIds are assigned alphabetically
+ SectionId cost_section_id = 0;
+ SectionId price_section_id = 1;
+ ASSERT_THAT(schema_store_->SetSchema(schema), IsOk());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_one_id,
+ document_store_->Put(DocumentBuilder()
+ .SetKey("namespace", "1")
+ .SetSchema("transaction")
+ .AddInt64Property("price", 10)
+ .Build()));
+ ICING_ASSERT_OK(
+ AddToNumericIndex(document_one_id, "price", price_section_id, 10));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_two_id,
+ document_store_->Put(DocumentBuilder()
+ .SetKey("namespace", "2")
+ .SetSchema("transaction")
+ .AddInt64Property("price", 25)
+ .Build()));
+ ICING_ASSERT_OK(
+ AddToNumericIndex(document_two_id, "price", price_section_id, 25));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_three_id,
+ document_store_->Put(DocumentBuilder()
+ .SetKey("namespace", "3")
+ .SetSchema("transaction")
+ .AddInt64Property("cost", 2)
+ .Build()));
+ ICING_ASSERT_OK(
+ AddToNumericIndex(document_three_id, "cost", cost_section_id, 2));
+
+ SearchSpecProto search_spec;
+ search_spec.set_query("price < 20");
+ search_spec.set_search_type(GetParam());
+ ICING_ASSERT_OK_AND_ASSIGN(
+ QueryResults results,
+ query_processor_->ParseSearch(search_spec,
+ ScoringSpecProto::RankingStrategy::NONE));
+ EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()),
+ ElementsAre(EqualsDocHitInfo(
+ document_one_id, std::vector<SectionId>{price_section_id})));
+
+ search_spec.set_query("price == 25");
+ ICING_ASSERT_OK_AND_ASSIGN(
+ results, query_processor_->ParseSearch(
+ search_spec, ScoringSpecProto::RankingStrategy::NONE));
+ EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()),
+ ElementsAre(EqualsDocHitInfo(
+ document_two_id, std::vector<SectionId>{price_section_id})));
+
+ search_spec.set_query("cost > 2");
+ ICING_ASSERT_OK_AND_ASSIGN(
+ results, query_processor_->ParseSearch(
+ search_spec, ScoringSpecProto::RankingStrategy::NONE));
+ EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), IsEmpty());
+
+ search_spec.set_query("cost >= 2");
+ ICING_ASSERT_OK_AND_ASSIGN(
+ results, query_processor_->ParseSearch(
+ search_spec, ScoringSpecProto::RankingStrategy::NONE));
+ EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()),
+ ElementsAre(EqualsDocHitInfo(
+ document_three_id, std::vector<SectionId>{cost_section_id})));
+
+ search_spec.set_query("price <= 25");
+ ICING_ASSERT_OK_AND_ASSIGN(
+ results, query_processor_->ParseSearch(
+ search_spec, ScoringSpecProto::RankingStrategy::NONE));
+ EXPECT_THAT(
+ GetDocHitInfos(results.root_iterator.get()),
+ ElementsAre(EqualsDocHitInfo(document_two_id,
+ std::vector<SectionId>{price_section_id}),
+ EqualsDocHitInfo(document_one_id,
+ std::vector<SectionId>{price_section_id})));
+}
+
INSTANTIATE_TEST_SUITE_P(
QueryProcessorTest, QueryProcessorTest,
-#ifdef ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY
testing::Values(
SearchSpecProto::SearchType::ICING_RAW_QUERY,
SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY));
-#else // !ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY
- testing::Values(SearchSpecProto::SearchType::ICING_RAW_QUERY));
-#endif // ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY
} // namespace
diff --git a/icing/result/result-retriever-v2.cc b/icing/result/result-retriever-v2.cc
index c10c4e8..a51a8e6 100644
--- a/icing/result/result-retriever-v2.cc
+++ b/icing/result/result-retriever-v2.cc
@@ -113,16 +113,17 @@ std::pair<PageResult, bool> ResultRetrieverV2::RetrieveNextPage(
int32_t num_total_bytes = 0;
while (results.size() < result_state.num_per_page() &&
!result_state.scored_document_hits_ranker->empty()) {
- ScoredDocumentHit next_best_document_hit =
+ JoinedScoredDocumentHit next_best_document_hit =
result_state.scored_document_hits_ranker->PopNext();
if (group_result_limiter_->ShouldBeRemoved(
- next_best_document_hit, result_state.namespace_group_id_map(),
- doc_store_, result_state.group_result_limits)) {
+ next_best_document_hit.parent_scored_document_hit(),
+ result_state.namespace_group_id_map(), doc_store_,
+ result_state.group_result_limits)) {
continue;
}
- libtextclassifier3::StatusOr<DocumentProto> document_or =
- doc_store_.Get(next_best_document_hit.document_id());
+ libtextclassifier3::StatusOr<DocumentProto> document_or = doc_store_.Get(
+ next_best_document_hit.parent_scored_document_hit().document_id());
if (!document_or.ok()) {
// Skip the document if getting errors.
ICING_LOG(WARNING) << "Fail to fetch document from document store: "
@@ -147,14 +148,38 @@ std::pair<PageResult, bool> ResultRetrieverV2::RetrieveNextPage(
SnippetProto snippet_proto = snippet_retriever_->RetrieveSnippet(
snippet_context.query_terms, snippet_context.match_type,
snippet_context.snippet_spec, document,
- next_best_document_hit.hit_section_id_mask());
+ next_best_document_hit.parent_scored_document_hit()
+ .hit_section_id_mask());
*result.mutable_snippet() = std::move(snippet_proto);
++num_results_with_snippets;
}
// Add the document, itself.
*result.mutable_document() = std::move(document);
- result.set_score(next_best_document_hit.score());
+ result.set_score(next_best_document_hit.final_score());
+
+ // Retrieve child documents
+ for (const ScoredDocumentHit& child_scored_document_hit :
+ next_best_document_hit.child_scored_document_hits()) {
+ libtextclassifier3::StatusOr<DocumentProto> child_document_or =
+ doc_store_.Get(child_scored_document_hit.document_id());
+ if (!child_document_or.ok()) {
+ // Skip the document if getting errors.
+ ICING_LOG(WARNING)
+ << "Fail to fetch child document from document store: "
+ << child_document_or.status().error_message();
+ continue;
+ }
+
+ DocumentProto child_document = std::move(child_document_or).ValueOrDie();
+ // TODO(b/256022027): apply projection and add snippet for child doc
+
+ SearchResultProto::ResultProto* child_result =
+ result.add_joined_results();
+ *child_result->mutable_document() = std::move(child_document);
+ child_result->set_score(child_scored_document_hit.score());
+ }
+
size_t result_bytes = result.ByteSizeLong();
results.push_back(std::move(result));
diff --git a/icing/result/result-retriever-v2_group-result-limiter_test.cc b/icing/result/result-retriever-v2_group-result-limiter_test.cc
index e0a6c79..02c9bc6 100644
--- a/icing/result/result-retriever-v2_group-result-limiter_test.cc
+++ b/icing/result/result-retriever-v2_group-result-limiter_test.cc
@@ -162,7 +162,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest,
// Creates a ResultState with 2 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true), result_spec,
@@ -219,7 +220,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest,
// Creates a ResultState with 2 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true), result_spec,
@@ -292,7 +294,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest,
// Creates a ResultState with 4 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true), result_spec,
@@ -376,7 +379,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest,
// Creates a ResultState with 4 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true), result_spec,
@@ -433,7 +437,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest,
// Creates a ResultState with 2 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true), result_spec,
@@ -534,7 +539,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest,
// Creates a ResultState with 6 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true), result_spec,
@@ -592,7 +598,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest,
// Creates a ResultState with 2 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true), result_spec,
@@ -687,7 +694,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest,
// Creates a ResultState with 5 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true), result_spec,
diff --git a/icing/result/result-retriever-v2_projection_test.cc b/icing/result/result-retriever-v2_projection_test.cc
index ec67caa..d093d1f 100644
--- a/icing/result/result-retriever-v2_projection_test.cc
+++ b/icing/result/result-retriever-v2_projection_test.cc
@@ -222,7 +222,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionTopLevelLeadNodeFieldPath) {
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -317,7 +318,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionNestedLeafNodeFieldPath) {
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -423,7 +425,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionIntermediateNodeFieldPath) {
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -533,7 +536,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionMultipleNestedFieldPaths) {
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -626,7 +630,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionEmptyFieldPath) {
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -702,7 +707,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionInvalidFieldPath) {
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -779,7 +785,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionValidAndInvalidFieldPath) {
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -858,7 +865,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionMultipleTypesNoWildcards) {
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -941,7 +949,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionMultipleTypesWildcard) {
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -1028,7 +1037,8 @@ TEST_F(ResultRetrieverV2ProjectionTest,
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -1124,7 +1134,8 @@ TEST_F(ResultRetrieverV2ProjectionTest,
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -1224,7 +1235,8 @@ TEST_F(ResultRetrieverV2ProjectionTest,
// 4. Create ResultState with custom ResultSpec.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/SectionRestrictQueryTermsMap{},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
diff --git a/icing/result/result-retriever-v2_snippet_test.cc b/icing/result/result-retriever-v2_snippet_test.cc
index 9384d6b..6123bf4 100644
--- a/icing/result/result-retriever-v2_snippet_test.cc
+++ b/icing/result/result-retriever-v2_snippet_test.cc
@@ -37,12 +37,12 @@
#include "icing/testing/common-matchers.h"
#include "icing/testing/fake-clock.h"
#include "icing/testing/icu-data-file-helper.h"
-#include "icing/testing/snippet-helpers.h"
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
#include "icing/tokenization/language-segmenter-factory.h"
#include "icing/transform/normalizer-factory.h"
#include "icing/transform/normalizer.h"
+#include "icing/util/snippet-helpers.h"
#include "unicode/uloc.h"
namespace icing {
@@ -225,7 +225,8 @@ TEST_F(ResultRetrieverV2SnippetTest,
language_segmenter_.get(), normalizer_.get()));
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true),
@@ -267,7 +268,8 @@ TEST_F(ResultRetrieverV2SnippetTest, SimpleSnippeted) {
*result_spec.mutable_snippet_spec() = CreateSnippetSpec();
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/{{"", {"foo", "bar"}}},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -368,7 +370,8 @@ TEST_F(ResultRetrieverV2SnippetTest, OnlyOneDocumentSnippeted) {
*result_spec.mutable_snippet_spec() = std::move(snippet_spec);
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/{{"", {"foo", "bar"}}},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -437,7 +440,8 @@ TEST_F(ResultRetrieverV2SnippetTest, ShouldSnippetAllResults) {
*result_spec.mutable_snippet_spec() = std::move(snippet_spec);
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/{{"", {"foo", "bar"}}},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -483,7 +487,8 @@ TEST_F(ResultRetrieverV2SnippetTest, ShouldSnippetSomeResults) {
*result_spec.mutable_snippet_spec() = std::move(snippet_spec);
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/{{"", {"foo", "bar"}}},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -534,7 +539,8 @@ TEST_F(ResultRetrieverV2SnippetTest, ShouldNotSnippetAnyResults) {
*result_spec.mutable_snippet_spec() = std::move(snippet_spec);
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/{{"", {"foo", "bar"}}},
CreateSearchSpec(TermMatchType::EXACT_ONLY),
diff --git a/icing/result/result-retriever-v2_test.cc b/icing/result/result-retriever-v2_test.cc
index 0fb2ba0..6171688 100644
--- a/icing/result/result-retriever-v2_test.cc
+++ b/icing/result/result-retriever-v2_test.cc
@@ -289,7 +289,8 @@ TEST_F(ResultRetrieverV2Test, ShouldRetrieveSimpleResults) {
result5.set_score(1);
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
CreateScoringSpec(/*is_descending_order=*/true),
@@ -366,7 +367,8 @@ TEST_F(ResultRetrieverV2Test, ShouldIgnoreNonInternalErrors) {
result2.set_score(4);
ResultStateV2 result_state1(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -383,7 +385,8 @@ TEST_F(ResultRetrieverV2Test, ShouldIgnoreNonInternalErrors) {
{document_id1, hit_section_id_mask, /*score=*/12},
{document_id2, hit_section_id_mask, /*score=*/4}};
ResultStateV2 result_state2(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -432,7 +435,8 @@ TEST_F(ResultRetrieverV2Test, ShouldIgnoreInternalErrors) {
result1.set_score(0);
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -479,7 +483,8 @@ TEST_F(ResultRetrieverV2Test, ShouldUpdateResultState) {
language_segmenter_.get(), normalizer_.get()));
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -550,7 +555,8 @@ TEST_F(ResultRetrieverV2Test, ShouldUpdateNumTotalHits) {
{document_id2, hit_section_id_mask, /*score=*/0}};
std::shared_ptr<ResultStateV2> result_state1 =
std::make_shared<ResultStateV2>(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1),
/*is_descending=*/true),
/*query_terms=*/SectionRestrictQueryTermsMap{},
@@ -576,7 +582,8 @@ TEST_F(ResultRetrieverV2Test, ShouldUpdateNumTotalHits) {
{document_id5, hit_section_id_mask, /*score=*/0}};
std::shared_ptr<ResultStateV2> result_state2 =
std::make_shared<ResultStateV2>(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2),
/*is_descending=*/true),
/*query_terms=*/SectionRestrictQueryTermsMap{},
@@ -662,7 +669,8 @@ TEST_F(ResultRetrieverV2Test, ShouldLimitNumTotalBytesPerPage) {
ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2);
result_spec.set_num_total_bytes_per_page_threshold(result1.ByteSizeLong());
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -723,7 +731,8 @@ TEST_F(ResultRetrieverV2Test,
ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2);
result_spec.set_num_total_bytes_per_page_threshold(threshold);
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -783,7 +792,8 @@ TEST_F(ResultRetrieverV2Test,
ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2);
result_spec.set_num_total_bytes_per_page_threshold(threshold);
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
diff --git a/icing/result/result-retriever_test.cc b/icing/result/result-retriever_test.cc
index e0b4875..044e0f2 100644
--- a/icing/result/result-retriever_test.cc
+++ b/icing/result/result-retriever_test.cc
@@ -36,12 +36,12 @@
#include "icing/testing/common-matchers.h"
#include "icing/testing/fake-clock.h"
#include "icing/testing/icu-data-file-helper.h"
-#include "icing/testing/snippet-helpers.h"
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
#include "icing/tokenization/language-segmenter-factory.h"
#include "icing/transform/normalizer-factory.h"
#include "icing/transform/normalizer.h"
+#include "icing/util/snippet-helpers.h"
#include "unicode/uloc.h"
namespace icing {
diff --git a/icing/result/result-state-manager_test.cc b/icing/result/result-state-manager_test.cc
index 7025c63..e7acc31 100644
--- a/icing/result/result-state-manager_test.cc
+++ b/icing/result/result-state-manager_test.cc
@@ -183,9 +183,9 @@ TEST_F(ResultStateManagerTest, ShouldCacheAndRetrieveFirstPageOnePage) {
{document_id1, kSectionIdMaskNone, /*score=*/1},
{document_id2, kSectionIdMaskNone, /*score=*/1},
{document_id3, kSectionIdMaskNone, /*score=*/1}};
- std::unique_ptr<ScoredDocumentHitsRanker> ranker =
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
- std::move(scored_document_hits), /*is_descending=*/true);
+ std::unique_ptr<ScoredDocumentHitsRanker> ranker = std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
+ std::move(scored_document_hits), /*is_descending=*/true);
ResultStateManager result_state_manager(
/*max_total_hits=*/std::numeric_limits<int>::max(), document_store(),
@@ -228,9 +228,9 @@ TEST_F(ResultStateManagerTest, ShouldCacheAndRetrieveFirstPageMultiplePages) {
{document_id3, kSectionIdMaskNone, /*score=*/1},
{document_id4, kSectionIdMaskNone, /*score=*/1},
{document_id5, kSectionIdMaskNone, /*score=*/1}};
- std::unique_ptr<ScoredDocumentHitsRanker> ranker =
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
- std::move(scored_document_hits), /*is_descending=*/true);
+ std::unique_ptr<ScoredDocumentHitsRanker> ranker = std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
+ std::move(scored_document_hits), /*is_descending=*/true);
ResultStateManager result_state_manager(
/*max_total_hits=*/std::numeric_limits<int>::max(), document_store(),
@@ -299,7 +299,8 @@ TEST_F(ResultStateManagerTest, EmptyRankerShouldReturnEmptyFirstPage) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::vector<ScoredDocumentHit>(), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -333,7 +334,8 @@ TEST_F(ResultStateManagerTest, ShouldAllowEmptyFirstPage) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), result_spec, document_store(),
@@ -373,7 +375,8 @@ TEST_F(ResultStateManagerTest, ShouldAllowEmptyLastPage) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), result_spec, document_store(),
@@ -417,7 +420,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
query_terms, search_spec, scoring_spec, result_spec, document_store(),
result_retriever()));
@@ -428,7 +432,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
query_terms, search_spec, scoring_spec, result_spec, document_store(),
result_retriever()));
@@ -462,7 +467,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -474,7 +480,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -516,7 +523,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -561,7 +569,8 @@ TEST_F(ResultStateManagerTest, ShouldInvalidateOneToken) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -570,7 +579,8 @@ TEST_F(ResultStateManagerTest, ShouldInvalidateOneToken) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -607,7 +617,8 @@ TEST_F(ResultStateManagerTest, ShouldInvalidateAllTokens) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -616,7 +627,8 @@ TEST_F(ResultStateManagerTest, ShouldInvalidateAllTokens) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -649,7 +661,8 @@ TEST_F(ResultStateManagerTest, ShouldRemoveOldestResultState) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -658,7 +671,8 @@ TEST_F(ResultStateManagerTest, ShouldRemoveOldestResultState) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -668,7 +682,8 @@ TEST_F(ResultStateManagerTest, ShouldRemoveOldestResultState) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info3,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits3), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -713,7 +728,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -722,7 +738,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -731,7 +748,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info3,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits3), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -749,7 +767,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info4,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits4), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -801,7 +820,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -810,7 +830,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -819,7 +840,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info3,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits3), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -841,7 +863,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info4,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits4), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -850,7 +873,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info5,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits5), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -859,7 +883,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info6,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits6), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -920,7 +945,8 @@ TEST_F(
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -929,7 +955,8 @@ TEST_F(
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -938,7 +965,8 @@ TEST_F(
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info3,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits3), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -956,7 +984,8 @@ TEST_F(
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info4,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits4), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -970,7 +999,8 @@ TEST_F(
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info5,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits5), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1025,7 +1055,8 @@ TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1034,7 +1065,8 @@ TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1043,7 +1075,8 @@ TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info3,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits3), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1066,7 +1099,8 @@ TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info4,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits4), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1118,7 +1152,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1127,7 +1162,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1136,7 +1172,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info3,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits3), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1159,7 +1196,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info4,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits4), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1173,7 +1211,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info5,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits5), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1226,7 +1265,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1235,7 +1275,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1251,7 +1292,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info3,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits3), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1324,7 +1366,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1337,7 +1380,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info2,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1),
@@ -1374,7 +1418,8 @@ TEST_F(ResultStateManagerTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/true),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/2),
diff --git a/icing/result/result-state-manager_thread-safety_test.cc b/icing/result/result-state-manager_thread-safety_test.cc
index 523f84a..0da37d8 100644
--- a/icing/result/result-state-manager_thread-safety_test.cc
+++ b/icing/result/result-state-manager_thread-safety_test.cc
@@ -160,7 +160,8 @@ TEST_F(ResultStateManagerThreadSafetyTest,
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(kNumPerPage), *document_store_,
@@ -260,7 +261,8 @@ TEST_F(ResultStateManagerThreadSafetyTest, InvalidateResultStateWhileUsing) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits), /*is_descending=*/false),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(kNumPerPage), *document_store_,
@@ -389,7 +391,8 @@ TEST_F(ResultStateManagerThreadSafetyTest, MultipleResultStates) {
ICING_ASSERT_OK_AND_ASSIGN(
PageResultInfo page_result_info1,
result_state_manager.CacheAndRetrieveFirstPage(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits_copy), /*is_descending=*/false),
/*query_terms=*/{}, SearchSpecProto::default_instance(),
CreateScoringSpec(), CreateResultSpec(kNumPerPage),
diff --git a/icing/result/result-state-v2_test.cc b/icing/result/result-state-v2_test.cc
index 7255958..f32546a 100644
--- a/icing/result/result-state-v2_test.cc
+++ b/icing/result/result-state-v2_test.cc
@@ -130,7 +130,8 @@ TEST_F(ResultStateV2Test, ShouldInitializeValuesAccordingToSpecs) {
result_spec.set_num_total_bytes_per_page_threshold(4096);
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::vector<ScoredDocumentHit>(),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -152,7 +153,8 @@ TEST_F(ResultStateV2Test, ShouldInitializeValuesAccordingToDefaultSpecs) {
Eq(std::numeric_limits<int32_t>::max()));
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::vector<ScoredDocumentHit>(),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -178,7 +180,8 @@ TEST_F(ResultStateV2Test, ShouldReturnSnippetContextAccordingToSpecs) {
query_terms_map.emplace("term1", std::unordered_set<std::string>());
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::vector<ScoredDocumentHit>(),
/*is_descending=*/true),
query_terms_map, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -217,7 +220,8 @@ TEST_F(ResultStateV2Test, NoSnippetingShouldReturnNull) {
query_terms_map.emplace("term1", std::unordered_set<std::string>());
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::vector<ScoredDocumentHit>(),
/*is_descending=*/true),
query_terms_map, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -253,7 +257,8 @@ TEST_F(ResultStateV2Test, ShouldConstructProjectionTreeMapAccordingToSpecs) {
wildcard_type_property_mask->add_paths("wild.card");
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::vector<ScoredDocumentHit>(),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -319,7 +324,8 @@ TEST_F(ResultStateV2Test,
document_store().GetNamespaceId("namespace3"));
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::vector<ScoredDocumentHit>(),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -352,7 +358,8 @@ TEST_F(ResultStateV2Test, ShouldUpdateNumTotalHits) {
// Creates a ResultState with 5 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -384,7 +391,8 @@ TEST_F(ResultStateV2Test, ShouldUpdateNumTotalHitsWhenDestructed) {
{
// Creates a ResultState with 5 ScoredDocumentHits.
ResultStateV2 result_state1(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits1),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -399,7 +407,8 @@ TEST_F(ResultStateV2Test, ShouldUpdateNumTotalHitsWhenDestructed) {
{
// Creates another ResultState with 2 ScoredDocumentHits.
ResultStateV2 result_state2(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits2),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -428,7 +437,8 @@ TEST_F(ResultStateV2Test, ShouldNotUpdateNumTotalHitsWhenNotRegistered) {
// Creates a ResultState with 5 ScoredDocumentHits.
{
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
@@ -458,7 +468,8 @@ TEST_F(ResultStateV2Test, ShouldDecrementOriginalNumTotalHitsWhenReregister) {
// Creates a ResultState with 5 ScoredDocumentHits.
ResultStateV2 result_state(
- std::make_unique<PriorityQueueScoredDocumentHitsRanker>(
+ std::make_unique<
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>(
std::move(scored_document_hits),
/*is_descending=*/true),
/*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY),
diff --git a/icing/result/snippet-retriever_test.cc b/icing/result/snippet-retriever_test.cc
index 0940b51..80d00d5 100644
--- a/icing/result/snippet-retriever_test.cc
+++ b/icing/result/snippet-retriever_test.cc
@@ -38,7 +38,6 @@
#include "icing/testing/fake-clock.h"
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/jni-test-helpers.h"
-#include "icing/testing/snippet-helpers.h"
#include "icing/testing/test-data.h"
#include "icing/testing/tmp-directory.h"
#include "icing/tokenization/language-segmenter-factory.h"
@@ -46,6 +45,7 @@
#include "icing/transform/map/map-normalizer.h"
#include "icing/transform/normalizer-factory.h"
#include "icing/transform/normalizer.h"
+#include "icing/util/snippet-helpers.h"
#include "unicode/uloc.h"
namespace icing {
diff --git a/icing/scoring/advanced_scoring/advanced-scorer.cc b/icing/scoring/advanced_scoring/advanced-scorer.cc
new file mode 100644
index 0000000..9d52fde
--- /dev/null
+++ b/icing/scoring/advanced_scoring/advanced-scorer.cc
@@ -0,0 +1,58 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/scoring/advanced_scoring/advanced-scorer.h"
+
+#include <memory>
+
+#include "icing/query/advanced_query_parser/lexer.h"
+#include "icing/query/advanced_query_parser/parser.h"
+#include "icing/scoring/advanced_scoring/score-expression.h"
+#include "icing/scoring/advanced_scoring/scoring-visitor.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>>
+AdvancedScorer::Create(const ScoringSpecProto& scoring_spec,
+ double default_score,
+ const DocumentStore* document_store,
+ const SchemaStore* schema_store) {
+ ICING_RETURN_ERROR_IF_NULL(document_store);
+ ICING_RETURN_ERROR_IF_NULL(schema_store);
+
+ Lexer lexer(scoring_spec.advanced_scoring_expression(),
+ Lexer::Language::SCORING);
+ ICING_ASSIGN_OR_RETURN(std::vector<Lexer::LexerToken> lexer_tokens,
+ lexer.ExtractTokens());
+ Parser parser = Parser::Create(std::move(lexer_tokens));
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> tree_root,
+ parser.ConsumeScoring());
+
+ ScoringVisitor visitor(default_score);
+ tree_root->Accept(&visitor);
+
+ ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression,
+ std::move(visitor).Expression());
+ if (expression->is_document_type()) {
+ return absl_ports::InvalidArgumentError(
+ "The root scoring expression will always be evaluated to a document, "
+ "but a number is expected.");
+ }
+ return std::unique_ptr<AdvancedScorer>(
+ new AdvancedScorer(std::move(expression), default_score));
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/scoring/advanced_scoring/advanced-scorer.h b/icing/scoring/advanced_scoring/advanced-scorer.h
new file mode 100644
index 0000000..6557ba6
--- /dev/null
+++ b/icing/scoring/advanced_scoring/advanced-scorer.h
@@ -0,0 +1,66 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
+#define ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
+
+#include <memory>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/schema/schema-store.h"
+#include "icing/scoring/advanced_scoring/score-expression.h"
+#include "icing/scoring/scorer.h"
+#include "icing/store/document-store.h"
+
+namespace icing {
+namespace lib {
+
+class AdvancedScorer : public Scorer {
+ public:
+ // Returns:
+ // A AdvancedScorer instance on success
+ // FAILED_PRECONDITION on any null pointer input
+ // INVALID_ARGUMENT if fails to create an instance
+ static libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> Create(
+ const ScoringSpecProto& scoring_spec, double default_score,
+ const DocumentStore* document_store, const SchemaStore* schema_store);
+
+ double GetScore(const DocHitInfo& hit_info,
+ const DocHitInfoIterator* query_it) override {
+ libtextclassifier3::StatusOr<double> result =
+ score_expression_->eval(hit_info, query_it);
+ if (!result.ok()) {
+ ICING_LOG(ERROR) << "Got an error when scoring a document:\n"
+ << result.status().error_message();
+ return default_score_;
+ }
+ return std::move(result).ValueOrDie();
+ }
+
+ private:
+ explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression,
+ double default_score)
+ : score_expression_(std::move(score_expression)),
+ default_score_(default_score) {}
+
+ std::unique_ptr<ScoreExpression> score_expression_;
+ double default_score_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_
diff --git a/icing/scoring/advanced_scoring/advanced-scorer_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_test.cc
new file mode 100644
index 0000000..0d3a05c
--- /dev/null
+++ b/icing/scoring/advanced_scoring/advanced-scorer_test.cc
@@ -0,0 +1,404 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/scoring/advanced_scoring/advanced-scorer.h"
+
+#include <cmath>
+#include <memory>
+#include <string>
+#include <string_view>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/proto/document.pb.h"
+#include "icing/proto/schema.pb.h"
+#include "icing/proto/scoring.pb.h"
+#include "icing/proto/usage.pb.h"
+#include "icing/schema-builder.h"
+#include "icing/schema/schema-store.h"
+#include "icing/scoring/scorer-factory.h"
+#include "icing/scoring/scorer.h"
+#include "icing/store/document-id.h"
+#include "icing/store/document-store.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/fake-clock.h"
+#include "icing/testing/tmp-directory.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+using ::testing::DoubleNear;
+using ::testing::Eq;
+
+class AdvancedScorerTest : public testing::Test {
+ protected:
+ AdvancedScorerTest()
+ : test_dir_(GetTestTempDir() + "/icing"),
+ doc_store_dir_(test_dir_ + "/doc_store"),
+ schema_store_dir_(test_dir_ + "/schema_store") {}
+
+ void SetUp() override {
+ filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
+ filesystem_.CreateDirectoryRecursively(doc_store_dir_.c_str());
+ filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str());
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ schema_store_,
+ SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentStore::CreateResult create_result,
+ DocumentStore::Create(&filesystem_, doc_store_dir_, &fake_clock_,
+ schema_store_.get()));
+ document_store_ = std::move(create_result.document_store);
+
+ // Creates a simple email schema
+ SchemaProto test_email_schema =
+ SchemaBuilder()
+ .AddType(SchemaTypeConfigBuilder().SetType("email").AddProperty(
+ PropertyConfigBuilder()
+ .SetName("subject")
+ .SetDataTypeString(
+ TermMatchType::PREFIX,
+ StringIndexingConfig::TokenizerType::PLAIN)
+ .SetDataType(TYPE_STRING)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+
+ ICING_ASSERT_OK(schema_store_->SetSchema(test_email_schema));
+ }
+
+ void TearDown() override {
+ document_store_.reset();
+ schema_store_.reset();
+ filesystem_.DeleteDirectoryRecursively(test_dir_.c_str());
+ }
+
+ const std::string test_dir_;
+ const std::string doc_store_dir_;
+ const std::string schema_store_dir_;
+ Filesystem filesystem_;
+ std::unique_ptr<SchemaStore> schema_store_;
+ std::unique_ptr<DocumentStore> document_store_;
+ FakeClock fake_clock_;
+};
+
+constexpr double kEps = 0.0000000001;
+constexpr int kDefaultScore = 0;
+constexpr int64_t kDefaultCreationTimestampMs = 1571100001111;
+
+DocumentProto CreateDocument(
+ const std::string& name_space, const std::string& uri,
+ int score = kDefaultScore,
+ int64_t creation_timestamp_ms = kDefaultCreationTimestampMs) {
+ return DocumentBuilder()
+ .SetKey(name_space, uri)
+ .SetSchema("email")
+ .SetScore(score)
+ .SetCreationTimestampMs(creation_timestamp_ms)
+ .Build();
+}
+
+ScoringSpecProto CreateAdvancedScoringSpec(
+ const std::string& advanced_scoring_expression) {
+ ScoringSpecProto scoring_spec;
+ scoring_spec.set_rank_by(
+ ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION);
+ scoring_spec.set_advanced_scoring_expression(advanced_scoring_expression);
+ return scoring_spec;
+}
+
+TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) {
+ // Empty scoring expression for advanced scoring
+ ScoringSpecProto scoring_spec;
+ scoring_spec.set_rank_by(
+ ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION);
+ EXPECT_THAT(
+ scorer_factory::Create(scoring_spec, /*default_score=*/10,
+ document_store_.get(), schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ // Non-empty scoring expression for normal scoring
+ scoring_spec = ScoringSpecProto::default_instance();
+ scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE);
+ scoring_spec.set_advanced_scoring_expression("1");
+ EXPECT_THAT(
+ scorer_factory::Create(scoring_spec, /*default_score=*/10,
+ document_store_.get(), schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(AdvancedScorerTest, SimpleExpression) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(CreateDocument("namespace", "uri")));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Scorer> scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("123"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+
+ DocHitInfo docHitInfo = DocHitInfo(document_id);
+
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(123));
+}
+
+TEST_F(AdvancedScorerTest, BasicPureArithmeticExpression) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(CreateDocument("namespace", "uri")));
+ DocHitInfo docHitInfo = DocHitInfo(document_id);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Scorer> scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + 2"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(3));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(2));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0.5));
+}
+
+TEST_F(AdvancedScorerTest, BasicMathFunctionExpression) {
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(CreateDocument("namespace", "uri")));
+ DocHitInfo docHitInfo = DocHitInfo(document_id);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Scorer> scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("log(10, 1000)"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(3, kEps));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("log(2.718281828459045)"),
+ /*default_score=*/10, document_store_.get(), schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1024));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("max(10, 11, 12, 13, 14)"),
+ /*default_score=*/10, document_store_.get(), schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(14));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("min(10, 11, 12, 13, 14)"),
+ /*default_score=*/10, document_store_.get(), schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(sqrt(2), kEps));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(-2) + abs(2)"),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(4));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("sin(3.141592653589793)"),
+ /*default_score=*/10, document_store_.get(), schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(0, kEps));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("cos(3.141592653589793)"),
+ /*default_score=*/10, document_store_.get(), schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(-1, kEps));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(
+ CreateAdvancedScoringSpec("tan(3.141592653589793 / 4)"),
+ /*default_score=*/10, document_store_.get(), schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps));
+}
+
+// Should be a parsing Error
+TEST_F(AdvancedScorerTest, EmptyExpression) {
+ EXPECT_THAT(
+ AdvancedScorer::Create(CreateAdvancedScoringSpec(""),
+ /*default_score=*/10, document_store_.get(),
+ schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) {
+ const double default_score = 123;
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ DocumentId document_id,
+ document_store_->Put(CreateDocument("namespace", "uri")));
+ DocHitInfo docHitInfo = DocHitInfo(document_id);
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Scorer> scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("log(0)"), default_score,
+ document_store_.get(), schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer,
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"), default_score,
+ document_store_.get(), schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(-1)"),
+ default_score, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(-1, 0.5)"),
+ default_score, document_store_.get(),
+ schema_store_.get()));
+ EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps));
+}
+
+// The following tests should trigger a type error while the visitor tries to
+// build a ScoreExpression object.
+TEST_F(AdvancedScorerTest, MathTypeError) {
+ const double default_score = 0;
+
+ EXPECT_THAT(
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("test"), default_score,
+ document_store_.get(), schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"), default_score,
+ document_store_.get(), schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, 2, 3)"),
+ default_score, document_store_.get(),
+ schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, this)"),
+ default_score, document_store_.get(),
+ schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"), default_score,
+ document_store_.get(), schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(1, 2)"),
+ default_score, document_store_.get(),
+ schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(1, 2)"),
+ default_score, document_store_.get(),
+ schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sin(1, 2)"),
+ default_score, document_store_.get(),
+ schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("cos(1, 2)"),
+ default_score, document_store_.get(),
+ schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("tan(1, 2)"),
+ default_score, document_store_.get(),
+ schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("this"), default_score,
+ document_store_.get(), schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(
+ AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"), default_score,
+ document_store_.get(), schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+
+ EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + this"),
+ default_score, document_store_.get(),
+ schema_store_.get()),
+ StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc
new file mode 100644
index 0000000..cd77046
--- /dev/null
+++ b/icing/scoring/advanced_scoring/score-expression.cc
@@ -0,0 +1,203 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/scoring/advanced_scoring/score-expression.h"
+
+namespace icing {
+namespace lib {
+
+libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>>
+OperatorScoreExpression::Create(
+ OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) {
+ if (children.empty()) {
+ return absl_ports::InvalidArgumentError(
+ "OperatorScoreExpression must have at least one argument.");
+ }
+ for (const auto& child : children) {
+ ICING_RETURN_ERROR_IF_NULL(child);
+ if (child->is_document_type()) {
+ return absl_ports::InvalidArgumentError(
+ "Operators are not supported for document type.");
+ }
+ }
+ if (op == OperatorType::kNegative) {
+ if (children.size() != 1) {
+ return absl_ports::InvalidArgumentError(
+ "Negative operator must have only 1 argument.");
+ }
+ }
+ return std::unique_ptr<OperatorScoreExpression>(
+ new OperatorScoreExpression(op, std::move(children)));
+}
+
+libtextclassifier3::StatusOr<double> OperatorScoreExpression::eval(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) {
+ // The Create factory guarantees that an operator will have at least one
+ // child.
+ ICING_ASSIGN_OR_RETURN(double res, children_.at(0)->eval(hit_info, query_it));
+
+ if (op_ == OperatorType::kNegative) {
+ return -res;
+ }
+
+ for (int i = 1; i < children_.size(); ++i) {
+ ICING_ASSIGN_OR_RETURN(double v, children_.at(i)->eval(hit_info, query_it));
+ switch (op_) {
+ case OperatorType::kPlus:
+ res += v;
+ break;
+ case OperatorType::kMinus:
+ res -= v;
+ break;
+ case OperatorType::kTimes:
+ res *= v;
+ break;
+ case OperatorType::kDiv:
+ res /= v;
+ break;
+ case OperatorType::kNegative:
+ return absl_ports::InternalError("Should never reach here.");
+ }
+ if (!std::isfinite(res)) {
+ return absl_ports::InvalidArgumentError(
+ "Got a non-finite value while evaluating operator score expression.");
+ }
+ }
+ return res;
+}
+
+const std::unordered_map<std::string, MathFunctionScoreExpression::FunctionType>
+ MathFunctionScoreExpression::kFunctionNames = {
+ {"log", FunctionType::kLog}, {"pow", FunctionType::kPow},
+ {"max", FunctionType::kMax}, {"min", FunctionType::kMin},
+ {"sqrt", FunctionType::kSqrt}, {"abs", FunctionType::kAbs},
+ {"sin", FunctionType::kSin}, {"cos", FunctionType::kCos},
+ {"tan", FunctionType::kTan}};
+
+libtextclassifier3::StatusOr<std::unique_ptr<MathFunctionScoreExpression>>
+MathFunctionScoreExpression::Create(
+ FunctionType function_type,
+ std::vector<std::unique_ptr<ScoreExpression>> children) {
+ if (children.empty()) {
+ return absl_ports::InvalidArgumentError(
+ "Math functions must have at least one argument.");
+ }
+ for (const auto& child : children) {
+ ICING_RETURN_ERROR_IF_NULL(child);
+ if (child->is_document_type()) {
+ return absl_ports::InvalidArgumentError(
+ "Math functions are not supported for document type.");
+ }
+ }
+ switch (function_type) {
+ case FunctionType::kLog:
+ if (children.size() != 1 && children.size() != 2) {
+ return absl_ports::InvalidArgumentError(
+ "log must have 1 or 2 arguments.");
+ }
+ break;
+ case FunctionType::kPow:
+ if (children.size() != 2) {
+ return absl_ports::InvalidArgumentError("pow must have 2 arguments.");
+ }
+ break;
+ case FunctionType::kSqrt:
+ if (children.size() != 1) {
+ return absl_ports::InvalidArgumentError("sqrt must have 1 argument.");
+ }
+ break;
+ case FunctionType::kAbs:
+ if (children.size() != 1) {
+ return absl_ports::InvalidArgumentError("abs must have 1 argument.");
+ }
+ break;
+ case FunctionType::kSin:
+ if (children.size() != 1) {
+ return absl_ports::InvalidArgumentError("sin must have 1 argument.");
+ }
+ break;
+ case FunctionType::kCos:
+ if (children.size() != 1) {
+ return absl_ports::InvalidArgumentError("cos must have 1 argument.");
+ }
+ break;
+ case FunctionType::kTan:
+ if (children.size() != 1) {
+ return absl_ports::InvalidArgumentError("tan must have 1 argument.");
+ }
+ break;
+ // max and min support variable length arguments
+ case FunctionType::kMax:
+ [[fallthrough]];
+ case FunctionType::kMin:
+ break;
+ }
+ return std::unique_ptr<MathFunctionScoreExpression>(
+ new MathFunctionScoreExpression(function_type, std::move(children)));
+}
+
+libtextclassifier3::StatusOr<double> MathFunctionScoreExpression::eval(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) {
+ std::vector<double> values;
+ for (const auto& child : children_) {
+ ICING_ASSIGN_OR_RETURN(double v, child->eval(hit_info, query_it));
+ values.push_back(v);
+ }
+
+ double res = 0;
+ switch (function_type_) {
+ case FunctionType::kLog:
+ if (values.size() == 1) {
+ res = log(values[0]);
+ } else {
+ // argument 0 is log base
+ // argument 1 is the value
+ res = log(values[1]) / log(values[0]);
+ }
+ break;
+ case FunctionType::kPow:
+ res = pow(values[0], values[1]);
+ break;
+ case FunctionType::kMax:
+ res = *std::max_element(values.begin(), values.end());
+ break;
+ case FunctionType::kMin:
+ res = *std::min_element(values.begin(), values.end());
+ break;
+ case FunctionType::kSqrt:
+ res = sqrt(values[0]);
+ break;
+ case FunctionType::kAbs:
+ res = abs(values[0]);
+ break;
+ case FunctionType::kSin:
+ res = sin(values[0]);
+ break;
+ case FunctionType::kCos:
+ res = cos(values[0]);
+ break;
+ case FunctionType::kTan:
+ res = tan(values[0]);
+ break;
+ }
+ if (!std::isfinite(res)) {
+ return absl_ports::InvalidArgumentError(
+ "Got a non-finite value while evaluating math function score "
+ "expression.");
+ }
+ return res;
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h
new file mode 100644
index 0000000..0e0c538
--- /dev/null
+++ b/icing/scoring/advanced_scoring/score-expression.h
@@ -0,0 +1,154 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
+#define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
+
+#include <algorithm>
+#include <cmath>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/index/hit/doc-hit-info.h"
+#include "icing/index/iterator/doc-hit-info-iterator.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+// TODO(b/261474063) Simplify every ScoreExpression node to
+// ConstantScoreExpression if its evaluation does not depend on a document.
+class ScoreExpression {
+ public:
+ virtual ~ScoreExpression() = default;
+
+ // Evaluate the score expression to double with the current document.
+ //
+ // RETURNS:
+ // - The evaluated result as a double on success.
+ // - INVALID_ARGUMENT if a non-finite value is reached while evaluating the
+ // expression.
+ // - INTERNAL if there are inconsistencies.
+ virtual libtextclassifier3::StatusOr<double> eval(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) = 0;
+
+ // Indicate whether the current expression is of document type
+ virtual bool is_document_type() const { return false; }
+};
+
+class ThisExpression : public ScoreExpression {
+ public:
+ static std::unique_ptr<ThisExpression> Create() {
+ return std::unique_ptr<ThisExpression>(new ThisExpression());
+ }
+
+ libtextclassifier3::StatusOr<double> eval(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override {
+ return absl_ports::InternalError(
+ "Should never reach here to evaluate a document type as double. "
+ "There must be inconsistencies.");
+ }
+
+ bool is_document_type() const override { return true; }
+
+ private:
+ ThisExpression() = default;
+};
+
+class ConstantScoreExpression : public ScoreExpression {
+ public:
+ static std::unique_ptr<ConstantScoreExpression> Create(double c) {
+ return std::unique_ptr<ConstantScoreExpression>(
+ new ConstantScoreExpression(c));
+ }
+
+ libtextclassifier3::StatusOr<double> eval(
+ const DocHitInfo&, const DocHitInfoIterator*) override {
+ return c_;
+ }
+
+ private:
+ explicit ConstantScoreExpression(double c) : c_(c) {}
+
+ double c_;
+};
+
+class OperatorScoreExpression : public ScoreExpression {
+ public:
+ enum class OperatorType { kPlus, kMinus, kNegative, kTimes, kDiv };
+
+ // RETURNS:
+ // - An OperatorScoreExpression instance on success.
+ // - FAILED_PRECONDITION on any null pointer in children.
+ // - INVALID_ARGUMENT on type errors.
+ static libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>>
+ Create(OperatorType op,
+ std::vector<std::unique_ptr<ScoreExpression>> children);
+
+ libtextclassifier3::StatusOr<double> eval(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override;
+
+ private:
+ explicit OperatorScoreExpression(
+ OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children)
+ : op_(op), children_(std::move(children)) {}
+
+ OperatorType op_;
+ std::vector<std::unique_ptr<ScoreExpression>> children_;
+};
+
+class MathFunctionScoreExpression : public ScoreExpression {
+ public:
+ enum class FunctionType {
+ kLog,
+ kPow,
+ kMax,
+ kMin,
+ kSqrt,
+ kAbs,
+ kSin,
+ kCos,
+ kTan
+ };
+
+ static const std::unordered_map<std::string, FunctionType> kFunctionNames;
+
+ // RETURNS:
+ // - A MathFunctionScoreExpression instance on success.
+ // - FAILED_PRECONDITION on any null pointer in children.
+ // - INVALID_ARGUMENT on type errors.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<MathFunctionScoreExpression>>
+ Create(FunctionType function_type,
+ std::vector<std::unique_ptr<ScoreExpression>> children);
+
+ libtextclassifier3::StatusOr<double> eval(
+ const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override;
+
+ private:
+ explicit MathFunctionScoreExpression(
+ FunctionType function_type,
+ std::vector<std::unique_ptr<ScoreExpression>> children)
+ : function_type_(function_type), children_(std::move(children)) {}
+
+ FunctionType function_type_;
+ std::vector<std::unique_ptr<ScoreExpression>> children_;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_
diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc
new file mode 100644
index 0000000..7737213
--- /dev/null
+++ b/icing/scoring/advanced_scoring/scoring-visitor.cc
@@ -0,0 +1,159 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/scoring/advanced_scoring/scoring-visitor.h"
+
+#include "icing/absl_ports/str_cat.h"
+
+namespace icing {
+namespace lib {
+
+void ScoringVisitor::VisitFunctionName(const FunctionNameNode* node) {
+ pending_error_ = absl_ports::InternalError(
+ "FunctionNameNode should be handled in VisitFunction!");
+}
+
+void ScoringVisitor::VisitString(const StringNode* node) {
+ pending_error_ =
+ absl_ports::InvalidArgumentError("Scoring does not support String!");
+}
+
+void ScoringVisitor::VisitText(const TextNode* node) {
+ pending_error_ =
+ absl_ports::InternalError("TextNode should be handled in VisitMember!");
+}
+
+void ScoringVisitor::VisitMember(const MemberNode* node) {
+ std::string value;
+ if (node->children().size() == 1) {
+ // If a member has only one child, then it can be a numeric literal,
+ // or "this" if the member is a reference to a member function.
+ value = node->children()[0]->value();
+ if (value == "this") {
+ stack.push_back(ThisExpression::Create());
+ return;
+ }
+ } else if (node->children().size() == 2) {
+ // If a member has two children, then it can only represent a floating point
+ // number, so we need to join them by "." to build the numeric literal.
+ value = absl_ports::StrCat(node->children()[0]->value(), ".",
+ node->children()[1]->value());
+ } else {
+ pending_error_ = absl_ports::InvalidArgumentError(
+ "MemberNode must have 1 or 2 children.");
+ return;
+ }
+ char* end;
+ double number = std::strtod(value.c_str(), &end);
+ if (end != value.c_str() + value.length()) {
+ // While it would be doable to support property references in the scoring
+ // grammar, we currently don't have an efficient way to support such a
+ // lookup (we'd have to read each document). As such, it's simpler to just
+ // restrict the scoring language to not include properties.
+ pending_error_ = absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("Expect a numeric literal, but got ", value));
+ return;
+ }
+ stack.push_back(ConstantScoreExpression::Create(number));
+}
+
+void ScoringVisitor::VisitFunction(const FunctionNode* node) {
+ std::vector<std::unique_ptr<ScoreExpression>> children;
+ for (const auto& arg : node->args()) {
+ arg->Accept(this);
+ if (has_pending_error()) {
+ return;
+ }
+ children.push_back(pop_stack());
+ }
+ const std::string& function_name = node->function_name()->value();
+ libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> expression =
+ absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("Unknown function: ", function_name));
+
+ // Math functions
+ if (MathFunctionScoreExpression::kFunctionNames.find(function_name) !=
+ MathFunctionScoreExpression::kFunctionNames.end()) {
+ expression = MathFunctionScoreExpression::Create(
+ MathFunctionScoreExpression::kFunctionNames.at(function_name),
+ std::move(children));
+ }
+
+ if (!expression.ok()) {
+ pending_error_ = expression.status();
+ return;
+ }
+ stack.push_back(std::move(expression).ValueOrDie());
+}
+
+void ScoringVisitor::VisitUnaryOperator(const UnaryOperatorNode* node) {
+ if (node->operator_text() != "MINUS") {
+ pending_error_ = absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("Unknown unary operator: ", node->operator_text()));
+ return;
+ }
+ node->child()->Accept(this);
+ if (has_pending_error()) {
+ return;
+ }
+ std::vector<std::unique_ptr<ScoreExpression>> children;
+ children.push_back(pop_stack());
+
+ libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>>
+ expression = OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kNegative,
+ std::move(children));
+ if (!expression.ok()) {
+ pending_error_ = expression.status();
+ return;
+ }
+ stack.push_back(std::move(expression).ValueOrDie());
+}
+
+void ScoringVisitor::VisitNaryOperator(const NaryOperatorNode* node) {
+ std::vector<std::unique_ptr<ScoreExpression>> children;
+ for (const auto& arg : node->children()) {
+ arg->Accept(this);
+ if (has_pending_error()) {
+ return;
+ }
+ children.push_back(pop_stack());
+ }
+
+ libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>>
+ expression = absl_ports::InvalidArgumentError(
+ absl_ports::StrCat("Unknown Nary operator: ", node->operator_text()));
+
+ if (node->operator_text() == "PLUS") {
+ expression = OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kPlus, std::move(children));
+ } else if (node->operator_text() == "MINUS") {
+ expression = OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kMinus, std::move(children));
+ } else if (node->operator_text() == "TIMES") {
+ expression = OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kTimes, std::move(children));
+ } else if (node->operator_text() == "DIV") {
+ expression = OperatorScoreExpression::Create(
+ OperatorScoreExpression::OperatorType::kDiv, std::move(children));
+ }
+ if (!expression.ok()) {
+ pending_error_ = expression.status();
+ return;
+ }
+ stack.push_back(std::move(expression).ValueOrDie());
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/scoring/advanced_scoring/scoring-visitor.h b/icing/scoring/advanced_scoring/scoring-visitor.h
new file mode 100644
index 0000000..47a03fd
--- /dev/null
+++ b/icing/scoring/advanced_scoring/scoring-visitor.h
@@ -0,0 +1,77 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
+#define ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/legacy/core/icing-string-util.h"
+#include "icing/proto/scoring.pb.h"
+#include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
+#include "icing/scoring/advanced_scoring/score-expression.h"
+
+namespace icing {
+namespace lib {
+
+class ScoringVisitor : public AbstractSyntaxTreeVisitor {
+ public:
+ explicit ScoringVisitor(double default_score)
+ : default_score_(default_score) {}
+
+ void VisitFunctionName(const FunctionNameNode* node) override;
+ void VisitString(const StringNode* node) override;
+ void VisitText(const TextNode* node) override;
+ void VisitMember(const MemberNode* node) override;
+ void VisitFunction(const FunctionNode* node) override;
+ void VisitUnaryOperator(const UnaryOperatorNode* node) override;
+ void VisitNaryOperator(const NaryOperatorNode* node) override;
+
+ // RETURNS:
+ // - An ScoreExpression instance able to evaluate the expression on success.
+ // - INVALID_ARGUMENT if the AST does not conform to supported expressions,
+ // such as type errors.
+ // - INTERNAL if there are inconsistencies.
+ libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
+ Expression() && {
+ if (has_pending_error()) {
+ return pending_error_;
+ }
+ if (stack.size() != 1) {
+ return absl_ports::InternalError(IcingStringUtil::StringPrintf(
+ "Expect to get only one result from "
+ "ScoringVisitor, but got %zu. There must be inconsistencies.",
+ stack.size()));
+ }
+ return std::move(stack[0]);
+ }
+
+ private:
+ bool has_pending_error() const { return !pending_error_.ok(); }
+
+ std::unique_ptr<ScoreExpression> pop_stack() {
+ std::unique_ptr<ScoreExpression> result = std::move(stack.back());
+ stack.pop_back();
+ return result;
+ }
+
+ double default_score_;
+ libtextclassifier3::Status pending_error_;
+ std::vector<std::unique_ptr<ScoreExpression>> stack;
+};
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker.cc b/icing/scoring/priority-queue-scored-document-hits-ranker.cc
deleted file mode 100644
index 691b088..0000000
--- a/icing/scoring/priority-queue-scored-document-hits-ranker.cc
+++ /dev/null
@@ -1,53 +0,0 @@
-// Copyright (C) 2022 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include "icing/scoring/priority-queue-scored-document-hits-ranker.h"
-
-#include <queue>
-#include <vector>
-
-#include "icing/scoring/scored-document-hit.h"
-
-namespace icing {
-namespace lib {
-
-PriorityQueueScoredDocumentHitsRanker::PriorityQueueScoredDocumentHitsRanker(
- std::vector<ScoredDocumentHit>&& scored_document_hits, bool is_descending)
- : comparator_(/*is_ascending=*/!is_descending),
- scored_document_hits_pq_(comparator_, std::move(scored_document_hits)) {}
-
-ScoredDocumentHit PriorityQueueScoredDocumentHitsRanker::PopNext() {
- ScoredDocumentHit ret = scored_document_hits_pq_.top();
- scored_document_hits_pq_.pop();
- return ret;
-}
-
-void PriorityQueueScoredDocumentHitsRanker::TruncateHitsTo(int new_size) {
- if (new_size < 0 || scored_document_hits_pq_.size() <= new_size) {
- return;
- }
-
- // Copying the best new_size results.
- std::priority_queue<ScoredDocumentHit, std::vector<ScoredDocumentHit>,
- Comparator>
- new_pq(comparator_);
- for (int i = 0; i < new_size; ++i) {
- new_pq.push(scored_document_hits_pq_.top());
- scored_document_hits_pq_.pop();
- }
- scored_document_hits_pq_ = std::move(new_pq);
-}
-
-} // namespace lib
-} // namespace icing
diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker.h b/icing/scoring/priority-queue-scored-document-hits-ranker.h
index 3ef2ae5..0798d7d 100644
--- a/icing/scoring/priority-queue-scored-document-hits-ranker.h
+++ b/icing/scoring/priority-queue-scored-document-hits-ranker.h
@@ -26,21 +26,37 @@ namespace lib {
// ScoredDocumentHitsRanker interface implementation, based on
// std::priority_queue. We can get next top hit in O(lgN) time.
+template <typename ScoredDataType,
+ typename Converter = typename ScoredDataType::Converter>
class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker {
public:
explicit PriorityQueueScoredDocumentHitsRanker(
- std::vector<ScoredDocumentHit>&& scored_document_hits,
- bool is_descending = true);
+ std::vector<ScoredDataType>&& scored_data_vec, bool is_descending = true);
~PriorityQueueScoredDocumentHitsRanker() override = default;
- ScoredDocumentHit PopNext() override;
+ // Note: ranker may store ScoredDocumentHit or JoinedScoredDocumentHit, so we
+ // have template for scored_data_pq_.
+ // - JoinedScoredDocumentHit is a superset of ScoredDocumentHit, so we unify
+ // the return type of PopNext to use the superset type
+ // JoinedScoredDocumentHit in order to make it simple, and rankers storing
+ // ScoredDocumentHit should convert it to JoinedScoredDocumentHit before
+ // returning. It makes the implementation simpler, especially for
+ // ResultRetriever, which now only needs to deal with one single return
+ // format.
+ // - JoinedScoredDocumentHit has ~2x size of ScoredDocumentHit. Since we cache
+ // ranker (which contains a priority queue of data) in ResultState, if we
+ // store the scored hits in JoinedScoredDocumentHit format directly, then it
+ // doubles the memory usage. Therefore, we still keep the flexibility to
+ // store ScoredDocumentHit or any other types of data, but require PopNext
+ // to convert it to JoinedScoredDocumentHit.
+ JoinedScoredDocumentHit PopNext() override;
void TruncateHitsTo(int new_size) override;
- int size() const override { return scored_document_hits_pq_.size(); }
+ int size() const override { return scored_data_pq_.size(); }
- bool empty() const override { return scored_document_hits_pq_.empty(); }
+ bool empty() const override { return scored_data_pq_.empty(); }
private:
// Comparator for std::priority_queue. Since std::priority is a max heap
@@ -49,8 +65,8 @@ class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker {
public:
explicit Comparator(bool is_ascending) : is_ascending_(is_ascending) {}
- bool operator()(const ScoredDocumentHit& lhs,
- const ScoredDocumentHit& rhs) const {
+ bool operator()(const ScoredDataType& lhs,
+ const ScoredDataType& rhs) const {
// STL comparator requirement: equal MUST return false.
// If writing `return is_ascending_ == !(lhs < rhs)`:
// - When lhs == rhs, !(lhs < rhs) is true
@@ -68,11 +84,44 @@ class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker {
Comparator comparator_;
// Use priority queue to get top K hits in O(KlgN) time.
- std::priority_queue<ScoredDocumentHit, std::vector<ScoredDocumentHit>,
- Comparator>
- scored_document_hits_pq_;
+ std::priority_queue<ScoredDataType, std::vector<ScoredDataType>, Comparator>
+ scored_data_pq_;
+
+ Converter converter_;
};
+template <typename ScoredDataType, typename Converter>
+PriorityQueueScoredDocumentHitsRanker<ScoredDataType, Converter>::
+ PriorityQueueScoredDocumentHitsRanker(
+ std::vector<ScoredDataType>&& scored_data_vec, bool is_descending)
+ : comparator_(/*is_ascending=*/!is_descending),
+ scored_data_pq_(comparator_, std::move(scored_data_vec)) {}
+
+template <typename ScoredDataType, typename Converter>
+JoinedScoredDocumentHit
+PriorityQueueScoredDocumentHitsRanker<ScoredDataType, Converter>::PopNext() {
+ ScoredDataType next_scored_data = scored_data_pq_.top();
+ scored_data_pq_.pop();
+ return converter_(std::move(next_scored_data));
+}
+
+template <typename ScoredDataType, typename Converter>
+void PriorityQueueScoredDocumentHitsRanker<
+ ScoredDataType, Converter>::TruncateHitsTo(int new_size) {
+ if (new_size < 0 || scored_data_pq_.size() <= new_size) {
+ return;
+ }
+
+ // Copying the best new_size results.
+ std::priority_queue<ScoredDataType, std::vector<ScoredDataType>, Comparator>
+ new_pq(comparator_);
+ for (int i = 0; i < new_size; ++i) {
+ new_pq.push(scored_data_pq_.top());
+ scored_data_pq_.pop();
+ }
+ scored_data_pq_ = std::move(new_pq);
+}
+
} // namespace lib
} // namespace icing
diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc b/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc
index a575eaf..ace2350 100644
--- a/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc
+++ b/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc
@@ -31,9 +31,19 @@ using ::testing::Eq;
using ::testing::IsEmpty;
using ::testing::SizeIs;
-std::vector<ScoredDocumentHit> PopAll(
- PriorityQueueScoredDocumentHitsRanker& ranker) {
- std::vector<ScoredDocumentHit> hits;
+class Converter {
+ public:
+ JoinedScoredDocumentHit operator()(ScoredDocumentHit hit) const {
+ return converter_(std::move(hit));
+ }
+
+ private:
+ ScoredDocumentHit::Converter converter_;
+} converter;
+
+std::vector<JoinedScoredDocumentHit> PopAll(
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>& ranker) {
+ std::vector<JoinedScoredDocumentHit> hits;
while (!ranker.empty()) {
hits.push_back(ranker.PopNext());
}
@@ -48,7 +58,7 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldGetCorrectSizeAndEmpty) {
ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone,
/*score=*/1);
- PriorityQueueScoredDocumentHitsRanker ranker(
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
{scored_hit_1, scored_hit_0, scored_hit_2},
/*is_descending=*/true);
EXPECT_THAT(ranker.size(), Eq(3));
@@ -79,18 +89,19 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInDescendingOrder) {
ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
/*score=*/1);
- PriorityQueueScoredDocumentHitsRanker ranker(
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
{scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
/*is_descending=*/true);
EXPECT_THAT(ranker, SizeIs(5));
- std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker);
- EXPECT_THAT(scored_document_hits,
- ElementsAre(EqualsScoredDocumentHit(scored_hit_4),
- EqualsScoredDocumentHit(scored_hit_3),
- EqualsScoredDocumentHit(scored_hit_2),
- EqualsScoredDocumentHit(scored_hit_1),
- EqualsScoredDocumentHit(scored_hit_0)));
+ std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
+ EXPECT_THAT(
+ scored_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_1)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_0))));
}
TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInAscendingOrder) {
@@ -105,18 +116,19 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInAscendingOrder) {
ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
/*score=*/1);
- PriorityQueueScoredDocumentHitsRanker ranker(
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
{scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
/*is_descending=*/false);
EXPECT_THAT(ranker, SizeIs(5));
- std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker);
- EXPECT_THAT(scored_document_hits,
- ElementsAre(EqualsScoredDocumentHit(scored_hit_0),
- EqualsScoredDocumentHit(scored_hit_1),
- EqualsScoredDocumentHit(scored_hit_2),
- EqualsScoredDocumentHit(scored_hit_3),
- EqualsScoredDocumentHit(scored_hit_4)));
+ std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
+ EXPECT_THAT(
+ scored_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_0)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_1)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_4))));
}
TEST(PriorityQueueScoredDocumentHitsRankerTest,
@@ -132,28 +144,30 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest,
ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
/*score=*/1);
- PriorityQueueScoredDocumentHitsRanker ranker(
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
{scored_hit_2, scored_hit_4, scored_hit_1, scored_hit_0, scored_hit_2,
scored_hit_2, scored_hit_4, scored_hit_3},
/*is_descending=*/true);
EXPECT_THAT(ranker, SizeIs(8));
- std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker);
- EXPECT_THAT(scored_document_hits,
- ElementsAre(EqualsScoredDocumentHit(scored_hit_4),
- EqualsScoredDocumentHit(scored_hit_4),
- EqualsScoredDocumentHit(scored_hit_3),
- EqualsScoredDocumentHit(scored_hit_2),
- EqualsScoredDocumentHit(scored_hit_2),
- EqualsScoredDocumentHit(scored_hit_2),
- EqualsScoredDocumentHit(scored_hit_1),
- EqualsScoredDocumentHit(scored_hit_0)));
+ std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
+ EXPECT_THAT(
+ scored_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_1)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_0))));
}
TEST(PriorityQueueScoredDocumentHitsRankerTest,
ShouldRankEmptyScoredDocumentHits) {
- PriorityQueueScoredDocumentHitsRanker ranker(/*scored_document_hits=*/{},
- /*is_descending=*/true);
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
+ /*scored_document_hits=*/{},
+ /*is_descending=*/true);
EXPECT_THAT(ranker, IsEmpty());
}
@@ -169,18 +183,19 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToNewSize) {
ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
/*score=*/1);
- PriorityQueueScoredDocumentHitsRanker ranker(
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
{scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
/*is_descending=*/true);
ASSERT_THAT(ranker, SizeIs(5));
ranker.TruncateHitsTo(/*new_size=*/3);
EXPECT_THAT(ranker, SizeIs(3));
- std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker);
- EXPECT_THAT(scored_document_hits,
- ElementsAre(EqualsScoredDocumentHit(scored_hit_4),
- EqualsScoredDocumentHit(scored_hit_3),
- EqualsScoredDocumentHit(scored_hit_2)));
+ std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
+ EXPECT_THAT(
+ scored_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_2))));
}
TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToZero) {
@@ -195,7 +210,7 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToZero) {
ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
/*score=*/1);
- PriorityQueueScoredDocumentHitsRanker ranker(
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
{scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
/*is_descending=*/true);
ASSERT_THAT(ranker, SizeIs(5));
@@ -216,7 +231,7 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldNotTruncateToNegative) {
ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone,
/*score=*/1);
- PriorityQueueScoredDocumentHitsRanker ranker(
+ PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker(
{scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3},
/*is_descending=*/true);
ASSERT_THAT(ranker, SizeIs(Eq(5)));
@@ -224,13 +239,14 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldNotTruncateToNegative) {
ranker.TruncateHitsTo(/*new_size=*/-1);
EXPECT_THAT(ranker, SizeIs(Eq(5)));
// Contents are not affected.
- std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker);
- EXPECT_THAT(scored_document_hits,
- ElementsAre(EqualsScoredDocumentHit(scored_hit_4),
- EqualsScoredDocumentHit(scored_hit_3),
- EqualsScoredDocumentHit(scored_hit_2),
- EqualsScoredDocumentHit(scored_hit_1),
- EqualsScoredDocumentHit(scored_hit_0)));
+ std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker);
+ EXPECT_THAT(
+ scored_document_hits,
+ ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_3)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_2)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_1)),
+ EqualsJoinedScoredDocumentHit(converter(scored_hit_0))));
}
} // namespace
diff --git a/icing/scoring/scored-document-hit.cc b/icing/scoring/scored-document-hit.cc
new file mode 100644
index 0000000..f519a16
--- /dev/null
+++ b/icing/scoring/scored-document-hit.cc
@@ -0,0 +1,30 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/scoring/scored-document-hit.h"
+
+namespace icing {
+namespace lib {
+
+JoinedScoredDocumentHit ScoredDocumentHit::Converter::operator()(
+ ScoredDocumentHit&& scored_doc_hit) const {
+ double final_score = scored_doc_hit.score();
+ return JoinedScoredDocumentHit(
+ final_score,
+ /*parent_scored_document_hit=*/std::move(scored_doc_hit),
+ /*child_scored_document_hits=*/{});
+}
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/scoring/scored-document-hit.h b/icing/scoring/scored-document-hit.h
index 96ca6aa..141049e 100644
--- a/icing/scoring/scored-document-hit.h
+++ b/icing/scoring/scored-document-hit.h
@@ -24,11 +24,19 @@
namespace icing {
namespace lib {
+class JoinedScoredDocumentHit;
+
// A data class containing information about the document, hit sections, and a
// score. The score is calculated against both the document and the hit
// sections.
class ScoredDocumentHit {
public:
+ class Converter {
+ public:
+ JoinedScoredDocumentHit operator()(
+ ScoredDocumentHit&& scored_doc_hit) const;
+ };
+
ScoredDocumentHit(DocumentId document_id, SectionIdMask hit_section_id_mask,
double score)
: document_id_(document_id),
@@ -85,6 +93,65 @@ class ScoredDocumentHitComparator {
bool is_descending_;
};
+// A data class containing information about a composite document after joining,
+// including final score, parent ScoredDocumentHit, and a vector of all child
+// ScoredDocumentHits. The final score is calculated by the strategy specified
+// in join spec/rank strategy. It could be aggregated score, raw parent doc
+// score, or anything else.
+//
+// ScoredDocumentHitsRanker may store ScoredDocumentHit or
+// JoinedScoredDocumentHit.
+// - We could've created a virtual class for them and ScoredDocumentHitsRanker
+// uses the abstract type.
+// - However, Icing lib caches ScoredDocumentHitsRanker (which contains a list
+// of (Joined)ScoredDocumentHits) in ResultState. Inheriting the virtual class
+// makes both classes have additional 8 bytes for vtable, which increases 40%
+// and 15% memory usage respectively.
+// - Also since JoinedScoredDocumentHit is a super-set of ScoredDocumentHit,
+// let's avoid the common virtual class and instead implement a convert
+// function (original type -> JoinedScoredDocumentHit) for each class, so
+// ScoredDocumentHitsRanker::PopNext can return a common type (i.e.
+// JoinedScoredDocumentHit).
+class JoinedScoredDocumentHit {
+ public:
+ class Converter {
+ public:
+ JoinedScoredDocumentHit operator()(
+ JoinedScoredDocumentHit&& scored_doc_hit) const {
+ return scored_doc_hit;
+ }
+ };
+
+ explicit JoinedScoredDocumentHit(
+ double final_score, ScoredDocumentHit&& parent_scored_document_hit,
+ std::vector<ScoredDocumentHit>&& child_scored_document_hits)
+ : final_score_(final_score),
+ parent_scored_document_hit_(std::move(parent_scored_document_hit)),
+ child_scored_document_hits_(std::move(child_scored_document_hits)) {}
+
+ bool operator<(const JoinedScoredDocumentHit& other) const {
+ if (final_score_ != other.final_score_) {
+ return final_score_ < other.final_score_;
+ }
+ return parent_scored_document_hit_ < other.parent_scored_document_hit_;
+ }
+
+ double final_score() const { return final_score_; }
+
+ const ScoredDocumentHit& parent_scored_document_hit() const {
+ return parent_scored_document_hit_;
+ }
+
+ const std::vector<ScoredDocumentHit>& child_scored_document_hits() const {
+ return child_scored_document_hits_;
+ }
+
+ private:
+ double final_score_;
+ ScoredDocumentHit parent_scored_document_hit_;
+ std::vector<ScoredDocumentHit> child_scored_document_hits_;
+} __attribute__((packed));
+
} // namespace lib
} // namespace icing
diff --git a/icing/scoring/scored-document-hit_test.cc b/icing/scoring/scored-document-hit_test.cc
new file mode 100644
index 0000000..cb9703b
--- /dev/null
+++ b/icing/scoring/scored-document-hit_test.cc
@@ -0,0 +1,77 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/scoring/scored-document-hit.h"
+
+#include <cstdint>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/testing/common-matchers.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::DoubleEq;
+using ::testing::IsEmpty;
+
+TEST(ScoredDocumentHitTest, ScoredDocumentHitConvertToJoinedScoredDocumentHit) {
+ ScoredDocumentHit::Converter converter;
+
+ double score = 2.0;
+ ScoredDocumentHit scored_document_hit(/*document_id=*/5,
+ /*section_id_mask=*/49, score);
+
+ JoinedScoredDocumentHit joined_scored_document_hit =
+ converter(ScoredDocumentHit(scored_document_hit));
+ EXPECT_THAT(joined_scored_document_hit.final_score(), DoubleEq(score));
+ EXPECT_THAT(joined_scored_document_hit.parent_scored_document_hit(),
+ EqualsScoredDocumentHit(scored_document_hit));
+ EXPECT_THAT(joined_scored_document_hit.child_scored_document_hits(),
+ IsEmpty());
+}
+
+TEST(ScoredDocumentHitTest,
+ JoinedScoredDocumentHitConvertToJoinedScoredDocumentHit) {
+ JoinedScoredDocumentHit::Converter converter;
+
+ ScoredDocumentHit parent_scored_document_hit(/*document_id=*/5,
+ /*section_id_mask=*/49,
+ /*score=*/1.0);
+ std::vector<ScoredDocumentHit> child_scored_document_hits{
+ ScoredDocumentHit(/*document_id=*/1,
+ /*section_id_mask=*/1,
+ /*score=*/2.0),
+ ScoredDocumentHit(/*document_id=*/2,
+ /*section_id_mask=*/2,
+ /*score=*/3.0),
+ ScoredDocumentHit(/*document_id=*/3,
+ /*section_id_mask=*/3,
+ /*score=*/4.0)};
+
+ JoinedScoredDocumentHit joined_scored_document_hit(
+ /*final_score=*/12345.6789, std::move(parent_scored_document_hit),
+ std::move(child_scored_document_hits));
+ EXPECT_THAT(converter(JoinedScoredDocumentHit(joined_scored_document_hit)),
+ EqualsJoinedScoredDocumentHit(joined_scored_document_hit));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/scoring/scored-document-hits-ranker.h b/icing/scoring/scored-document-hits-ranker.h
index 0287452..9b76ce7 100644
--- a/icing/scoring/scored-document-hits-ranker.h
+++ b/icing/scoring/scored-document-hits-ranker.h
@@ -30,10 +30,19 @@ class ScoredDocumentHitsRanker {
public:
virtual ~ScoredDocumentHitsRanker() = default;
- // Pop the next top ScoredDocumentHit and return. It is undefined to call
- // PopNext on an empty ranker, so the caller should check if it is not empty
- // before calling.
- virtual ScoredDocumentHit PopNext() = 0;
+ // Pop the next top JoinedScoredDocumentHit and return. It is undefined to
+ // call PopNext on an empty ranker, so the caller should check if it is not
+ // empty before calling.
+ //
+ // Note: ranker may store ScoredDocumentHit or JoinedScoredDocumentHit. We can
+ // add template for this interface, but since JoinedScoredDocumentHit is a
+ // superset of ScoredDocumentHit, we unify the return type of PopNext to use
+ // the superset type JoinedScoredDocumentHit in order to make it simple, and
+ // rankers storing ScoredDocumentHit should convert it to
+ // JoinedScoredDocumentHit before returning. It makes the implementation
+ // simpler, especially for ResultRetriever, which now only needs to deal with
+ // one single return format.
+ virtual JoinedScoredDocumentHit PopNext() = 0;
// Truncates the remaining ScoredDocumentHits to the given size. The best
// ScoredDocumentHits (according to the ranking policy) should be kept.
diff --git a/icing/scoring/scorer.cc b/icing/scoring/scorer-factory.cc
index 14a004e..600fe6b 100644
--- a/icing/scoring/scorer.cc
+++ b/icing/scoring/scorer-factory.cc
@@ -12,16 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "icing/scoring/scorer.h"
+#include "icing/scoring/scorer-factory.h"
#include <memory>
+#include <unordered_map>
#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/absl_ports/canonical_errors.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/proto/scoring.pb.h"
+#include "icing/scoring/advanced_scoring/advanced-scorer.h"
#include "icing/scoring/bm25f-calculator.h"
+#include "icing/scoring/scorer.h"
#include "icing/scoring/section-weights.h"
#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
@@ -156,12 +159,22 @@ class NoScorer : public Scorer {
double default_score_;
};
-libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create(
+namespace scorer_factory {
+
+libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
const ScoringSpecProto& scoring_spec, double default_score,
const DocumentStore* document_store, const SchemaStore* schema_store) {
ICING_RETURN_ERROR_IF_NULL(document_store);
ICING_RETURN_ERROR_IF_NULL(schema_store);
+ if (!scoring_spec.advanced_scoring_expression().empty() &&
+ scoring_spec.rank_by() !=
+ ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION) {
+ return absl_ports::InvalidArgumentError(
+ "Advanced scoring is not enabled, but the advanced scoring expression "
+ "is not empty!");
+ }
+
switch (scoring_spec.rank_by()) {
case ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE:
return std::make_unique<DocumentScoreScorer>(document_store,
@@ -192,6 +205,13 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create(
case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP:
return std::make_unique<UsageScorer>(
document_store, scoring_spec.rank_by(), default_score);
+ case ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION:
+ if (scoring_spec.advanced_scoring_expression().empty()) {
+ return absl_ports::InvalidArgumentError(
+ "Advanced scoring is enabled, but the expression is empty!");
+ }
+ return AdvancedScorer::Create(scoring_spec, default_score, document_store,
+ schema_store);
case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE:
ICING_LOG(WARNING)
<< "JOIN_AGGREGATE_SCORE not implemented, falling back to NoScorer";
@@ -201,5 +221,7 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create(
}
}
+} // namespace scorer_factory
+
} // namespace lib
} // namespace icing
diff --git a/icing/scoring/scorer-factory.h b/icing/scoring/scorer-factory.h
new file mode 100644
index 0000000..8c19c75
--- /dev/null
+++ b/icing/scoring/scorer-factory.h
@@ -0,0 +1,46 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_SCORING_SCORER_FACTORY_H_
+#define ICING_SCORING_SCORER_FACTORY_H_
+
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/scoring/scorer.h"
+#include "icing/store/document-store.h"
+
+namespace icing {
+namespace lib {
+
+namespace scorer_factory {
+
+// Factory function to create a Scorer which does not take ownership of any
+// input components (DocumentStore), and all pointers must refer to valid
+// objects that outlive the created Scorer instance. The default score will be
+// returned only when the scorer fails to find or calculate a score for the
+// document.
+//
+// Returns:
+// A Scorer on success
+// FAILED_PRECONDITION on any null pointer input
+// INVALID_ARGUMENT if fails to create an instance
+libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
+ const ScoringSpecProto& scoring_spec, double default_score,
+ const DocumentStore* document_store, const SchemaStore* schema_store);
+
+} // namespace scorer_factory
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_SCORING_SCORER_FACTORY_H_
diff --git a/icing/scoring/scorer.h b/icing/scoring/scorer.h
index abdd5ca..ec48502 100644
--- a/icing/scoring/scorer.h
+++ b/icing/scoring/scorer.h
@@ -16,13 +16,11 @@
#define ICING_SCORING_SCORER_H_
#include <memory>
+#include <unordered_map>
-#include "icing/text_classifier/lib3/utils/base/statusor.h"
#include "icing/index/hit/doc-hit-info.h"
#include "icing/index/iterator/doc-hit-info-iterator.h"
#include "icing/proto/scoring.pb.h"
-#include "icing/store/document-id.h"
-#include "icing/store/document-store.h"
namespace icing {
namespace lib {
@@ -32,20 +30,6 @@ class Scorer {
public:
virtual ~Scorer() = default;
- // Factory function to create a Scorer which does not take ownership of any
- // input components (DocumentStore), and all pointers must refer to valid
- // objects that outlive the created Scorer instance. The default score will be
- // returned only when the scorer fails to find or calculate a score for the
- // document.
- //
- // Returns:
- // A Scorer on success
- // FAILED_PRECONDITION on any null pointer input
- // INVALID_ARGUMENT if fails to create an instance
- static libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create(
- const ScoringSpecProto& scoring_spec, double default_score,
- const DocumentStore* document_store, const SchemaStore* schema_store);
-
// Returns a non-negative score of a document. The score can be a
// document-associated score which comes from the DocumentProto directly, an
// accumulated score, a relevance score, or even an inferred score. If it
diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc
index 5432cde..7bbb8b7 100644
--- a/icing/scoring/scorer_test.cc
+++ b/icing/scoring/scorer_test.cc
@@ -28,6 +28,7 @@
#include "icing/proto/usage.pb.h"
#include "icing/schema-builder.h"
#include "icing/schema/schema-store.h"
+#include "icing/scoring/scorer-factory.h"
#include "icing/scoring/section-weights.h"
#include "icing/store/document-id.h"
#include "icing/store/document-store.h"
@@ -128,28 +129,29 @@ ScoringSpecProto CreateScoringSpecForRankingStrategy(
TEST_F(ScorerTest, CreationWithNullDocumentStoreShouldFail) {
EXPECT_THAT(
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
- /*default_score=*/0, /*document_store=*/nullptr,
- schema_store()),
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/0, /*document_store=*/nullptr, schema_store()),
StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_F(ScorerTest, CreationWithNullSchemaStoreShouldFail) {
- EXPECT_THAT(
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
- /*default_score=*/0, document_store(),
- /*schema_store=*/nullptr),
- StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
+ EXPECT_THAT(scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/0, document_store(),
+ /*schema_store=*/nullptr),
+ StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION));
}
TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
- /*default_score=*/10, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
// Non existent document id
DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1);
@@ -171,9 +173,10 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
- /*default_score=*/10, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
@@ -204,9 +207,10 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
- /*default_score=*/10, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
@@ -233,9 +237,10 @@ TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) {
document_store()->Put(test_document));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
- /*default_score=*/10, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0));
@@ -256,9 +261,10 @@ TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) {
document_store()->Put(test_document));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5));
@@ -281,9 +287,10 @@ TEST_F(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) {
document_store()->Put(test_document));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE),
- /*default_score=*/10, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE),
+ /*default_score=*/10, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10));
@@ -313,9 +320,10 @@ TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) {
document_store()->Put(test_document2));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo1 = DocHitInfo(document_id1);
DocHitInfo docHitInfo2 = DocHitInfo(document_id2);
@@ -340,19 +348,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -384,19 +395,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -428,19 +442,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT),
+ /*default_score=*/0, document_store(), schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -472,22 +489,25 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE2_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE3_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -535,22 +555,25 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE2_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE3_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -598,22 +621,25 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) {
// Create 3 scorers for 3 different usage types.
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer2,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE2_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE2_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer3,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE3_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE3_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0));
EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0));
@@ -649,9 +675,10 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) {
TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::NONE),
- /*default_score=*/3, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::NONE),
+ /*default_score=*/3, document_store(),
+ schema_store()));
DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0);
DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1);
@@ -661,10 +688,10 @@ TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) {
EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3));
ICING_ASSERT_OK_AND_ASSIGN(
- scorer,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::NONE),
- /*default_score=*/111, document_store(), schema_store()));
+ scorer, scorer_factory::Create(
+ CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::NONE),
+ /*default_score=*/111, document_store(), schema_store()));
docHitInfo1 = DocHitInfo(/*document_id_in=*/4);
docHitInfo2 = DocHitInfo(/*document_id_in=*/5);
@@ -688,10 +715,11 @@ TEST_F(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Scorer> scorer1,
- Scorer::Create(CreateScoringSpecForRankingStrategy(
- ScoringSpecProto::RankingStrategy::
- USAGE_TYPE1_LAST_USED_TIMESTAMP),
- /*default_score=*/0, document_store(), schema_store()));
+ scorer_factory::Create(CreateScoringSpecForRankingStrategy(
+ ScoringSpecProto::RankingStrategy::
+ USAGE_TYPE1_LAST_USED_TIMESTAMP),
+ /*default_score=*/0, document_store(),
+ schema_store()));
DocHitInfo docHitInfo = DocHitInfo(document_id);
// Create usage report for the maximum allowable timestamp.
diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc
index d5e64d8..571a112 100644
--- a/icing/scoring/scoring-processor.cc
+++ b/icing/scoring/scoring-processor.cc
@@ -15,6 +15,7 @@
#include "icing/scoring/scoring-processor.h"
#include <memory>
+#include <unordered_map>
#include <utility>
#include <vector>
@@ -25,6 +26,7 @@
#include "icing/proto/scoring.pb.h"
#include "icing/scoring/ranker.h"
#include "icing/scoring/scored-document-hit.h"
+#include "icing/scoring/scorer-factory.h"
#include "icing/scoring/scorer.h"
#include "icing/store/document-store.h"
#include "icing/util/status-macros.h"
@@ -50,10 +52,11 @@ ScoringProcessor::Create(const ScoringSpecProto& scoring_spec,
ICING_ASSIGN_OR_RETURN(
std::unique_ptr<Scorer> scorer,
- Scorer::Create(scoring_spec,
- is_descending_order ? kDefaultScoreInDescendingOrder
- : kDefaultScoreInAscendingOrder,
- document_store, schema_store));
+ scorer_factory::Create(scoring_spec,
+ is_descending_order
+ ? kDefaultScoreInDescendingOrder
+ : kDefaultScoreInAscendingOrder,
+ document_store, schema_store));
// Using `new` to access a non-public constructor.
return std::unique_ptr<ScoringProcessor>(
new ScoringProcessor(std::move(scorer)));
diff --git a/icing/store/document-store.cc b/icing/store/document-store.cc
index 9a33682..62599c8 100644
--- a/icing/store/document-store.cc
+++ b/icing/store/document-store.cc
@@ -1678,8 +1678,8 @@ DocumentStore::OptimizeInto(const std::string& new_directory,
}
TokenizedDocument tokenized_document(
std::move(tokenized_document_or).ValueOrDie());
- new_document_id_or =
- new_doc_store->Put(document_to_keep, tokenized_document.num_tokens());
+ new_document_id_or = new_doc_store->Put(
+ document_to_keep, tokenized_document.num_string_tokens());
} else {
// TODO(b/144458732): Implement a more robust version of
// TC_ASSIGN_OR_RETURN that can support error logging.
diff --git a/icing/store/dynamic-trie-key-mapper.h b/icing/store/dynamic-trie-key-mapper.h
index 35d2200..63e8488 100644
--- a/icing/store/dynamic-trie-key-mapper.h
+++ b/icing/store/dynamic-trie-key-mapper.h
@@ -60,12 +60,15 @@ class DynamicTrieKeyMapper : public KeyMapper<T, Formatter> {
Create(const Filesystem& filesystem, std::string_view base_dir,
int maximum_size_bytes);
- // Deletes all the files associated with the DynamicTrieKeyMapper. Returns
- // success or any encountered IO errors
+ // Deletes all the files associated with the DynamicTrieKeyMapper.
//
// base_dir : Base directory used to save all the files required to persist
// DynamicTrieKeyMapper. Should be the same as passed into
// Create().
+ //
+ // Returns
+ // OK on success
+ // INTERNAL_ERROR on I/O error
static libtextclassifier3::Status Delete(const Filesystem& filesystem,
std::string_view base_dir);
diff --git a/icing/store/dynamic-trie-key-mapper_test.cc b/icing/store/dynamic-trie-key-mapper_test.cc
index add88bb..fd56170 100644
--- a/icing/store/dynamic-trie-key-mapper_test.cc
+++ b/icing/store/dynamic-trie-key-mapper_test.cc
@@ -14,21 +14,21 @@
#include "icing/store/dynamic-trie-key-mapper.h"
+#include <string>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "icing/file/filesystem.h"
#include "icing/store/document-id.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/tmp-directory.h"
-using ::testing::_;
-using ::testing::HasSubstr;
-using ::testing::IsEmpty;
-using ::testing::Pair;
-using ::testing::UnorderedElementsAre;
-
namespace icing {
namespace lib {
+
namespace {
+
constexpr int kMaxDynamicTrieKeyMapperSize = 3 * 1024 * 1024; // 3 MiB
class DynamicTrieKeyMapperTest : public testing::Test {
@@ -43,168 +43,25 @@ class DynamicTrieKeyMapperTest : public testing::Test {
Filesystem filesystem_;
};
-std::unordered_map<std::string, DocumentId> GetAllKeyValuePairs(
- const DynamicTrieKeyMapper<DocumentId>* key_mapper) {
- std::unordered_map<std::string, DocumentId> ret;
-
- std::unique_ptr<typename KeyMapper<DocumentId>::Iterator> itr =
- key_mapper->GetIterator();
- while (itr->Advance()) {
- ret.emplace(itr->GetKey(), itr->GetValue());
- }
- return ret;
-}
-
TEST_F(DynamicTrieKeyMapperTest, InvalidBaseDir) {
- ASSERT_THAT(DynamicTrieKeyMapper<DocumentId>::Create(
- filesystem_, "/dev/null", kMaxDynamicTrieKeyMapperSize)
- .status()
- .error_message(),
- HasSubstr("Failed to create DynamicTrieKeyMapper"));
+ EXPECT_THAT(DynamicTrieKeyMapper<DocumentId>::Create(
+ filesystem_, "/dev/null", kMaxDynamicTrieKeyMapperSize),
+ StatusIs(libtextclassifier3::StatusCode::INTERNAL));
}
TEST_F(DynamicTrieKeyMapperTest, NegativeMaxKeyMapperSizeReturnsInternalError) {
- ASSERT_THAT(
+ EXPECT_THAT(
DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, -1),
StatusIs(libtextclassifier3::StatusCode::INTERNAL));
}
TEST_F(DynamicTrieKeyMapperTest, TooLargeMaxKeyMapperSizeReturnsInternalError) {
- ASSERT_THAT(DynamicTrieKeyMapper<DocumentId>::Create(
+ EXPECT_THAT(DynamicTrieKeyMapper<DocumentId>::Create(
filesystem_, base_dir_, std::numeric_limits<int>::max()),
StatusIs(libtextclassifier3::StatusCode::INTERNAL));
}
-TEST_F(DynamicTrieKeyMapperTest, CreateNewKeyMapper) {
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper,
- DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_,
- kMaxDynamicTrieKeyMapperSize));
- EXPECT_THAT(key_mapper->num_keys(), 0);
-}
-
-TEST_F(DynamicTrieKeyMapperTest, CanUpdateSameKeyMultipleTimes) {
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper,
- DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_,
- kMaxDynamicTrieKeyMapperSize));
-
- ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
- ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 50));
-
- EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100));
-
- ICING_EXPECT_OK(key_mapper->Put("default-google.com", 200));
- EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(200));
- EXPECT_THAT(key_mapper->num_keys(), 2);
-
- ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300));
- EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300));
- EXPECT_THAT(key_mapper->num_keys(), 2);
-}
-
-TEST_F(DynamicTrieKeyMapperTest, GetOrPutOk) {
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper,
- DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_,
- kMaxDynamicTrieKeyMapperSize));
-
- EXPECT_THAT(key_mapper->Get("foo"),
- StatusIs(libtextclassifier3::StatusCode::NOT_FOUND));
- EXPECT_THAT(key_mapper->GetOrPut("foo", 1), IsOkAndHolds(1));
- EXPECT_THAT(key_mapper->Get("foo"), IsOkAndHolds(1));
-}
-
-TEST_F(DynamicTrieKeyMapperTest, CanPersistToDiskRegularly) {
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper,
- DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_,
- kMaxDynamicTrieKeyMapperSize));
- // Can persist an empty DynamicTrieKeyMapper.
- ICING_EXPECT_OK(key_mapper->PersistToDisk());
- EXPECT_THAT(key_mapper->num_keys(), 0);
-
- // Can persist the smallest DynamicTrieKeyMapper.
- ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
- ICING_EXPECT_OK(key_mapper->PersistToDisk());
- EXPECT_THAT(key_mapper->num_keys(), 1);
- EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100));
-
- // Can continue to add keys after PersistToDisk().
- ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 200));
- EXPECT_THAT(key_mapper->num_keys(), 2);
- EXPECT_THAT(key_mapper->Get("default-youtube.com"), IsOkAndHolds(200));
-
- // Can continue to update the same key after PersistToDisk().
- ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300));
- EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300));
- EXPECT_THAT(key_mapper->num_keys(), 2);
-}
-
-TEST_F(DynamicTrieKeyMapperTest, CanUseAcrossMultipleInstances) {
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper,
- DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_,
- kMaxDynamicTrieKeyMapperSize));
- ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
- ICING_EXPECT_OK(key_mapper->PersistToDisk());
-
- key_mapper.reset();
- ICING_ASSERT_OK_AND_ASSIGN(
- key_mapper, DynamicTrieKeyMapper<DocumentId>::Create(
- filesystem_, base_dir_, kMaxDynamicTrieKeyMapperSize));
- EXPECT_THAT(key_mapper->num_keys(), 1);
- EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100));
-
- // Can continue to read/write to the KeyMapper.
- ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 200));
- ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300));
- EXPECT_THAT(key_mapper->num_keys(), 2);
- EXPECT_THAT(key_mapper->Get("default-youtube.com"), IsOkAndHolds(200));
- EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300));
-}
-
-TEST_F(DynamicTrieKeyMapperTest, CanDeleteAndRestartKeyMapping) {
- // Can delete even if there's nothing there
- ICING_EXPECT_OK(
- DynamicTrieKeyMapper<DocumentId>::Delete(filesystem_, base_dir_));
-
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper,
- DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_,
- kMaxDynamicTrieKeyMapperSize));
- ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
- ICING_EXPECT_OK(key_mapper->PersistToDisk());
- ICING_EXPECT_OK(
- DynamicTrieKeyMapper<DocumentId>::Delete(filesystem_, base_dir_));
-
- key_mapper.reset();
- ICING_ASSERT_OK_AND_ASSIGN(
- key_mapper, DynamicTrieKeyMapper<DocumentId>::Create(
- filesystem_, base_dir_, kMaxDynamicTrieKeyMapperSize));
- EXPECT_THAT(key_mapper->num_keys(), 0);
- ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
- EXPECT_THAT(key_mapper->num_keys(), 1);
-}
-
-TEST_F(DynamicTrieKeyMapperTest, Iterator) {
- ICING_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper,
- DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_,
- kMaxDynamicTrieKeyMapperSize));
- EXPECT_THAT(GetAllKeyValuePairs(key_mapper.get()), IsEmpty());
-
- ICING_EXPECT_OK(key_mapper->Put("foo", /*value=*/1));
- ICING_EXPECT_OK(key_mapper->Put("bar", /*value=*/2));
- EXPECT_THAT(GetAllKeyValuePairs(key_mapper.get()),
- UnorderedElementsAre(Pair("foo", 1), Pair("bar", 2)));
-
- ICING_EXPECT_OK(key_mapper->Put("baz", /*value=*/3));
- EXPECT_THAT(
- GetAllKeyValuePairs(key_mapper.get()),
- UnorderedElementsAre(Pair("foo", 1), Pair("bar", 2), Pair("baz", 3)));
-}
-
} // namespace
+
} // namespace lib
} // namespace icing
diff --git a/icing/store/key-mapper_benchmark.cc b/icing/store/key-mapper_benchmark.cc
new file mode 100644
index 0000000..b649bc7
--- /dev/null
+++ b/icing/store/key-mapper_benchmark.cc
@@ -0,0 +1,316 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <random>
+#include <string>
+#include <unordered_map>
+
+#include "testing/base/public/benchmark.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/absl_ports/str_cat.h"
+#include "icing/file/destructible-directory.h"
+#include "icing/file/filesystem.h"
+#include "icing/store/dynamic-trie-key-mapper.h"
+#include "icing/store/key-mapper.h"
+#include "icing/store/persistent-hash-map-key-mapper.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/random-string.h"
+#include "icing/testing/tmp-directory.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::Eq;
+using ::testing::Not;
+
+class KeyMapperBenchmark {
+ public:
+ static constexpr int kKeyLength = 20;
+
+ explicit KeyMapperBenchmark()
+ : clock(std::make_unique<Clock>()),
+ base_dir(GetTestTempDir() + "/key_mapper_benchmark"),
+ random_engine(/*seed=*/12345) {}
+
+ std::string GenerateUniqueRandomKeyValuePair(int val,
+ std::string_view prefix = "") {
+ std::string rand_str = absl_ports::StrCat(
+ prefix, RandomString(kAlNumAlphabet, kKeyLength, &random_engine));
+ while (random_kvps_map.find(rand_str) != random_kvps_map.end()) {
+ rand_str = absl_ports::StrCat(
+ std::string(prefix),
+ RandomString(kAlNumAlphabet, kKeyLength, &random_engine));
+ }
+ std::pair<std::string, int> entry(rand_str, val);
+ random_kvps.push_back(entry);
+ random_kvps_map.insert(entry);
+ return rand_str;
+ }
+
+ template <typename UnknownKeyMapperType>
+ libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<int>>> CreateKeyMapper(
+ int max_num_entries) {
+ return absl_ports::InvalidArgumentError("Unknown type");
+ }
+
+ template <>
+ libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<int>>>
+ CreateKeyMapper<DynamicTrieKeyMapper<int>>(int max_num_entries) {
+ return DynamicTrieKeyMapper<int>::Create(
+ filesystem, base_dir,
+ /*maximum_size_bytes=*/128 * 1024 * 1024);
+ }
+
+ template <>
+ libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<int>>>
+ CreateKeyMapper<PersistentHashMapKeyMapper<int>>(int max_num_entries) {
+ return PersistentHashMapKeyMapper<int>::Create(
+ filesystem, base_dir, max_num_entries,
+ /*average_kv_byte_size=*/kKeyLength + 1 + sizeof(int),
+ /*max_load_factor_percent=*/100);
+ }
+
+ std::unique_ptr<Clock> clock;
+
+ Filesystem filesystem;
+ std::string base_dir;
+
+ std::default_random_engine random_engine;
+ std::vector<std::pair<std::string, int>> random_kvps;
+ std::unordered_map<std::string, int> random_kvps_map;
+};
+
+// Benchmark the total time of putting num_keys (specified by Arg) unique random
+// key value pairs.
+template <typename KeyMapperType>
+void BM_PutMany(benchmark::State& state) {
+ int num_keys = state.range(0);
+
+ KeyMapperBenchmark benchmark;
+ for (int i = 0; i < num_keys; ++i) {
+ benchmark.GenerateUniqueRandomKeyValuePair(i);
+ }
+
+ for (auto _ : state) {
+ state.PauseTiming();
+ benchmark.filesystem.DeleteDirectoryRecursively(benchmark.base_dir.c_str());
+ DestructibleDirectory ddir(&benchmark.filesystem, benchmark.base_dir);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<KeyMapper<int>> key_mapper,
+ benchmark.CreateKeyMapper<KeyMapperType>(num_keys));
+ ASSERT_THAT(key_mapper->num_keys(), Eq(0));
+ state.ResumeTiming();
+
+ for (int i = 0; i < num_keys; ++i) {
+ ICING_ASSERT_OK(key_mapper->Put(benchmark.random_kvps[i].first,
+ benchmark.random_kvps[i].second));
+ }
+
+ // Explicit calls PersistToDisk.
+ ICING_ASSERT_OK(key_mapper->PersistToDisk());
+
+ state.PauseTiming();
+ ASSERT_THAT(key_mapper->num_keys(), Eq(num_keys));
+ // The destructor of IcingDynamicTrie doesn't implicitly call PersistToDisk,
+ // while PersistentHashMap does. Thus, we reset the unique pointer to invoke
+ // destructor in the pause timing block, so in this case PersistToDisk will
+ // be included into the benchmark only once.
+ key_mapper.reset();
+ state.ResumeTiming();
+ }
+}
+BENCHMARK(BM_PutMany<DynamicTrieKeyMapper<int>>)
+ ->Arg(1 << 10)
+ ->Arg(1 << 11)
+ ->Arg(1 << 12)
+ ->Arg(1 << 13)
+ ->Arg(1 << 14)
+ ->Arg(1 << 15)
+ ->Arg(1 << 16)
+ ->Arg(1 << 17)
+ ->Arg(1 << 18)
+ ->Arg(1 << 19)
+ ->Arg(1 << 20);
+BENCHMARK(BM_PutMany<PersistentHashMapKeyMapper<int>>)
+ ->Arg(1 << 10)
+ ->Arg(1 << 11)
+ ->Arg(1 << 12)
+ ->Arg(1 << 13)
+ ->Arg(1 << 14)
+ ->Arg(1 << 15)
+ ->Arg(1 << 16)
+ ->Arg(1 << 17)
+ ->Arg(1 << 18)
+ ->Arg(1 << 19)
+ ->Arg(1 << 20);
+
+// Benchmark the average time of putting 1 unique random key value pair. The
+// result will be affected by # of iterations, so use --benchmark_max_iters=k
+// and --benchmark_min_iters=k to force # of iterations to be fixed.
+template <typename KeyMapperType>
+void BM_Put(benchmark::State& state) {
+ KeyMapperBenchmark benchmark;
+ benchmark.filesystem.DeleteDirectoryRecursively(benchmark.base_dir.c_str());
+ DestructibleDirectory ddir(&benchmark.filesystem, benchmark.base_dir);
+
+ // The overhead of state.PauseTiming is too large and affects the benchmark
+ // result a lot, so pre-generate enough kvps to avoid calling too many times
+ // state.PauseTiming for GenerateUniqueRandomKeyValuePair in the benchmark
+ // for-loop.
+ int MAX_PREGEN_KVPS = 1 << 22;
+ for (int i = 0; i < MAX_PREGEN_KVPS; ++i) {
+ benchmark.GenerateUniqueRandomKeyValuePair(i);
+ }
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<KeyMapper<int>> key_mapper,
+ benchmark.CreateKeyMapper<KeyMapperType>(/*max_num_entries=*/1 << 22));
+ ASSERT_THAT(key_mapper->num_keys(), Eq(0));
+
+ int cnt = 0;
+ for (auto _ : state) {
+ if (cnt >= MAX_PREGEN_KVPS) {
+ state.PauseTiming();
+ benchmark.GenerateUniqueRandomKeyValuePair(cnt);
+ state.ResumeTiming();
+ }
+
+ ICING_ASSERT_OK(key_mapper->Put(benchmark.random_kvps[cnt].first,
+ benchmark.random_kvps[cnt].second));
+ ++cnt;
+ }
+}
+BENCHMARK(BM_Put<DynamicTrieKeyMapper<int>>);
+BENCHMARK(BM_Put<PersistentHashMapKeyMapper<int>>);
+
+// Benchmark the average time of getting 1 existing key value pair from the key
+// mapper with size num_keys (specified by Arg).
+template <typename KeyMapperType>
+void BM_Get(benchmark::State& state) {
+ int num_keys = state.range(0);
+
+ KeyMapperBenchmark benchmark;
+ benchmark.filesystem.DeleteDirectoryRecursively(benchmark.base_dir.c_str());
+ DestructibleDirectory ddir(&benchmark.filesystem, benchmark.base_dir);
+
+ // Create a key mapper with num_keys entries.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<KeyMapper<int>> key_mapper,
+ benchmark.CreateKeyMapper<KeyMapperType>(num_keys));
+ for (int i = 0; i < num_keys; ++i) {
+ ICING_ASSERT_OK(
+ key_mapper->Put(benchmark.GenerateUniqueRandomKeyValuePair(i), i));
+ }
+ ASSERT_THAT(key_mapper->num_keys(), Eq(num_keys));
+
+ std::uniform_int_distribution<> distrib(0, num_keys - 1);
+ std::default_random_engine e(/*seed=*/12345);
+ for (auto _ : state) {
+ int idx = distrib(e);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ int val, key_mapper->Get(benchmark.random_kvps[idx].first));
+ ASSERT_THAT(val, Eq(benchmark.random_kvps[idx].second));
+ }
+}
+BENCHMARK(BM_Get<DynamicTrieKeyMapper<int>>)
+ ->Arg(1 << 10)
+ ->Arg(1 << 11)
+ ->Arg(1 << 12)
+ ->Arg(1 << 13)
+ ->Arg(1 << 14)
+ ->Arg(1 << 15)
+ ->Arg(1 << 16)
+ ->Arg(1 << 17)
+ ->Arg(1 << 18)
+ ->Arg(1 << 19)
+ ->Arg(1 << 20);
+BENCHMARK(BM_Get<PersistentHashMapKeyMapper<int>>)
+ ->Arg(1 << 10)
+ ->Arg(1 << 11)
+ ->Arg(1 << 12)
+ ->Arg(1 << 13)
+ ->Arg(1 << 14)
+ ->Arg(1 << 15)
+ ->Arg(1 << 16)
+ ->Arg(1 << 17)
+ ->Arg(1 << 18)
+ ->Arg(1 << 19)
+ ->Arg(1 << 20);
+
+// Benchmark the total time of iterating through all key value pairs of the key
+// mapper with size num_keys (specified by Arg).
+template <typename KeyMapperType>
+void BM_Iterator(benchmark::State& state) {
+ int num_keys = state.range(0);
+
+ KeyMapperBenchmark benchmark;
+ benchmark.filesystem.DeleteDirectoryRecursively(benchmark.base_dir.c_str());
+ DestructibleDirectory ddir(&benchmark.filesystem, benchmark.base_dir);
+
+ // Create a key mapper with num_keys entries.
+ ICING_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<KeyMapper<int>> key_mapper,
+ benchmark.CreateKeyMapper<KeyMapperType>(num_keys));
+ for (int i = 0; i < num_keys; ++i) {
+ ICING_ASSERT_OK(
+ key_mapper->Put(benchmark.GenerateUniqueRandomKeyValuePair(i), i));
+ }
+ ASSERT_THAT(key_mapper->num_keys(), Eq(num_keys));
+
+ for (auto _ : state) {
+ auto iter = key_mapper->GetIterator();
+ int cnt = 0;
+ while (iter->Advance()) {
+ ++cnt;
+ std::string key(iter->GetKey());
+ int value = iter->GetValue();
+ auto it = benchmark.random_kvps_map.find(key);
+ ASSERT_THAT(it, Not(Eq(benchmark.random_kvps_map.end())));
+ ASSERT_THAT(it->second, Eq(value));
+ }
+ ASSERT_THAT(cnt, Eq(num_keys));
+ }
+}
+BENCHMARK(BM_Iterator<DynamicTrieKeyMapper<int>>)
+ ->Arg(1 << 10)
+ ->Arg(1 << 11)
+ ->Arg(1 << 12)
+ ->Arg(1 << 13)
+ ->Arg(1 << 14)
+ ->Arg(1 << 15)
+ ->Arg(1 << 16)
+ ->Arg(1 << 17)
+ ->Arg(1 << 18)
+ ->Arg(1 << 19)
+ ->Arg(1 << 20);
+BENCHMARK(BM_Iterator<PersistentHashMapKeyMapper<int>>)
+ ->Arg(1 << 10)
+ ->Arg(1 << 11)
+ ->Arg(1 << 12)
+ ->Arg(1 << 13)
+ ->Arg(1 << 14)
+ ->Arg(1 << 15)
+ ->Arg(1 << 16)
+ ->Arg(1 << 17)
+ ->Arg(1 << 18)
+ ->Arg(1 << 19)
+ ->Arg(1 << 20);
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/store/key-mapper_test.cc b/icing/store/key-mapper_test.cc
new file mode 100644
index 0000000..682888d
--- /dev/null
+++ b/icing/store/key-mapper_test.cc
@@ -0,0 +1,215 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/store/key-mapper.h"
+
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <unordered_map>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/absl_ports/canonical_errors.h"
+#include "icing/file/filesystem.h"
+#include "icing/store/document-id.h"
+#include "icing/store/dynamic-trie-key-mapper.h"
+#include "icing/store/persistent-hash-map-key-mapper.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/tmp-directory.h"
+
+using ::testing::IsEmpty;
+using ::testing::Pair;
+using ::testing::UnorderedElementsAre;
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+constexpr int kMaxDynamicTrieKeyMapperSize = 3 * 1024 * 1024; // 3 MiB
+
+template <typename T>
+class KeyMapperTest : public ::testing::Test {
+ protected:
+ using KeyMapperType = T;
+
+ void SetUp() override { base_dir_ = GetTestTempDir() + "/key_mapper"; }
+
+ void TearDown() override {
+ filesystem_.DeleteDirectoryRecursively(base_dir_.c_str());
+ }
+
+ template <typename UnknownKeyMapperType>
+ libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<DocumentId>>>
+ CreateKeyMapper() {
+ return absl_ports::InvalidArgumentError("Unknown type");
+ }
+
+ template <>
+ libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<DocumentId>>>
+ CreateKeyMapper<DynamicTrieKeyMapper<DocumentId>>() {
+ return DynamicTrieKeyMapper<DocumentId>::Create(
+ filesystem_, base_dir_, kMaxDynamicTrieKeyMapperSize);
+ }
+
+ template <>
+ libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<DocumentId>>>
+ CreateKeyMapper<PersistentHashMapKeyMapper<DocumentId>>() {
+ return PersistentHashMapKeyMapper<DocumentId>::Create(filesystem_,
+ base_dir_);
+ }
+
+ std::string base_dir_;
+ Filesystem filesystem_;
+};
+
+using TestTypes = ::testing::Types<DynamicTrieKeyMapper<DocumentId>,
+ PersistentHashMapKeyMapper<DocumentId>>;
+TYPED_TEST_SUITE(KeyMapperTest, TestTypes);
+
+std::unordered_map<std::string, DocumentId> GetAllKeyValuePairs(
+ const KeyMapper<DocumentId>* key_mapper) {
+ std::unordered_map<std::string, DocumentId> ret;
+
+ std::unique_ptr<typename KeyMapper<DocumentId>::Iterator> itr =
+ key_mapper->GetIterator();
+ while (itr->Advance()) {
+ ret.emplace(itr->GetKey(), itr->GetValue());
+ }
+ return ret;
+}
+
+TYPED_TEST(KeyMapperTest, CreateNewKeyMapper) {
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper,
+ this->template CreateKeyMapper<TypeParam>());
+ EXPECT_THAT(key_mapper->num_keys(), 0);
+}
+
+TYPED_TEST(KeyMapperTest, CanUpdateSameKeyMultipleTimes) {
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper,
+ this->template CreateKeyMapper<TypeParam>());
+
+ ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
+ ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 50));
+
+ EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100));
+
+ ICING_EXPECT_OK(key_mapper->Put("default-google.com", 200));
+ EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(200));
+ EXPECT_THAT(key_mapper->num_keys(), 2);
+
+ ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300));
+ EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300));
+ EXPECT_THAT(key_mapper->num_keys(), 2);
+}
+
+TYPED_TEST(KeyMapperTest, GetOrPutOk) {
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper,
+ this->template CreateKeyMapper<TypeParam>());
+
+ EXPECT_THAT(key_mapper->Get("foo"),
+ StatusIs(libtextclassifier3::StatusCode::NOT_FOUND));
+ EXPECT_THAT(key_mapper->GetOrPut("foo", 1), IsOkAndHolds(1));
+ EXPECT_THAT(key_mapper->Get("foo"), IsOkAndHolds(1));
+}
+
+TYPED_TEST(KeyMapperTest, CanPersistToDiskRegularly) {
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper,
+ this->template CreateKeyMapper<TypeParam>());
+
+ // Can persist an empty DynamicTrieKeyMapper.
+ ICING_EXPECT_OK(key_mapper->PersistToDisk());
+ EXPECT_THAT(key_mapper->num_keys(), 0);
+
+ // Can persist the smallest DynamicTrieKeyMapper.
+ ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
+ ICING_EXPECT_OK(key_mapper->PersistToDisk());
+ EXPECT_THAT(key_mapper->num_keys(), 1);
+ EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100));
+
+ // Can continue to add keys after PersistToDisk().
+ ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 200));
+ EXPECT_THAT(key_mapper->num_keys(), 2);
+ EXPECT_THAT(key_mapper->Get("default-youtube.com"), IsOkAndHolds(200));
+
+ // Can continue to update the same key after PersistToDisk().
+ ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300));
+ EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300));
+ EXPECT_THAT(key_mapper->num_keys(), 2);
+}
+
+TYPED_TEST(KeyMapperTest, CanUseAcrossMultipleInstances) {
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper,
+ this->template CreateKeyMapper<TypeParam>());
+ ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
+ ICING_EXPECT_OK(key_mapper->PersistToDisk());
+
+ key_mapper.reset();
+
+ ICING_ASSERT_OK_AND_ASSIGN(key_mapper,
+ this->template CreateKeyMapper<TypeParam>());
+ EXPECT_THAT(key_mapper->num_keys(), 1);
+ EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100));
+
+ // Can continue to read/write to the KeyMapper.
+ ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 200));
+ ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300));
+ EXPECT_THAT(key_mapper->num_keys(), 2);
+ EXPECT_THAT(key_mapper->Get("default-youtube.com"), IsOkAndHolds(200));
+ EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300));
+}
+
+TYPED_TEST(KeyMapperTest, CanDeleteAndRestartKeyMapping) {
+ // Can delete even if there's nothing there
+ ICING_EXPECT_OK(
+ TestFixture::KeyMapperType::Delete(this->filesystem_, this->base_dir_));
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper,
+ this->template CreateKeyMapper<TypeParam>());
+ ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
+ ICING_EXPECT_OK(key_mapper->PersistToDisk());
+ ICING_EXPECT_OK(
+ TestFixture::KeyMapperType::Delete(this->filesystem_, this->base_dir_));
+
+ key_mapper.reset();
+ ICING_ASSERT_OK_AND_ASSIGN(key_mapper,
+ this->template CreateKeyMapper<TypeParam>());
+ EXPECT_THAT(key_mapper->num_keys(), 0);
+ ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100));
+ EXPECT_THAT(key_mapper->num_keys(), 1);
+}
+
+TYPED_TEST(KeyMapperTest, Iterator) {
+ ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper,
+ this->template CreateKeyMapper<TypeParam>());
+ EXPECT_THAT(GetAllKeyValuePairs(key_mapper.get()), IsEmpty());
+
+ ICING_EXPECT_OK(key_mapper->Put("foo", /*value=*/1));
+ ICING_EXPECT_OK(key_mapper->Put("bar", /*value=*/2));
+ EXPECT_THAT(GetAllKeyValuePairs(key_mapper.get()),
+ UnorderedElementsAre(Pair("foo", 1), Pair("bar", 2)));
+
+ ICING_EXPECT_OK(key_mapper->Put("baz", /*value=*/3));
+ EXPECT_THAT(
+ GetAllKeyValuePairs(key_mapper.get()),
+ UnorderedElementsAre(Pair("foo", 1), Pair("bar", 2), Pair("baz", 3)));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/store/persistent-hash-map-key-mapper.h b/icing/store/persistent-hash-map-key-mapper.h
new file mode 100644
index 0000000..a13ec11
--- /dev/null
+++ b/icing/store/persistent-hash-map-key-mapper.h
@@ -0,0 +1,209 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef ICING_STORE_PERSISTENT_HASH_MAP_KEY_MAPPER_H_
+#define ICING_STORE_PERSISTENT_HASH_MAP_KEY_MAPPER_H_
+
+#include <cstdint>
+#include <memory>
+#include <string>
+#include <string_view>
+#include <type_traits>
+
+#include "icing/text_classifier/lib3/utils/base/status.h"
+#include "icing/text_classifier/lib3/utils/base/statusor.h"
+#include "icing/absl_ports/str_join.h"
+#include "icing/file/filesystem.h"
+#include "icing/file/persistent-hash-map.h"
+#include "icing/store/key-mapper.h"
+#include "icing/util/crc32.h"
+#include "icing/util/status-macros.h"
+
+namespace icing {
+namespace lib {
+
+// File-backed mapping between the string key and a trivially copyable value
+// type.
+template <typename T, typename Formatter = absl_ports::DefaultFormatter>
+class PersistentHashMapKeyMapper : public KeyMapper<T, Formatter> {
+ public:
+ // Returns an initialized instance of PersistentHashMapKeyMapper that can
+ // immediately handle read/write operations.
+ // Returns any encountered IO errors.
+ //
+ // filesystem: Object to make system level calls
+ // base_dir : Base directory used to save all the files required to persist
+ // PersistentHashMapKeyMapper. If this base_dir was previously used
+ // to create a PersistentHashMapKeyMapper, then this existing data
+ // would be loaded. Otherwise, an empty PersistentHashMapKeyMapper
+ // would be created.
+ // max_num_entries: max # of kvps. It will be used to compute 3 storages size.
+ // average_kv_byte_size: average byte size of a single key + serialized value.
+ // It will be used to compute kv_storage size.
+ // max_load_factor_percent: percentage of the max loading for the hash map.
+ // load_factor_percent = 100 * num_keys / num_buckets
+ // If load_factor_percent exceeds
+ // max_load_factor_percent, then rehash will be
+ // invoked (and # of buckets will be doubled).
+ // Note that load_factor_percent exceeding 100 is
+ // considered valid.
+ static libtextclassifier3::StatusOr<
+ std::unique_ptr<PersistentHashMapKeyMapper<T, Formatter>>>
+ Create(const Filesystem& filesystem, std::string_view base_dir,
+ int32_t max_num_entries = PersistentHashMap::Entry::kMaxNumEntries,
+ int32_t average_kv_byte_size =
+ PersistentHashMap::Options::kDefaultAverageKVByteSize,
+ int32_t max_load_factor_percent =
+ PersistentHashMap::Options::kDefaultMaxLoadFactorPercent);
+
+ // Deletes all the files associated with the PersistentHashMapKeyMapper.
+ //
+ // base_dir : Base directory used to save all the files required to persist
+ // PersistentHashMapKeyMapper. Should be the same as passed into
+ // Create().
+ //
+ // Returns:
+ // OK on success
+ // INTERNAL_ERROR on I/O error
+ static libtextclassifier3::Status Delete(const Filesystem& filesystem,
+ std::string_view base_dir);
+
+ ~PersistentHashMapKeyMapper() override = default;
+
+ libtextclassifier3::Status Put(std::string_view key, T value) override {
+ return persistent_hash_map_->Put(key, &value);
+ }
+
+ libtextclassifier3::StatusOr<T> GetOrPut(std::string_view key,
+ T next_value) override {
+ ICING_RETURN_IF_ERROR(persistent_hash_map_->GetOrPut(key, &next_value));
+ return next_value;
+ }
+
+ libtextclassifier3::StatusOr<T> Get(std::string_view key) const override {
+ T value;
+ ICING_RETURN_IF_ERROR(persistent_hash_map_->Get(key, &value));
+ return value;
+ }
+
+ bool Delete(std::string_view key) override {
+ return persistent_hash_map_->Delete(key).ok();
+ }
+
+ std::unique_ptr<typename KeyMapper<T, Formatter>::Iterator> GetIterator()
+ const override {
+ return std::make_unique<PersistentHashMapKeyMapper<T, Formatter>::Iterator>(
+ persistent_hash_map_.get());
+ }
+
+ int32_t num_keys() const override { return persistent_hash_map_->size(); }
+
+ libtextclassifier3::Status PersistToDisk() override {
+ return persistent_hash_map_->PersistToDisk();
+ }
+
+ libtextclassifier3::StatusOr<int64_t> GetDiskUsage() const override {
+ return persistent_hash_map_->GetDiskUsage();
+ }
+
+ libtextclassifier3::StatusOr<int64_t> GetElementsSize() const override {
+ return persistent_hash_map_->GetElementsSize();
+ }
+
+ libtextclassifier3::StatusOr<Crc32> ComputeChecksum() override {
+ return persistent_hash_map_->ComputeChecksum();
+ }
+
+ private:
+ class Iterator : public KeyMapper<T, Formatter>::Iterator {
+ public:
+ explicit Iterator(const PersistentHashMap* persistent_hash_map)
+ : itr_(persistent_hash_map->GetIterator()) {}
+
+ ~Iterator() override = default;
+
+ bool Advance() override { return itr_.Advance(); }
+
+ std::string_view GetKey() const override { return itr_.GetKey(); }
+
+ T GetValue() const override {
+ T value;
+ memcpy(&value, itr_.GetValue(), sizeof(T));
+ return value;
+ }
+
+ private:
+ PersistentHashMap::Iterator itr_;
+ };
+
+ static constexpr std::string_view kKeyMapperDir = "key_mapper_dir";
+
+ // Use PersistentHashMapKeyMapper::Create() to instantiate.
+ explicit PersistentHashMapKeyMapper(
+ std::unique_ptr<PersistentHashMap> persistent_hash_map)
+ : persistent_hash_map_(std::move(persistent_hash_map)) {}
+
+ std::unique_ptr<PersistentHashMap> persistent_hash_map_;
+
+ static_assert(std::is_trivially_copyable<T>::value,
+ "T must be trivially copyable");
+};
+
+template <typename T, typename Formatter>
+/* static */ libtextclassifier3::StatusOr<
+ std::unique_ptr<PersistentHashMapKeyMapper<T, Formatter>>>
+PersistentHashMapKeyMapper<T, Formatter>::Create(
+ const Filesystem& filesystem, std::string_view base_dir,
+ int32_t max_num_entries, int32_t average_kv_byte_size,
+ int32_t max_load_factor_percent) {
+ const std::string key_mapper_dir =
+ absl_ports::StrCat(base_dir, "/", kKeyMapperDir);
+ if (!filesystem.CreateDirectoryRecursively(key_mapper_dir.c_str())) {
+ return absl_ports::InternalError(absl_ports::StrCat(
+ "Failed to create PersistentHashMapKeyMapper directory: ",
+ key_mapper_dir));
+ }
+
+ ICING_ASSIGN_OR_RETURN(
+ std::unique_ptr<PersistentHashMap> persistent_hash_map,
+ PersistentHashMap::Create(
+ filesystem, key_mapper_dir,
+ PersistentHashMap::Options(
+ /*value_type_size_in=*/sizeof(T),
+ /*max_num_entries_in=*/max_num_entries,
+ /*max_load_factor_percent_in=*/max_load_factor_percent,
+ /*average_kv_byte_size_in=*/average_kv_byte_size)));
+ return std::unique_ptr<PersistentHashMapKeyMapper<T, Formatter>>(
+ new PersistentHashMapKeyMapper<T, Formatter>(
+ std::move(persistent_hash_map)));
+}
+
+template <typename T, typename Formatter>
+/* static */ libtextclassifier3::Status
+PersistentHashMapKeyMapper<T, Formatter>::Delete(const Filesystem& filesystem,
+ std::string_view base_dir) {
+ const std::string key_mapper_dir =
+ absl_ports::StrCat(base_dir, "/", kKeyMapperDir);
+ if (!filesystem.DeleteDirectoryRecursively(key_mapper_dir.c_str())) {
+ return absl_ports::InternalError(absl_ports::StrCat(
+ "Failed to delete PersistentHashMapKeyMapper directory: ",
+ key_mapper_dir));
+ }
+ return libtextclassifier3::Status::OK;
+}
+
+} // namespace lib
+} // namespace icing
+
+#endif // ICING_STORE_PERSISTENT_HASH_MAP_KEY_MAPPER_H_
diff --git a/icing/store/persistent-hash-map-key-mapper_test.cc b/icing/store/persistent-hash-map-key-mapper_test.cc
new file mode 100644
index 0000000..c937c43
--- /dev/null
+++ b/icing/store/persistent-hash-map-key-mapper_test.cc
@@ -0,0 +1,52 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/store/persistent-hash-map-key-mapper.h"
+
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/file/filesystem.h"
+#include "icing/store/document-id.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/tmp-directory.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+class PersistentHashMapKeyMapperTest : public testing::Test {
+ protected:
+ void SetUp() override { base_dir_ = GetTestTempDir() + "/key_mapper"; }
+
+ void TearDown() override {
+ filesystem_.DeleteDirectoryRecursively(base_dir_.c_str());
+ }
+
+ std::string base_dir_;
+ Filesystem filesystem_;
+};
+
+TEST_F(PersistentHashMapKeyMapperTest, InvalidBaseDir) {
+ EXPECT_THAT(
+ PersistentHashMapKeyMapper<DocumentId>::Create(filesystem_, "/dev/null"),
+ StatusIs(libtextclassifier3::StatusCode::INTERNAL));
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/icing/testing/common-matchers.h b/icing/testing/common-matchers.h
index f2738e3..e090800 100644
--- a/icing/testing/common-matchers.h
+++ b/icing/testing/common-matchers.h
@@ -15,7 +15,10 @@
#ifndef ICING_TESTING_COMMON_MATCHERS_H_
#define ICING_TESTING_COMMON_MATCHERS_H_
+#include <algorithm>
#include <cmath>
+#include <string>
+#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/status_macros.h"
@@ -31,6 +34,7 @@
#include "icing/proto/status.pb.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
+#include "icing/scoring/scored-document-hit.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -104,19 +108,73 @@ MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id,
term_frequency_as_expected;
}
+class ScoredDocumentHitFormatter {
+ public:
+ std::string operator()(const ScoredDocumentHit& scored_document_hit) {
+ return IcingStringUtil::StringPrintf(
+ "(document_id=%d, hit_section_id_mask=%" PRId64 ", score=%.2f)",
+ scored_document_hit.document_id(),
+ scored_document_hit.hit_section_id_mask(), scored_document_hit.score());
+ }
+};
+
+class ScoredDocumentHitEqualComparator {
+ public:
+ bool operator()(const ScoredDocumentHit& lhs,
+ const ScoredDocumentHit& rhs) const {
+ return lhs.document_id() == rhs.document_id() &&
+ lhs.hit_section_id_mask() == rhs.hit_section_id_mask() &&
+ std::fabs(lhs.score() - rhs.score()) < 1e-6;
+ }
+};
+
// Used to match a ScoredDocumentHit
MATCHER_P(EqualsScoredDocumentHit, expected_scored_document_hit, "") {
- if (arg.document_id() != expected_scored_document_hit.document_id() ||
- arg.hit_section_id_mask() !=
- expected_scored_document_hit.hit_section_id_mask() ||
- std::fabs(arg.score() - expected_scored_document_hit.score()) > 1e-6) {
+ ScoredDocumentHitEqualComparator equal_comparator;
+ if (!equal_comparator(arg, expected_scored_document_hit)) {
+ ScoredDocumentHitFormatter formatter;
+ *result_listener << "Expected: " << formatter(expected_scored_document_hit)
+ << ". Actual: " << formatter(arg);
+ return false;
+ }
+ return true;
+}
+
+// Used to match a JoinedScoredDocumentHit
+MATCHER_P(EqualsJoinedScoredDocumentHit, expected_joined_scored_document_hit,
+ "") {
+ ScoredDocumentHitEqualComparator equal_comparator;
+ if (std::fabs(arg.final_score() -
+ expected_joined_scored_document_hit.final_score()) > 1e-6 ||
+ !equal_comparator(
+ arg.parent_scored_document_hit(),
+ expected_joined_scored_document_hit.parent_scored_document_hit()) ||
+ arg.child_scored_document_hits().size() !=
+ expected_joined_scored_document_hit.child_scored_document_hits()
+ .size() ||
+ !std::equal(
+ arg.child_scored_document_hits().cbegin(),
+ arg.child_scored_document_hits().cend(),
+ expected_joined_scored_document_hit.child_scored_document_hits()
+ .cbegin(),
+ equal_comparator)) {
+ ScoredDocumentHitFormatter formatter;
+
*result_listener << IcingStringUtil::StringPrintf(
- "Expected: document_id=%d, hit_section_id_mask=%d, score=%.2f. Actual: "
- "document_id=%d, hit_section_id_mask=%d, score=%.2f",
- expected_scored_document_hit.document_id(),
- expected_scored_document_hit.hit_section_id_mask(),
- expected_scored_document_hit.score(), arg.document_id(),
- arg.hit_section_id_mask(), arg.score());
+ "Expected: final_score=%.2f, parent_scored_document_hit=%s, "
+ "child_scored_document_hits=[%s]. Actual: final_score=%.2f, "
+ "parent_scored_document_hit=%s, child_scored_document_hits=[%s]",
+ expected_joined_scored_document_hit.final_score(),
+ formatter(
+ expected_joined_scored_document_hit.parent_scored_document_hit())
+ .c_str(),
+ absl_ports::StrJoin(
+ expected_joined_scored_document_hit.child_scored_document_hits(),
+ ",", formatter)
+ .c_str(),
+ arg.final_score(), formatter(arg.parent_scored_document_hit()).c_str(),
+ absl_ports::StrJoin(arg.child_scored_document_hits(), ",", formatter)
+ .c_str());
return false;
}
return true;
@@ -435,6 +493,11 @@ MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") {
actual_copy.clear_debug_info();
for (SearchResultProto::ResultProto& result :
*actual_copy.mutable_results()) {
+ // Joined results
+ for (SearchResultProto::ResultProto& joined_result :
+ *result.mutable_joined_results()) {
+ joined_result.clear_score();
+ }
result.clear_score();
}
@@ -443,6 +506,11 @@ MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") {
expected_copy.clear_debug_info();
for (SearchResultProto::ResultProto& result :
*expected_copy.mutable_results()) {
+ // Joined results
+ for (SearchResultProto::ResultProto& joined_result :
+ *result.mutable_joined_results()) {
+ joined_result.clear_score();
+ }
result.clear_score();
}
return ExplainMatchResult(portable_equals_proto::EqualsProto(expected_copy),
diff --git a/icing/tokenization/combined-tokenizer_test.cc b/icing/tokenization/combined-tokenizer_test.cc
index 42c7743..8314e91 100644
--- a/icing/tokenization/combined-tokenizer_test.cc
+++ b/icing/tokenization/combined-tokenizer_test.cc
@@ -142,6 +142,7 @@ TEST_F(CombinedTokenizerTest, Negation) {
EXPECT_THAT(query_terms, ElementsAre("foo", "bar", "baz"));
}
+// TODO(b/254874614): Handle colon word breaks in ICU 72+
TEST_F(CombinedTokenizerTest, Colons) {
const std::string_view kText = ":foo: :bar baz:";
ICING_ASSERT_OK_AND_ASSIGN(
@@ -165,6 +166,7 @@ TEST_F(CombinedTokenizerTest, Colons) {
EXPECT_THAT(query_terms, ElementsAre("foo", "bar", "baz"));
}
+// TODO(b/254874614): Handle colon word breaks in ICU 72+
TEST_F(CombinedTokenizerTest, ColonsPropertyRestricts) {
ICING_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<Tokenizer> indexing_tokenizer,
@@ -176,33 +178,61 @@ TEST_F(CombinedTokenizerTest, ColonsPropertyRestricts) {
CreateQueryTokenizer(tokenizer_factory::QueryTokenizerType::RAW_QUERY,
lang_segmenter_.get()));
- // This is a difference between the two tokenizers. "foo:bar" is a single
- // token to the plain tokenizer because ':' is a word connector. But "foo:bar"
- // is a property restrict to the query tokenizer - so "foo" is the property
- // and "bar" is the only text term.
- constexpr std::string_view kText = "foo:bar";
- ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> indexing_tokens,
- indexing_tokenizer->TokenizeAll(kText));
- std::vector<std::string> indexing_terms = GetTokenTerms(indexing_tokens);
- EXPECT_THAT(indexing_terms, ElementsAre("foo:bar"));
-
- ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> query_tokens,
- query_tokenizer->TokenizeAll(kText));
- std::vector<std::string> query_terms = GetTokenTerms(query_tokens);
- EXPECT_THAT(query_terms, ElementsAre("bar"));
-
- // This difference, however, should only apply to the first ':'. A
- // second ':' should be treated by both tokenizers as a word connector.
- constexpr std::string_view kText2 = "foo:bar:baz";
- ICING_ASSERT_OK_AND_ASSIGN(indexing_tokens,
- indexing_tokenizer->TokenizeAll(kText2));
- indexing_terms = GetTokenTerms(indexing_tokens);
- EXPECT_THAT(indexing_terms, ElementsAre("foo:bar:baz"));
-
- ICING_ASSERT_OK_AND_ASSIGN(query_tokens,
- query_tokenizer->TokenizeAll(kText2));
- query_terms = GetTokenTerms(query_tokens);
- EXPECT_THAT(query_terms, ElementsAre("bar:baz"));
+ if (IsIcu72PlusTokenization()) {
+ // In ICU 72+ and above, ':' are no longer considered word connectors. The
+ // query tokenizer should still consider them to be property restricts.
+ constexpr std::string_view kText = "foo:bar";
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> indexing_tokens,
+ indexing_tokenizer->TokenizeAll(kText));
+ std::vector<std::string> indexing_terms = GetTokenTerms(indexing_tokens);
+ EXPECT_THAT(indexing_terms, ElementsAre("foo", "bar"));
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> query_tokens,
+ query_tokenizer->TokenizeAll(kText));
+ std::vector<std::string> query_terms = GetTokenTerms(query_tokens);
+ EXPECT_THAT(query_terms, ElementsAre("bar"));
+
+ // This difference, however, should only apply to the first ':'. Both should
+ // consider a second ':' to be a word break.
+ constexpr std::string_view kText2 = "foo:bar:baz";
+ ICING_ASSERT_OK_AND_ASSIGN(indexing_tokens,
+ indexing_tokenizer->TokenizeAll(kText2));
+ indexing_terms = GetTokenTerms(indexing_tokens);
+ EXPECT_THAT(indexing_terms, ElementsAre("foo", "bar", "baz"));
+
+ ICING_ASSERT_OK_AND_ASSIGN(query_tokens,
+ query_tokenizer->TokenizeAll(kText2));
+ query_terms = GetTokenTerms(query_tokens);
+ EXPECT_THAT(query_terms, ElementsAre("bar", "baz"));
+ } else {
+ // This is a difference between the two tokenizers. "foo:bar" is a single
+ // token to the plain tokenizer because ':' is a word connector. But
+ // "foo:bar" is a property restrict to the query tokenizer - so "foo" is the
+ // property and "bar" is the only text term.
+ constexpr std::string_view kText = "foo:bar";
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> indexing_tokens,
+ indexing_tokenizer->TokenizeAll(kText));
+ std::vector<std::string> indexing_terms = GetTokenTerms(indexing_tokens);
+ EXPECT_THAT(indexing_terms, ElementsAre("foo:bar"));
+
+ ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> query_tokens,
+ query_tokenizer->TokenizeAll(kText));
+ std::vector<std::string> query_terms = GetTokenTerms(query_tokens);
+ EXPECT_THAT(query_terms, ElementsAre("bar"));
+
+ // This difference, however, should only apply to the first ':'. A
+ // second ':' should be treated by both tokenizers as a word connector.
+ constexpr std::string_view kText2 = "foo:bar:baz";
+ ICING_ASSERT_OK_AND_ASSIGN(indexing_tokens,
+ indexing_tokenizer->TokenizeAll(kText2));
+ indexing_terms = GetTokenTerms(indexing_tokens);
+ EXPECT_THAT(indexing_terms, ElementsAre("foo:bar:baz"));
+
+ ICING_ASSERT_OK_AND_ASSIGN(query_tokens,
+ query_tokenizer->TokenizeAll(kText2));
+ query_terms = GetTokenTerms(query_tokens);
+ EXPECT_THAT(query_terms, ElementsAre("bar:baz"));
+ }
}
TEST_F(CombinedTokenizerTest, Punctuation) {
diff --git a/icing/tokenization/icu/icu-language-segmenter_test.cc b/icing/tokenization/icu/icu-language-segmenter_test.cc
index 71e04e2..6771050 100644
--- a/icing/tokenization/icu/icu-language-segmenter_test.cc
+++ b/icing/tokenization/icu/icu-language-segmenter_test.cc
@@ -21,6 +21,7 @@
#include "gtest/gtest.h"
#include "icing/absl_ports/str_cat.h"
#include "icing/jni/jni-cache.h"
+#include "icing/portable/platform.h"
#include "icing/testing/common-matchers.h"
#include "icing/testing/icu-data-file-helper.h"
#include "icing/testing/icu-i18n-test-utils.h"
@@ -118,6 +119,9 @@ class IcuLanguageSegmenterAllLocalesTest
: public testing::TestWithParam<const char*> {
protected:
void SetUp() override {
+ if (!IsIcuTokenization()) {
+ GTEST_SKIP() << "ICU tokenization not enabled!";
+ }
ICING_ASSERT_OK(
// File generated via icu_data_file rule in //icing/BUILD.
icu_data_file_helper::SetUpICUDataFile(
@@ -223,46 +227,42 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, WordConnector) {
// Word connecters
EXPECT_THAT(language_segmenter->GetAllTerms("com.google.android"),
IsOkAndHolds(ElementsAre("com.google.android")));
- EXPECT_THAT(language_segmenter->GetAllTerms("com:google:android"),
- IsOkAndHolds(ElementsAre("com:google:android")));
EXPECT_THAT(language_segmenter->GetAllTerms("com'google'android"),
IsOkAndHolds(ElementsAre("com'google'android")));
EXPECT_THAT(language_segmenter->GetAllTerms("com_google_android"),
IsOkAndHolds(ElementsAre("com_google_android")));
// Word connecters can be mixed
- EXPECT_THAT(language_segmenter->GetAllTerms("com.google.android:icing"),
- IsOkAndHolds(ElementsAre("com.google.android:icing")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("com.google.android_icing"),
+ IsOkAndHolds(ElementsAre("com.google.android_icing")));
// Connectors that don't have valid terms on both sides of it are not
// considered connectors.
- EXPECT_THAT(language_segmenter->GetAllTerms(":bar:baz"),
- IsOkAndHolds(ElementsAre(":", "bar:baz")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("'bar'baz"),
+ IsOkAndHolds(ElementsAre("'", "bar'baz")));
- EXPECT_THAT(language_segmenter->GetAllTerms("bar:baz:"),
- IsOkAndHolds(ElementsAre("bar:baz", ":")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("bar.baz."),
+ IsOkAndHolds(ElementsAre("bar.baz", ".")));
// Connectors that don't have valid terms on both sides of it are not
// considered connectors.
- EXPECT_THAT(language_segmenter->GetAllTerms(" :bar:baz"),
- IsOkAndHolds(ElementsAre(" ", ":", "bar:baz")));
+ EXPECT_THAT(language_segmenter->GetAllTerms(" .bar.baz"),
+ IsOkAndHolds(ElementsAre(" ", ".", "bar.baz")));
- EXPECT_THAT(language_segmenter->GetAllTerms("bar:baz: "),
- IsOkAndHolds(ElementsAre("bar:baz", ":", " ")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("bar'baz' "),
+ IsOkAndHolds(ElementsAre("bar'baz", "'", " ")));
// Connectors don't connect if one side is an invalid term (?)
- EXPECT_THAT(language_segmenter->GetAllTerms("bar:baz:?"),
- IsOkAndHolds(ElementsAre("bar:baz", ":", "?")));
- EXPECT_THAT(language_segmenter->GetAllTerms("?:bar:baz"),
- IsOkAndHolds(ElementsAre("?", ":", "bar:baz")));
- EXPECT_THAT(language_segmenter->GetAllTerms("3:14"),
- IsOkAndHolds(ElementsAre("3", ":", "14")));
- EXPECT_THAT(language_segmenter->GetAllTerms("私:は"),
- IsOkAndHolds(ElementsAre("私", ":", "は")));
- EXPECT_THAT(language_segmenter->GetAllTerms("我:每"),
- IsOkAndHolds(ElementsAre("我", ":", "每")));
- EXPECT_THAT(language_segmenter->GetAllTerms("เดิน:ไป"),
- IsOkAndHolds(ElementsAre("เดิน:ไป")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("bar.baz.?"),
+ IsOkAndHolds(ElementsAre("bar.baz", ".", "?")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("?'bar'baz"),
+ IsOkAndHolds(ElementsAre("?", "'", "bar'baz")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("私'は"),
+ IsOkAndHolds(ElementsAre("私", "'", "は")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("我.每"),
+ IsOkAndHolds(ElementsAre("我", ".", "每")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("เดิน'ไป"),
+ IsOkAndHolds(ElementsAre("เดิน'ไป")));
// Any heading and trailing characters are not connecters
EXPECT_THAT(language_segmenter->GetAllTerms(".com.google.android."),
@@ -277,8 +277,6 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, WordConnector) {
IsOkAndHolds(ElementsAre("com", "+", "google", "+", "android")));
EXPECT_THAT(language_segmenter->GetAllTerms("com*google*android"),
IsOkAndHolds(ElementsAre("com", "*", "google", "*", "android")));
- EXPECT_THAT(language_segmenter->GetAllTerms("com@google@android"),
- IsOkAndHolds(ElementsAre("com", "@", "google", "@", "android")));
EXPECT_THAT(language_segmenter->GetAllTerms("com^google^android"),
IsOkAndHolds(ElementsAre("com", "^", "google", "^", "android")));
EXPECT_THAT(language_segmenter->GetAllTerms("com&google&android"),
@@ -292,6 +290,29 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, WordConnector) {
EXPECT_THAT(
language_segmenter->GetAllTerms("com\"google\"android"),
IsOkAndHolds(ElementsAre("com", "\"", "google", "\"", "android")));
+
+ // In ICU 72, there were a few changes:
+ // 1. ':' stopped being a word connector
+ // 2. '@' became a word connector
+ // 3. <numeric><word-connector><numeric> such as "3'14" is now considered as
+ // a single token.
+ if (IsIcu72PlusTokenization()) {
+ EXPECT_THAT(
+ language_segmenter->GetAllTerms("com:google:android"),
+ IsOkAndHolds(ElementsAre("com", ":", "google", ":", "android")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("com@google@android"),
+ IsOkAndHolds(ElementsAre("com@google@android")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("3'14"),
+ IsOkAndHolds(ElementsAre("3'14")));
+ } else {
+ EXPECT_THAT(language_segmenter->GetAllTerms("com:google:android"),
+ IsOkAndHolds(ElementsAre("com:google:android")));
+ EXPECT_THAT(
+ language_segmenter->GetAllTerms("com@google@android"),
+ IsOkAndHolds(ElementsAre("com", "@", "google", "@", "android")));
+ EXPECT_THAT(language_segmenter->GetAllTerms("3'14"),
+ IsOkAndHolds(ElementsAre("3", "'", "14")));
+ }
}
TEST_P(IcuLanguageSegmenterAllLocalesTest, Apostrophes) {
@@ -494,17 +515,17 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, ResetToStartUtf32WordConnector) {
ICING_ASSERT_OK_AND_ASSIGN(
auto segmenter, language_segmenter_factory::Create(
GetSegmenterOptions(GetLocale(), jni_cache_.get())));
- constexpr std::string_view kText = "com:google:android is package";
+ constexpr std::string_view kText = "com.google.android is package";
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageSegmenter::Iterator> itr,
segmenter->Segment(kText));
- // String: "com:google:android is package"
+ // String: "com.google.android is package"
// ^ ^^ ^^
// UTF-8 idx: 0 18 19 21 22
// UTF-32 idx: 0 18 19 21 22
auto position_or = itr->ResetToStartUtf32();
EXPECT_THAT(position_or, IsOk());
- ASSERT_THAT(itr->GetTerm(), Eq("com:google:android"));
+ ASSERT_THAT(itr->GetTerm(), Eq("com.google.android"));
}
TEST_P(IcuLanguageSegmenterAllLocalesTest, NewIteratorResetToStartUtf32) {
@@ -585,11 +606,11 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, ResetToTermAfterUtf32WordConnector) {
ICING_ASSERT_OK_AND_ASSIGN(
auto segmenter, language_segmenter_factory::Create(
GetSegmenterOptions(GetLocale(), jni_cache_.get())));
- constexpr std::string_view kText = "package com:google:android name";
+ constexpr std::string_view kText = "package com.google.android name";
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageSegmenter::Iterator> itr,
segmenter->Segment(kText));
- // String: "package com:google:android name"
+ // String: "package com.google.android name"
// ^ ^^ ^^
// UTF-8 idx: 0 7 8 26 27
// UTF-32 idx: 0 7 8 26 27
@@ -601,7 +622,7 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, ResetToTermAfterUtf32WordConnector) {
position_or = itr->ResetToTermStartingAfterUtf32(7);
EXPECT_THAT(position_or, IsOk());
EXPECT_THAT(position_or.ValueOrDie(), Eq(8));
- ASSERT_THAT(itr->GetTerm(), Eq("com:google:android"));
+ ASSERT_THAT(itr->GetTerm(), Eq("com.google.android"));
}
TEST_P(IcuLanguageSegmenterAllLocalesTest, ResetToTermAfterUtf32OutOfBounds) {
@@ -961,18 +982,18 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest,
ICING_ASSERT_OK_AND_ASSIGN(
auto segmenter, language_segmenter_factory::Create(
GetSegmenterOptions(GetLocale(), jni_cache_.get())));
- constexpr std::string_view kText = "package name com:google:android!";
+ constexpr std::string_view kText = "package name com.google.android!";
ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageSegmenter::Iterator> itr,
segmenter->Segment(kText));
- // String: "package name com:google:android!"
+ // String: "package name com.google.android!"
// ^ ^^ ^^ ^
// UTF-8 idx: 0 7 8 12 13 31
// UTF-32 idx: 0 7 8 12 13 31
auto position_or = itr->ResetToTermEndingBeforeUtf32(31);
EXPECT_THAT(position_or, IsOk());
EXPECT_THAT(position_or.ValueOrDie(), Eq(13));
- ASSERT_THAT(itr->GetTerm(), Eq("com:google:android"));
+ ASSERT_THAT(itr->GetTerm(), Eq("com.google.android"));
position_or = itr->ResetToTermEndingBeforeUtf32(21);
EXPECT_THAT(position_or, IsOk());
diff --git a/icing/tokenization/raw-query-tokenizer_test.cc b/icing/tokenization/raw-query-tokenizer_test.cc
index 2af4f18..a00f2f7 100644
--- a/icing/tokenization/raw-query-tokenizer_test.cc
+++ b/icing/tokenization/raw-query-tokenizer_test.cc
@@ -225,7 +225,7 @@ TEST_F(RawQueryTokenizerTest, Parentheses) {
HasSubstr("Too many right parentheses")));
}
-TEST_F(RawQueryTokenizerTest, Exclustion) {
+TEST_F(RawQueryTokenizerTest, Exclusion) {
language_segmenter_factory::SegmenterOptions options(ULOC_US);
ICING_ASSERT_OK_AND_ASSIGN(
auto language_segmenter,
@@ -344,13 +344,23 @@ TEST_F(RawQueryTokenizerTest, PropertyRestriction) {
EqualsToken(Token::Type::QUERY_PROPERTY, "email.title"),
EqualsToken(Token::Type::REGULAR, "hello"))));
- // The first colon ":" triggers property restriction, the second colon is used
- // as a word connector per ICU's rule
- // (https://unicode.org/reports/tr29/#Word_Boundaries).
- EXPECT_THAT(raw_query_tokenizer->TokenizeAll("property:foo:bar"),
- IsOkAndHolds(ElementsAre(
- EqualsToken(Token::Type::QUERY_PROPERTY, "property"),
- EqualsToken(Token::Type::REGULAR, "foo:bar"))));
+ // The first colon ":" triggers property restriction. Pre ICU 72, ':' was
+ // considered a word connector, so the second ':' will be interepreted as a
+ // connector pre-ICU 72. For ICU 72 and above, it's no longer considered a
+ // connector.
+ // TODO(b/254874614): Handle colon word breaks in ICU 72+
+ if (IsIcu72PlusTokenization()) {
+ EXPECT_THAT(raw_query_tokenizer->TokenizeAll("property:foo:bar"),
+ IsOkAndHolds(ElementsAre(
+ EqualsToken(Token::Type::QUERY_PROPERTY, "property"),
+ EqualsToken(Token::Type::REGULAR, "foo"),
+ EqualsToken(Token::Type::REGULAR, "bar"))));
+ } else {
+ EXPECT_THAT(raw_query_tokenizer->TokenizeAll("property:foo:bar"),
+ IsOkAndHolds(ElementsAre(
+ EqualsToken(Token::Type::QUERY_PROPERTY, "property"),
+ EqualsToken(Token::Type::REGULAR, "foo:bar"))));
+ }
// Property restriction only applies to the term right after it.
// Note: "term1:term2" is not a term but 2 terms because word connectors
diff --git a/icing/testing/snippet-helpers.cc b/icing/util/snippet-helpers.cc
index 3105073..6d6277f 100644
--- a/icing/testing/snippet-helpers.cc
+++ b/icing/util/snippet-helpers.cc
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include "icing/testing/snippet-helpers.h"
+#include "icing/util/snippet-helpers.h"
#include <algorithm>
#include <string_view>
diff --git a/icing/testing/snippet-helpers.h b/icing/util/snippet-helpers.h
index 73b2ce2..73b2ce2 100644
--- a/icing/testing/snippet-helpers.h
+++ b/icing/util/snippet-helpers.h
diff --git a/icing/util/tokenized-document.cc b/icing/util/tokenized-document.cc
index e741987..facb267 100644
--- a/icing/util/tokenized-document.cc
+++ b/icing/util/tokenized-document.cc
@@ -31,29 +31,14 @@
namespace icing {
namespace lib {
-libtextclassifier3::StatusOr<TokenizedDocument> TokenizedDocument::Create(
- const SchemaStore* schema_store,
- const LanguageSegmenter* language_segmenter, DocumentProto document) {
- TokenizedDocument tokenized_document(std::move(document));
- ICING_RETURN_IF_ERROR(
- tokenized_document.Tokenize(schema_store, language_segmenter));
- return tokenized_document;
-}
+namespace {
-TokenizedDocument::TokenizedDocument(DocumentProto document)
- : document_(std::move(document)) {}
-
-libtextclassifier3::Status TokenizedDocument::Tokenize(
+libtextclassifier3::StatusOr<std::vector<TokenizedSection>> Tokenize(
const SchemaStore* schema_store,
- const LanguageSegmenter* language_segmenter) {
- DocumentValidator validator(schema_store);
- ICING_RETURN_IF_ERROR(validator.Validate(document_));
-
- ICING_ASSIGN_OR_RETURN(SectionGroup section_group,
- schema_store->ExtractSections(document_));
- // string sections
- for (const Section<std::string_view>& section :
- section_group.string_sections) {
+ const LanguageSegmenter* language_segmenter,
+ const std::vector<Section<std::string_view>>& string_sections) {
+ std::vector<TokenizedSection> tokenized_string_sections;
+ for (const Section<std::string_view>& section : string_sections) {
ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer> tokenizer,
tokenizer_factory::CreateIndexingTokenizer(
section.metadata.tokenizer, language_segmenter));
@@ -68,11 +53,34 @@ libtextclassifier3::Status TokenizedDocument::Tokenize(
}
}
}
- tokenized_sections_.emplace_back(SectionMetadata(section.metadata),
- std::move(token_sequence));
+ tokenized_string_sections.emplace_back(SectionMetadata(section.metadata),
+ std::move(token_sequence));
}
- return libtextclassifier3::Status::OK;
+ return tokenized_string_sections;
+}
+
+} // namespace
+
+/* static */ libtextclassifier3::StatusOr<TokenizedDocument>
+TokenizedDocument::Create(const SchemaStore* schema_store,
+ const LanguageSegmenter* language_segmenter,
+ DocumentProto document) {
+ DocumentValidator validator(schema_store);
+ ICING_RETURN_IF_ERROR(validator.Validate(document));
+
+ ICING_ASSIGN_OR_RETURN(SectionGroup section_group,
+ schema_store->ExtractSections(document));
+
+ // Tokenize string sections
+ ICING_ASSIGN_OR_RETURN(
+ std::vector<TokenizedSection> tokenized_string_sections,
+ Tokenize(schema_store, language_segmenter,
+ section_group.string_sections));
+
+ return TokenizedDocument(std::move(document),
+ std::move(tokenized_string_sections),
+ std::move(section_group.integer_sections));
}
} // namespace lib
diff --git a/icing/util/tokenized-document.h b/icing/util/tokenized-document.h
index 5d996d9..5729df2 100644
--- a/icing/util/tokenized-document.h
+++ b/icing/util/tokenized-document.h
@@ -46,28 +46,35 @@ class TokenizedDocument {
const DocumentProto& document() const { return document_; }
- int32_t num_tokens() const {
- int32_t num_tokens = 0;
- for (const TokenizedSection& section : tokenized_sections_) {
- num_tokens += section.token_sequence.size();
+ int32_t num_string_tokens() const {
+ int32_t num_string_tokens = 0;
+ for (const TokenizedSection& section : tokenized_string_sections_) {
+ num_string_tokens += section.token_sequence.size();
}
- return num_tokens;
+ return num_string_tokens;
}
- const std::vector<TokenizedSection>& sections() const {
- return tokenized_sections_;
+ const std::vector<TokenizedSection>& tokenized_string_sections() const {
+ return tokenized_string_sections_;
+ }
+
+ const std::vector<Section<int64_t>>& integer_sections() const {
+ return integer_sections_;
}
private:
// Use TokenizedDocument::Create() to instantiate.
- explicit TokenizedDocument(DocumentProto document);
+ explicit TokenizedDocument(
+ DocumentProto&& document,
+ std::vector<TokenizedSection>&& tokenized_string_sections,
+ std::vector<Section<int64_t>>&& integer_sections)
+ : document_(std::move(document)),
+ tokenized_string_sections_(std::move(tokenized_string_sections)),
+ integer_sections_(std::move(integer_sections)) {}
DocumentProto document_;
- std::vector<TokenizedSection> tokenized_sections_;
-
- libtextclassifier3::Status Tokenize(
- const SchemaStore* schema_store,
- const LanguageSegmenter* language_segmenter);
+ std::vector<TokenizedSection> tokenized_string_sections_;
+ std::vector<Section<int64_t>> integer_sections_;
};
} // namespace lib
diff --git a/icing/util/tokenized-document_test.cc b/icing/util/tokenized-document_test.cc
new file mode 100644
index 0000000..3497bef
--- /dev/null
+++ b/icing/util/tokenized-document_test.cc
@@ -0,0 +1,335 @@
+// Copyright (C) 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "icing/util/tokenized-document.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "icing/document-builder.h"
+#include "icing/file/filesystem.h"
+#include "icing/portable/platform.h"
+#include "icing/proto/document.pb.h"
+#include "icing/proto/schema.pb.h"
+#include "icing/proto/term.pb.h"
+#include "icing/schema-builder.h"
+#include "icing/schema/schema-store.h"
+#include "icing/schema/section.h"
+#include "icing/testing/common-matchers.h"
+#include "icing/testing/fake-clock.h"
+#include "icing/testing/icu-data-file-helper.h"
+#include "icing/testing/test-data.h"
+#include "icing/testing/tmp-directory.h"
+#include "icing/tokenization/language-segmenter-factory.h"
+#include "icing/tokenization/language-segmenter.h"
+#include "unicode/uloc.h"
+
+namespace icing {
+namespace lib {
+
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::EqualsProto;
+using ::testing::IsEmpty;
+using ::testing::SizeIs;
+
+// schema types
+constexpr std::string_view kFakeType = "FakeType";
+
+// Indexable properties and section Id. Section Id is determined by the
+// lexicographical order of indexable property path.
+constexpr std::string_view kIndexableIntegerProperty1 = "indexableInteger1";
+constexpr std::string_view kIndexableIntegerProperty2 = "indexableInteger2";
+constexpr std::string_view kStringExactProperty = "stringExact";
+constexpr std::string_view kStringPrefixProperty = "stringPrefix";
+
+constexpr SectionId kIndexableInteger1SectionId = 0;
+constexpr SectionId kIndexableInteger2SectionId = 1;
+constexpr SectionId kStringExactSectionId = 2;
+constexpr SectionId kStringPrefixSectionId = 3;
+
+const SectionMetadata kIndexableInteger1SectionMetadata(
+ kIndexableInteger1SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
+ NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty1));
+
+const SectionMetadata kIndexableInteger2SectionMetadata(
+ kIndexableInteger2SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN,
+ NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty2));
+
+const SectionMetadata kStringExactSectionMetadata(
+ kStringExactSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT,
+ NUMERIC_MATCH_UNKNOWN, std::string(kStringExactProperty));
+
+const SectionMetadata kStringPrefixSectionMetadata(
+ kStringPrefixSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_PREFIX,
+ NUMERIC_MATCH_UNKNOWN, std::string(kStringPrefixProperty));
+
+// Other non-indexable properties.
+constexpr std::string_view kUnindexedStringProperty = "unindexedString";
+constexpr std::string_view kUnindexedIntegerProperty = "unindexedInteger";
+
+class TokenizedDocumentTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ test_dir_ = GetTestTempDir() + "/icing";
+ schema_store_dir_ = test_dir_ + "/schema_store";
+ filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str());
+
+ if (!IsCfStringTokenization() && !IsReverseJniTokenization()) {
+ ICING_ASSERT_OK(
+ // File generated via icu_data_file rule in //icing/BUILD.
+ icu_data_file_helper::SetUpICUDataFile(
+ GetTestFilePath("icing/icu.dat")));
+ }
+
+ language_segmenter_factory::SegmenterOptions options(ULOC_US);
+ ICING_ASSERT_OK_AND_ASSIGN(
+ lang_segmenter_,
+ language_segmenter_factory::Create(std::move(options)));
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ schema_store_,
+ SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_));
+
+ SchemaProto schema =
+ SchemaBuilder()
+ .AddType(
+ SchemaTypeConfigBuilder()
+ .SetType(kFakeType)
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kUnindexedStringProperty)
+ .SetDataType(TYPE_STRING)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kUnindexedIntegerProperty)
+ .SetDataType(TYPE_INT64)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kIndexableIntegerProperty1)
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kIndexableIntegerProperty2)
+ .SetDataTypeInt64(NUMERIC_MATCH_RANGE)
+ .SetCardinality(CARDINALITY_OPTIONAL))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kStringExactProperty)
+ .SetDataTypeString(TERM_MATCH_EXACT,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_REPEATED))
+ .AddProperty(PropertyConfigBuilder()
+ .SetName(kStringPrefixProperty)
+ .SetDataTypeString(TERM_MATCH_PREFIX,
+ TOKENIZER_PLAIN)
+ .SetCardinality(CARDINALITY_OPTIONAL)))
+ .Build();
+ ICING_ASSERT_OK(schema_store_->SetSchema(schema));
+ }
+
+ void TearDown() override {
+ schema_store_.reset();
+
+ // Check that the schema store directory is the *only* directory in the
+ // schema_store_dir_. IOW, ensure that all temporary directories have been
+ // properly cleaned up.
+ std::vector<std::string> sub_dirs;
+ ASSERT_TRUE(filesystem_.ListDirectory(test_dir_.c_str(), &sub_dirs));
+ ASSERT_THAT(sub_dirs, ElementsAre("schema_store"));
+
+ // Finally, clean everything up.
+ ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()));
+ }
+
+ Filesystem filesystem_;
+ FakeClock fake_clock_;
+ std::string test_dir_;
+ std::string schema_store_dir_;
+ std::unique_ptr<LanguageSegmenter> lang_segmenter_;
+ std::unique_ptr<SchemaStore> schema_store_;
+};
+
+TEST_F(TokenizedDocumentTest, CreateAll) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kUnindexedStringProperty),
+ "hello world unindexed")
+ .AddStringProperty(std::string(kStringExactProperty), "test foo",
+ "test bar", "test baz")
+ .AddStringProperty(std::string(kStringPrefixProperty), "foo bar baz")
+ .AddInt64Property(std::string(kUnindexedIntegerProperty), 789)
+ .AddInt64Property(std::string(kIndexableIntegerProperty1), 1, 2, 3)
+ .AddInt64Property(std::string(kIndexableIntegerProperty2), 456)
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+
+ EXPECT_THAT(tokenized_document.document(), EqualsProto(document));
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(9));
+
+ // string sections
+ EXPECT_THAT(tokenized_document.tokenized_string_sections(), SizeIs(2));
+ EXPECT_THAT(tokenized_document.tokenized_string_sections().at(0).metadata,
+ Eq(kStringExactSectionMetadata));
+ EXPECT_THAT(
+ tokenized_document.tokenized_string_sections().at(0).token_sequence,
+ ElementsAre("test", "foo", "test", "bar", "test", "baz"));
+ EXPECT_THAT(tokenized_document.tokenized_string_sections().at(1).metadata,
+ Eq(kStringPrefixSectionMetadata));
+ EXPECT_THAT(
+ tokenized_document.tokenized_string_sections().at(1).token_sequence,
+ ElementsAre("foo", "bar", "baz"));
+
+ // integer sections
+ EXPECT_THAT(tokenized_document.integer_sections(), SizeIs(2));
+ EXPECT_THAT(tokenized_document.integer_sections().at(0).metadata,
+ Eq(kIndexableInteger1SectionMetadata));
+ EXPECT_THAT(tokenized_document.integer_sections().at(0).content,
+ ElementsAre(1, 2, 3));
+ EXPECT_THAT(tokenized_document.integer_sections().at(1).metadata,
+ Eq(kIndexableInteger2SectionMetadata));
+ EXPECT_THAT(tokenized_document.integer_sections().at(1).content,
+ ElementsAre(456));
+}
+
+TEST_F(TokenizedDocumentTest, CreateNoIndexableIntegerProperties) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddInt64Property(std::string(kUnindexedIntegerProperty), 789)
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+
+ EXPECT_THAT(tokenized_document.document(), EqualsProto(document));
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0));
+
+ // string sections
+ EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty());
+
+ // integer sections
+ EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+}
+
+TEST_F(TokenizedDocumentTest, CreateMultipleIndexableIntegerProperties) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddInt64Property(std::string(kUnindexedIntegerProperty), 789)
+ .AddInt64Property(std::string(kIndexableIntegerProperty1), 1, 2, 3)
+ .AddInt64Property(std::string(kIndexableIntegerProperty2), 456)
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+
+ EXPECT_THAT(tokenized_document.document(), EqualsProto(document));
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0));
+
+ // string sections
+ EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty());
+
+ // integer sections
+ EXPECT_THAT(tokenized_document.integer_sections(), SizeIs(2));
+ EXPECT_THAT(tokenized_document.integer_sections().at(0).metadata,
+ Eq(kIndexableInteger1SectionMetadata));
+ EXPECT_THAT(tokenized_document.integer_sections().at(0).content,
+ ElementsAre(1, 2, 3));
+ EXPECT_THAT(tokenized_document.integer_sections().at(1).metadata,
+ Eq(kIndexableInteger2SectionMetadata));
+ EXPECT_THAT(tokenized_document.integer_sections().at(1).content,
+ ElementsAre(456));
+}
+
+TEST_F(TokenizedDocumentTest, CreateNoIndexableStringProperties) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kUnindexedStringProperty),
+ "hello world unindexed")
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+
+ EXPECT_THAT(tokenized_document.document(), EqualsProto(document));
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0));
+
+ // string sections
+ EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty());
+
+ // integer sections
+ EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+}
+
+TEST_F(TokenizedDocumentTest, CreateMultipleIndexableStringProperties) {
+ DocumentProto document =
+ DocumentBuilder()
+ .SetKey("icing", "fake_type/1")
+ .SetSchema(std::string(kFakeType))
+ .AddStringProperty(std::string(kUnindexedStringProperty),
+ "hello world unindexed")
+ .AddStringProperty(std::string(kStringExactProperty), "test foo",
+ "test bar", "test baz")
+ .AddStringProperty(std::string(kStringPrefixProperty), "foo bar baz")
+ .Build();
+
+ ICING_ASSERT_OK_AND_ASSIGN(
+ TokenizedDocument tokenized_document,
+ TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(),
+ document));
+
+ EXPECT_THAT(tokenized_document.document(), EqualsProto(document));
+ EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(9));
+
+ // string sections
+ EXPECT_THAT(tokenized_document.tokenized_string_sections(), SizeIs(2));
+ EXPECT_THAT(tokenized_document.tokenized_string_sections().at(0).metadata,
+ Eq(kStringExactSectionMetadata));
+ EXPECT_THAT(
+ tokenized_document.tokenized_string_sections().at(0).token_sequence,
+ ElementsAre("test", "foo", "test", "bar", "test", "baz"));
+ EXPECT_THAT(tokenized_document.tokenized_string_sections().at(1).metadata,
+ Eq(kStringPrefixSectionMetadata));
+ EXPECT_THAT(
+ tokenized_document.tokenized_string_sections().at(1).token_sequence,
+ ElementsAre("foo", "bar", "baz"));
+
+ // integer sections
+ EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty());
+}
+
+} // namespace
+
+} // namespace lib
+} // namespace icing
diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java
index 81223f2..47b94a5 100644
--- a/java/src/com/google/android/icing/IcingSearchEngine.java
+++ b/java/src/com/google/android/icing/IcingSearchEngine.java
@@ -14,7 +14,6 @@
package com.google.android.icing;
-import android.util.Log;
import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import com.google.android.icing.proto.DebugInfoResultProto;
@@ -45,17 +44,15 @@ import com.google.android.icing.proto.ScoringSpecProto;
import com.google.android.icing.proto.SearchResultProto;
import com.google.android.icing.proto.SearchSpecProto;
import com.google.android.icing.proto.SetSchemaResultProto;
-import com.google.android.icing.proto.StatusProto;
import com.google.android.icing.proto.StorageInfoResultProto;
import com.google.android.icing.proto.SuggestionResponse;
import com.google.android.icing.proto.SuggestionSpecProto;
import com.google.android.icing.proto.UsageReport;
-import com.google.protobuf.ExtensionRegistryLite;
-import com.google.protobuf.InvalidProtocolBufferException;
-import java.io.Closeable;
/**
- * Java wrapper to access native APIs in external/icing/icing/icing-search-engine.h
+ * Java wrapper to access {@link IcingSearchEngineImpl}.
+ *
+ * <p>It converts byte array from {@link IcingSearchEngineImpl} to corresponding protos.
*
* <p>If this instance has been closed, the instance is no longer usable.
*
@@ -63,574 +60,197 @@ import java.io.Closeable;
*
* <p>NOTE: This class is NOT thread-safe.
*/
-public class IcingSearchEngine implements Closeable {
+public class IcingSearchEngine implements IcingSearchEngineInterface {
private static final String TAG = "IcingSearchEngine";
- private static final ExtensionRegistryLite EXTENSION_REGISTRY_LITE =
- ExtensionRegistryLite.getEmptyRegistry();
-
- private long nativePointer;
-
- private boolean closed = false;
-
- static {
- // NOTE: This can fail with an UnsatisfiedLinkError
- System.loadLibrary("icing");
- }
+ private final IcingSearchEngineImpl icingSearchEngineImpl;
/**
* @throws IllegalStateException if IcingSearchEngine fails to be created
*/
public IcingSearchEngine(@NonNull IcingSearchEngineOptions options) {
- nativePointer = nativeCreate(options.toByteArray());
- if (nativePointer == 0) {
- Log.e(TAG, "Failed to create IcingSearchEngine.");
- throw new IllegalStateException("Failed to create IcingSearchEngine.");
- }
- }
-
- private void throwIfClosed() {
- if (closed) {
- throw new IllegalStateException("Trying to use a closed IcingSearchEngine instance.");
- }
+ icingSearchEngineImpl = new IcingSearchEngineImpl(options.toByteArray());
}
@Override
public void close() {
- if (closed) {
- return;
- }
-
- if (nativePointer != 0) {
- nativeDestroy(this);
- }
- nativePointer = 0;
- closed = true;
+ icingSearchEngineImpl.close();
}
@Override
protected void finalize() throws Throwable {
- close();
+ icingSearchEngineImpl.close();
super.finalize();
}
@NonNull
+ @Override
public InitializeResultProto initialize() {
- throwIfClosed();
-
- byte[] initializeResultBytes = nativeInitialize(this);
- if (initializeResultBytes == null) {
- Log.e(TAG, "Received null InitializeResult from native.");
- return InitializeResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return InitializeResultProto.parseFrom(initializeResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing InitializeResultProto.", e);
- return InitializeResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToInitializeResultProto(
+ icingSearchEngineImpl.initialize());
}
@NonNull
+ @Override
public SetSchemaResultProto setSchema(@NonNull SchemaProto schema) {
return setSchema(schema, /*ignoreErrorsAndDeleteDocuments=*/ false);
}
@NonNull
+ @Override
public SetSchemaResultProto setSchema(
@NonNull SchemaProto schema, boolean ignoreErrorsAndDeleteDocuments) {
- throwIfClosed();
-
- byte[] setSchemaResultBytes =
- nativeSetSchema(this, schema.toByteArray(), ignoreErrorsAndDeleteDocuments);
- if (setSchemaResultBytes == null) {
- Log.e(TAG, "Received null SetSchemaResultProto from native.");
- return SetSchemaResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return SetSchemaResultProto.parseFrom(setSchemaResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing SetSchemaResultProto.", e);
- return SetSchemaResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToSetSchemaResultProto(
+ icingSearchEngineImpl.setSchema(schema.toByteArray(), ignoreErrorsAndDeleteDocuments));
}
@NonNull
+ @Override
public GetSchemaResultProto getSchema() {
- throwIfClosed();
-
- byte[] getSchemaResultBytes = nativeGetSchema(this);
- if (getSchemaResultBytes == null) {
- Log.e(TAG, "Received null GetSchemaResultProto from native.");
- return GetSchemaResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return GetSchemaResultProto.parseFrom(getSchemaResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing GetSchemaResultProto.", e);
- return GetSchemaResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToGetSchemaResultProto(
+ icingSearchEngineImpl.getSchema());
}
@NonNull
+ @Override
public GetSchemaTypeResultProto getSchemaType(@NonNull String schemaType) {
- throwIfClosed();
-
- byte[] getSchemaTypeResultBytes = nativeGetSchemaType(this, schemaType);
- if (getSchemaTypeResultBytes == null) {
- Log.e(TAG, "Received null GetSchemaTypeResultProto from native.");
- return GetSchemaTypeResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return GetSchemaTypeResultProto.parseFrom(getSchemaTypeResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing GetSchemaTypeResultProto.", e);
- return GetSchemaTypeResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToGetSchemaTypeResultProto(
+ icingSearchEngineImpl.getSchemaType(schemaType));
}
@NonNull
+ @Override
public PutResultProto put(@NonNull DocumentProto document) {
- throwIfClosed();
-
- byte[] putResultBytes = nativePut(this, document.toByteArray());
- if (putResultBytes == null) {
- Log.e(TAG, "Received null PutResultProto from native.");
- return PutResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return PutResultProto.parseFrom(putResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing PutResultProto.", e);
- return PutResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToPutResultProto(
+ icingSearchEngineImpl.put(document.toByteArray()));
}
@NonNull
+ @Override
public GetResultProto get(
@NonNull String namespace, @NonNull String uri, @NonNull GetResultSpecProto getResultSpec) {
- throwIfClosed();
-
- byte[] getResultBytes = nativeGet(this, namespace, uri, getResultSpec.toByteArray());
- if (getResultBytes == null) {
- Log.e(TAG, "Received null GetResultProto from native.");
- return GetResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return GetResultProto.parseFrom(getResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing GetResultProto.", e);
- return GetResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToGetResultProto(
+ icingSearchEngineImpl.get(namespace, uri, getResultSpec.toByteArray()));
}
@NonNull
+ @Override
public ReportUsageResultProto reportUsage(@NonNull UsageReport usageReport) {
- throwIfClosed();
-
- byte[] reportUsageResultBytes = nativeReportUsage(this, usageReport.toByteArray());
- if (reportUsageResultBytes == null) {
- Log.e(TAG, "Received null ReportUsageResultProto from native.");
- return ReportUsageResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return ReportUsageResultProto.parseFrom(reportUsageResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing ReportUsageResultProto.", e);
- return ReportUsageResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToReportUsageResultProto(
+ icingSearchEngineImpl.reportUsage(usageReport.toByteArray()));
}
@NonNull
+ @Override
public GetAllNamespacesResultProto getAllNamespaces() {
- throwIfClosed();
-
- byte[] getAllNamespacesResultBytes = nativeGetAllNamespaces(this);
- if (getAllNamespacesResultBytes == null) {
- Log.e(TAG, "Received null GetAllNamespacesResultProto from native.");
- return GetAllNamespacesResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return GetAllNamespacesResultProto.parseFrom(
- getAllNamespacesResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing GetAllNamespacesResultProto.", e);
- return GetAllNamespacesResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToGetAllNamespacesResultProto(
+ icingSearchEngineImpl.getAllNamespaces());
}
@NonNull
+ @Override
public SearchResultProto search(
@NonNull SearchSpecProto searchSpec,
@NonNull ScoringSpecProto scoringSpec,
@NonNull ResultSpecProto resultSpec) {
- throwIfClosed();
-
- // Note that on Android System.currentTimeMillis() is the standard "wall" clock and can be set
- // by the user or the phone network so the time may jump backwards or forwards unpredictably.
- // This could lead to inaccurate final JNI latency calculations or unexpected negative numbers
- // in the case where the phone time is changed while sending data across JNI layers.
- // However these occurrences should be very rare, so we will keep usage of
- // System.currentTimeMillis() due to the lack of better time functions that can provide a
- // consistent timestamp across all platforms.
- long javaToNativeStartTimestampMs = System.currentTimeMillis();
- byte[] searchResultBytes =
- nativeSearch(
- this,
- searchSpec.toByteArray(),
- scoringSpec.toByteArray(),
- resultSpec.toByteArray(),
- javaToNativeStartTimestampMs);
- if (searchResultBytes == null) {
- Log.e(TAG, "Received null SearchResultProto from native.");
- return SearchResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- SearchResultProto.Builder searchResultProtoBuilder =
- SearchResultProto.newBuilder().mergeFrom(searchResultBytes, EXTENSION_REGISTRY_LITE);
- setNativeToJavaJniLatency(searchResultProtoBuilder);
- return searchResultProtoBuilder.build();
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing SearchResultProto.", e);
- return SearchResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToSearchResultProto(
+ icingSearchEngineImpl.search(
+ searchSpec.toByteArray(), scoringSpec.toByteArray(), resultSpec.toByteArray()));
}
@NonNull
+ @Override
public SearchResultProto getNextPage(long nextPageToken) {
- throwIfClosed();
-
- byte[] searchResultBytes = nativeGetNextPage(this, nextPageToken, System.currentTimeMillis());
- if (searchResultBytes == null) {
- Log.e(TAG, "Received null SearchResultProto from native.");
- return SearchResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- SearchResultProto.Builder searchResultProtoBuilder =
- SearchResultProto.newBuilder().mergeFrom(searchResultBytes, EXTENSION_REGISTRY_LITE);
- setNativeToJavaJniLatency(searchResultProtoBuilder);
- return searchResultProtoBuilder.build();
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing SearchResultProto.", e);
- return SearchResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
- }
-
- private void setNativeToJavaJniLatency(SearchResultProto.Builder searchResultProtoBuilder) {
- int nativeToJavaLatencyMs =
- (int)
- (System.currentTimeMillis()
- - searchResultProtoBuilder.getQueryStats().getNativeToJavaStartTimestampMs());
- searchResultProtoBuilder.setQueryStats(
- searchResultProtoBuilder.getQueryStats().toBuilder()
- .setNativeToJavaJniLatencyMs(nativeToJavaLatencyMs));
+ return IcingSearchEngineUtils.byteArrayToSearchResultProto(
+ icingSearchEngineImpl.getNextPage(nextPageToken));
}
@NonNull
+ @Override
public void invalidateNextPageToken(long nextPageToken) {
- throwIfClosed();
-
- nativeInvalidateNextPageToken(this, nextPageToken);
+ icingSearchEngineImpl.invalidateNextPageToken(nextPageToken);
}
@NonNull
+ @Override
public DeleteResultProto delete(@NonNull String namespace, @NonNull String uri) {
- throwIfClosed();
-
- byte[] deleteResultBytes = nativeDelete(this, namespace, uri);
- if (deleteResultBytes == null) {
- Log.e(TAG, "Received null DeleteResultProto from native.");
- return DeleteResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return DeleteResultProto.parseFrom(deleteResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing DeleteResultProto.", e);
- return DeleteResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToDeleteResultProto(
+ icingSearchEngineImpl.delete(namespace, uri));
}
@NonNull
+ @Override
public SuggestionResponse searchSuggestions(@NonNull SuggestionSpecProto suggestionSpec) {
- byte[] suggestionResponseBytes = nativeSearchSuggestions(this, suggestionSpec.toByteArray());
- if (suggestionResponseBytes == null) {
- Log.e(TAG, "Received null suggestionResponseBytes from native.");
- return SuggestionResponse.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return SuggestionResponse.parseFrom(suggestionResponseBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing suggestionResponseBytes.", e);
- return SuggestionResponse.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToSuggestionResponse(
+ icingSearchEngineImpl.searchSuggestions(suggestionSpec.toByteArray()));
}
@NonNull
+ @Override
public DeleteByNamespaceResultProto deleteByNamespace(@NonNull String namespace) {
- throwIfClosed();
-
- byte[] deleteByNamespaceResultBytes = nativeDeleteByNamespace(this, namespace);
- if (deleteByNamespaceResultBytes == null) {
- Log.e(TAG, "Received null DeleteByNamespaceResultProto from native.");
- return DeleteByNamespaceResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return DeleteByNamespaceResultProto.parseFrom(
- deleteByNamespaceResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing DeleteByNamespaceResultProto.", e);
- return DeleteByNamespaceResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToDeleteByNamespaceResultProto(
+ icingSearchEngineImpl.deleteByNamespace(namespace));
}
@NonNull
+ @Override
public DeleteBySchemaTypeResultProto deleteBySchemaType(@NonNull String schemaType) {
- throwIfClosed();
-
- byte[] deleteBySchemaTypeResultBytes = nativeDeleteBySchemaType(this, schemaType);
- if (deleteBySchemaTypeResultBytes == null) {
- Log.e(TAG, "Received null DeleteBySchemaTypeResultProto from native.");
- return DeleteBySchemaTypeResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return DeleteBySchemaTypeResultProto.parseFrom(
- deleteBySchemaTypeResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing DeleteBySchemaTypeResultProto.", e);
- return DeleteBySchemaTypeResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToDeleteBySchemaTypeResultProto(
+ icingSearchEngineImpl.deleteBySchemaType(schemaType));
}
@NonNull
+ @Override
public DeleteByQueryResultProto deleteByQuery(@NonNull SearchSpecProto searchSpec) {
return deleteByQuery(searchSpec, /*returnDeletedDocumentInfo=*/ false);
}
@NonNull
+ @Override
public DeleteByQueryResultProto deleteByQuery(
@NonNull SearchSpecProto searchSpec, boolean returnDeletedDocumentInfo) {
- throwIfClosed();
-
- byte[] deleteResultBytes =
- nativeDeleteByQuery(this, searchSpec.toByteArray(), returnDeletedDocumentInfo);
- if (deleteResultBytes == null) {
- Log.e(TAG, "Received null DeleteResultProto from native.");
- return DeleteByQueryResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return DeleteByQueryResultProto.parseFrom(deleteResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing DeleteResultProto.", e);
- return DeleteByQueryResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToDeleteByQueryResultProto(
+ icingSearchEngineImpl.deleteByQuery(searchSpec.toByteArray(), returnDeletedDocumentInfo));
}
@NonNull
+ @Override
public PersistToDiskResultProto persistToDisk(@NonNull PersistType.Code persistTypeCode) {
- throwIfClosed();
-
- byte[] persistToDiskResultBytes = nativePersistToDisk(this, persistTypeCode.getNumber());
- if (persistToDiskResultBytes == null) {
- Log.e(TAG, "Received null PersistToDiskResultProto from native.");
- return PersistToDiskResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return PersistToDiskResultProto.parseFrom(persistToDiskResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing PersistToDiskResultProto.", e);
- return PersistToDiskResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToPersistToDiskResultProto(
+ icingSearchEngineImpl.persistToDisk(persistTypeCode.getNumber()));
}
@NonNull
+ @Override
public OptimizeResultProto optimize() {
- throwIfClosed();
-
- byte[] optimizeResultBytes = nativeOptimize(this);
- if (optimizeResultBytes == null) {
- Log.e(TAG, "Received null OptimizeResultProto from native.");
- return OptimizeResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return OptimizeResultProto.parseFrom(optimizeResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing OptimizeResultProto.", e);
- return OptimizeResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToOptimizeResultProto(icingSearchEngineImpl.optimize());
}
@NonNull
+ @Override
public GetOptimizeInfoResultProto getOptimizeInfo() {
- throwIfClosed();
-
- byte[] getOptimizeInfoResultBytes = nativeGetOptimizeInfo(this);
- if (getOptimizeInfoResultBytes == null) {
- Log.e(TAG, "Received null GetOptimizeInfoResultProto from native.");
- return GetOptimizeInfoResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return GetOptimizeInfoResultProto.parseFrom(
- getOptimizeInfoResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing GetOptimizeInfoResultProto.", e);
- return GetOptimizeInfoResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToGetOptimizeInfoResultProto(
+ icingSearchEngineImpl.getOptimizeInfo());
}
@NonNull
+ @Override
public StorageInfoResultProto getStorageInfo() {
- throwIfClosed();
-
- byte[] storageInfoResultProtoBytes = nativeGetStorageInfo(this);
- if (storageInfoResultProtoBytes == null) {
- Log.e(TAG, "Received null StorageInfoResultProto from native.");
- return StorageInfoResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return StorageInfoResultProto.parseFrom(storageInfoResultProtoBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing GetOptimizeInfoResultProto.", e);
- return StorageInfoResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToStorageInfoResultProto(
+ icingSearchEngineImpl.getStorageInfo());
}
@NonNull
+ @Override
public DebugInfoResultProto getDebugInfo(DebugInfoVerbosity.Code verbosity) {
- throwIfClosed();
-
- byte[] debugInfoResultProtoBytes = nativeGetDebugInfo(this, verbosity.getNumber());
- if (debugInfoResultProtoBytes == null) {
- Log.e(TAG, "Received null DebugInfoResultProto from native.");
- return DebugInfoResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return DebugInfoResultProto.parseFrom(debugInfoResultProtoBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing DebugInfoResultProto.", e);
- return DebugInfoResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToDebugInfoResultProto(
+ icingSearchEngineImpl.getDebugInfo(verbosity.getNumber()));
}
@NonNull
+ @Override
public ResetResultProto reset() {
- throwIfClosed();
-
- byte[] resetResultBytes = nativeReset(this);
- if (resetResultBytes == null) {
- Log.e(TAG, "Received null ResetResultProto from native.");
- return ResetResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
-
- try {
- return ResetResultProto.parseFrom(resetResultBytes, EXTENSION_REGISTRY_LITE);
- } catch (InvalidProtocolBufferException e) {
- Log.e(TAG, "Error parsing ResetResultProto.", e);
- return ResetResultProto.newBuilder()
- .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
- .build();
- }
+ return IcingSearchEngineUtils.byteArrayToResetResultProto(icingSearchEngineImpl.reset());
}
public static boolean shouldLog(LogSeverity.Code severity) {
@@ -638,7 +258,7 @@ public class IcingSearchEngine implements Closeable {
}
public static boolean shouldLog(LogSeverity.Code severity, short verbosity) {
- return nativeShouldLog((short) severity.getNumber(), verbosity);
+ return IcingSearchEngineImpl.shouldLog((short) severity.getNumber(), verbosity);
}
public static boolean setLoggingLevel(LogSeverity.Code severity) {
@@ -646,84 +266,11 @@ public class IcingSearchEngine implements Closeable {
}
public static boolean setLoggingLevel(LogSeverity.Code severity, short verbosity) {
- return nativeSetLoggingLevel((short) severity.getNumber(), verbosity);
+ return IcingSearchEngineImpl.setLoggingLevel((short) severity.getNumber(), verbosity);
}
@Nullable
public static String getLoggingTag() {
- String tag = nativeGetLoggingTag();
- if (tag == null) {
- Log.e(TAG, "Received null logging tag from native.");
- }
- return tag;
+ return IcingSearchEngineImpl.getLoggingTag();
}
-
- private static native long nativeCreate(byte[] icingSearchEngineOptionsBytes);
-
- private static native void nativeDestroy(IcingSearchEngine instance);
-
- private static native byte[] nativeInitialize(IcingSearchEngine instance);
-
- private static native byte[] nativeSetSchema(
- IcingSearchEngine instance, byte[] schemaBytes, boolean ignoreErrorsAndDeleteDocuments);
-
- private static native byte[] nativeGetSchema(IcingSearchEngine instance);
-
- private static native byte[] nativeGetSchemaType(IcingSearchEngine instance, String schemaType);
-
- private static native byte[] nativePut(IcingSearchEngine instance, byte[] documentBytes);
-
- private static native byte[] nativeGet(
- IcingSearchEngine instance, String namespace, String uri, byte[] getResultSpecBytes);
-
- private static native byte[] nativeReportUsage(
- IcingSearchEngine instance, byte[] usageReportBytes);
-
- private static native byte[] nativeGetAllNamespaces(IcingSearchEngine instance);
-
- private static native byte[] nativeSearch(
- IcingSearchEngine instance,
- byte[] searchSpecBytes,
- byte[] scoringSpecBytes,
- byte[] resultSpecBytes,
- long javaToNativeStartTimestampMs);
-
- private static native byte[] nativeGetNextPage(
- IcingSearchEngine instance, long nextPageToken, long javaToNativeStartTimestampMs);
-
- private static native void nativeInvalidateNextPageToken(
- IcingSearchEngine instance, long nextPageToken);
-
- private static native byte[] nativeDelete(
- IcingSearchEngine instance, String namespace, String uri);
-
- private static native byte[] nativeDeleteByNamespace(
- IcingSearchEngine instance, String namespace);
-
- private static native byte[] nativeDeleteBySchemaType(
- IcingSearchEngine instance, String schemaType);
-
- private static native byte[] nativeDeleteByQuery(
- IcingSearchEngine instance, byte[] searchSpecBytes, boolean returnDeletedDocumentInfo);
-
- private static native byte[] nativePersistToDisk(IcingSearchEngine instance, int persistType);
-
- private static native byte[] nativeOptimize(IcingSearchEngine instance);
-
- private static native byte[] nativeGetOptimizeInfo(IcingSearchEngine instance);
-
- private static native byte[] nativeGetStorageInfo(IcingSearchEngine instance);
-
- private static native byte[] nativeReset(IcingSearchEngine instance);
-
- private static native byte[] nativeSearchSuggestions(
- IcingSearchEngine instance, byte[] suggestionSpecBytes);
-
- private static native byte[] nativeGetDebugInfo(IcingSearchEngine instance, int verbosity);
-
- private static native boolean nativeShouldLog(short severity, short verbosity);
-
- private static native boolean nativeSetLoggingLevel(short severity, short verbosity);
-
- private static native String nativeGetLoggingTag();
}
diff --git a/java/src/com/google/android/icing/IcingSearchEngineImpl.java b/java/src/com/google/android/icing/IcingSearchEngineImpl.java
new file mode 100644
index 0000000..8e79a88
--- /dev/null
+++ b/java/src/com/google/android/icing/IcingSearchEngineImpl.java
@@ -0,0 +1,330 @@
+// Copyright (C) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.android.icing;
+
+import android.util.Log;
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+import java.io.Closeable;
+
+/**
+ * Java wrapper to access native APIs in external/icing/icing/icing-search-engine.h
+ *
+ * <p>If this instance has been closed, the instance is no longer usable.
+ *
+ * <p>Keep this class to be non-Final so that it can be mocked in AppSearch.
+ *
+ * <p>NOTE: This class is NOT thread-safe.
+ */
+public class IcingSearchEngineImpl implements Closeable {
+
+ private static final String TAG = "IcingSearchEngineImpl";
+
+ private long nativePointer;
+
+ private boolean closed = false;
+
+ static {
+ // NOTE: This can fail with an UnsatisfiedLinkError
+ System.loadLibrary("icing");
+ }
+
+ /**
+ * @throws IllegalStateException if IcingSearchEngineImpl fails to be created
+ */
+ public IcingSearchEngineImpl(@NonNull byte[] optionsBytes) {
+ nativePointer = nativeCreate(optionsBytes);
+ if (nativePointer == 0) {
+ Log.e(TAG, "Failed to create IcingSearchEngineImpl.");
+ throw new IllegalStateException("Failed to create IcingSearchEngineImpl.");
+ }
+ }
+
+ private void throwIfClosed() {
+ if (closed) {
+ throw new IllegalStateException("Trying to use a closed IcingSearchEngineImpl instance.");
+ }
+ }
+
+ @Override
+ public void close() {
+ if (closed) {
+ return;
+ }
+
+ if (nativePointer != 0) {
+ nativeDestroy(this);
+ }
+ nativePointer = 0;
+ closed = true;
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ close();
+ super.finalize();
+ }
+
+ @Nullable
+ public byte[] initialize() {
+ throwIfClosed();
+ return nativeInitialize(this);
+ }
+
+ @Nullable
+ public byte[] setSchema(@NonNull byte[] schemaBytes) {
+ return setSchema(schemaBytes, /* ignoreErrorsAndDeleteDocuments= */ false);
+ }
+
+ @Nullable
+ public byte[] setSchema(@NonNull byte[] schemaBytes, boolean ignoreErrorsAndDeleteDocuments) {
+ throwIfClosed();
+ return nativeSetSchema(this, schemaBytes, ignoreErrorsAndDeleteDocuments);
+ }
+
+ @Nullable
+ public byte[] getSchema() {
+ throwIfClosed();
+ return nativeGetSchema(this);
+ }
+
+ @Nullable
+ public byte[] getSchemaType(@NonNull String schemaType) {
+ throwIfClosed();
+ return nativeGetSchemaType(this, schemaType);
+ }
+
+ @Nullable
+ public byte[] put(@NonNull byte[] documentBytes) {
+ throwIfClosed();
+ return nativePut(this, documentBytes);
+ }
+
+ @Nullable
+ public byte[] get(
+ @NonNull String namespace, @NonNull String uri, @NonNull byte[] getResultSpecBytes) {
+ throwIfClosed();
+ return nativeGet(this, namespace, uri, getResultSpecBytes);
+ }
+
+ @Nullable
+ public byte[] reportUsage(@NonNull byte[] usageReportBytes) {
+ throwIfClosed();
+ return nativeReportUsage(this, usageReportBytes);
+ }
+
+ @Nullable
+ public byte[] getAllNamespaces() {
+ throwIfClosed();
+ return nativeGetAllNamespaces(this);
+ }
+
+ @Nullable
+ public byte[] search(
+ @NonNull byte[] searchSpecBytes,
+ @NonNull byte[] scoringSpecBytes,
+ @NonNull byte[] resultSpecBytes) {
+ throwIfClosed();
+
+ // Note that on Android System.currentTimeMillis() is the standard "wall" clock and can be set
+ // by the user or the phone network so the time may jump backwards or forwards unpredictably.
+ // This could lead to inaccurate final JNI latency calculations or unexpected negative numbers
+ // in the case where the phone time is changed while sending data across JNI layers.
+ // However these occurrences should be very rare, so we will keep usage of
+ // System.currentTimeMillis() due to the lack of better time functions that can provide a
+ // consistent timestamp across all platforms.
+ long javaToNativeStartTimestampMs = System.currentTimeMillis();
+ return nativeSearch(
+ this, searchSpecBytes, scoringSpecBytes, resultSpecBytes, javaToNativeStartTimestampMs);
+ }
+
+ @Nullable
+ public byte[] getNextPage(long nextPageToken) {
+ throwIfClosed();
+ return nativeGetNextPage(this, nextPageToken, System.currentTimeMillis());
+ }
+
+ @NonNull
+ public void invalidateNextPageToken(long nextPageToken) {
+ throwIfClosed();
+ nativeInvalidateNextPageToken(this, nextPageToken);
+ }
+
+ @Nullable
+ public byte[] delete(@NonNull String namespace, @NonNull String uri) {
+ throwIfClosed();
+ return nativeDelete(this, namespace, uri);
+ }
+
+ @Nullable
+ public byte[] searchSuggestions(@NonNull byte[] suggestionSpecBytes) {
+ throwIfClosed();
+ return nativeSearchSuggestions(this, suggestionSpecBytes);
+ }
+
+ @Nullable
+ public byte[] deleteByNamespace(@NonNull String namespace) {
+ throwIfClosed();
+ return nativeDeleteByNamespace(this, namespace);
+ }
+
+ @Nullable
+ public byte[] deleteBySchemaType(@NonNull String schemaType) {
+ throwIfClosed();
+ return nativeDeleteBySchemaType(this, schemaType);
+ }
+
+ @Nullable
+ public byte[] deleteByQuery(@NonNull byte[] searchSpecBytes) {
+ return deleteByQuery(searchSpecBytes, /* returnDeletedDocumentInfo= */ false);
+ }
+
+ @Nullable
+ public byte[] deleteByQuery(@NonNull byte[] searchSpecBytes, boolean returnDeletedDocumentInfo) {
+ throwIfClosed();
+ return nativeDeleteByQuery(this, searchSpecBytes, returnDeletedDocumentInfo);
+ }
+
+ @Nullable
+ public byte[] persistToDisk(int persistTypeCode) {
+ throwIfClosed();
+ return nativePersistToDisk(this, persistTypeCode);
+ }
+
+ @Nullable
+ public byte[] optimize() {
+ throwIfClosed();
+ return nativeOptimize(this);
+ }
+
+ @Nullable
+ public byte[] getOptimizeInfo() {
+ throwIfClosed();
+ return nativeGetOptimizeInfo(this);
+ }
+
+ @Nullable
+ public byte[] getStorageInfo() {
+ throwIfClosed();
+ return nativeGetStorageInfo(this);
+ }
+
+ @Nullable
+ public byte[] getDebugInfo(int verbosityCode) {
+ throwIfClosed();
+ return nativeGetDebugInfo(this, verbosityCode);
+ }
+
+ @Nullable
+ public byte[] reset() {
+ throwIfClosed();
+ return nativeReset(this);
+ }
+
+ public static boolean shouldLog(short severity) {
+ return shouldLog(severity, (short) 0);
+ }
+
+ public static boolean shouldLog(short severity, short verbosity) {
+ return nativeShouldLog(severity, verbosity);
+ }
+
+ public static boolean setLoggingLevel(short severity) {
+ return setLoggingLevel(severity, (short) 0);
+ }
+
+ public static boolean setLoggingLevel(short severity, short verbosity) {
+ return nativeSetLoggingLevel(severity, verbosity);
+ }
+
+ @Nullable
+ public static String getLoggingTag() {
+ String tag = nativeGetLoggingTag();
+ if (tag == null) {
+ Log.e(TAG, "Received null logging tag from native.");
+ }
+ return tag;
+ }
+
+ private static native long nativeCreate(byte[] icingSearchEngineOptionsBytes);
+
+ private static native void nativeDestroy(IcingSearchEngineImpl instance);
+
+ private static native byte[] nativeInitialize(IcingSearchEngineImpl instance);
+
+ private static native byte[] nativeSetSchema(
+ IcingSearchEngineImpl instance, byte[] schemaBytes, boolean ignoreErrorsAndDeleteDocuments);
+
+ private static native byte[] nativeGetSchema(IcingSearchEngineImpl instance);
+
+ private static native byte[] nativeGetSchemaType(
+ IcingSearchEngineImpl instance, String schemaType);
+
+ private static native byte[] nativePut(IcingSearchEngineImpl instance, byte[] documentBytes);
+
+ private static native byte[] nativeGet(
+ IcingSearchEngineImpl instance, String namespace, String uri, byte[] getResultSpecBytes);
+
+ private static native byte[] nativeReportUsage(
+ IcingSearchEngineImpl instance, byte[] usageReportBytes);
+
+ private static native byte[] nativeGetAllNamespaces(IcingSearchEngineImpl instance);
+
+ private static native byte[] nativeSearch(
+ IcingSearchEngineImpl instance,
+ byte[] searchSpecBytes,
+ byte[] scoringSpecBytes,
+ byte[] resultSpecBytes,
+ long javaToNativeStartTimestampMs);
+
+ private static native byte[] nativeGetNextPage(
+ IcingSearchEngineImpl instance, long nextPageToken, long javaToNativeStartTimestampMs);
+
+ private static native void nativeInvalidateNextPageToken(
+ IcingSearchEngineImpl instance, long nextPageToken);
+
+ private static native byte[] nativeDelete(
+ IcingSearchEngineImpl instance, String namespace, String uri);
+
+ private static native byte[] nativeDeleteByNamespace(
+ IcingSearchEngineImpl instance, String namespace);
+
+ private static native byte[] nativeDeleteBySchemaType(
+ IcingSearchEngineImpl instance, String schemaType);
+
+ private static native byte[] nativeDeleteByQuery(
+ IcingSearchEngineImpl instance, byte[] searchSpecBytes, boolean returnDeletedDocumentInfo);
+
+ private static native byte[] nativePersistToDisk(IcingSearchEngineImpl instance, int persistType);
+
+ private static native byte[] nativeOptimize(IcingSearchEngineImpl instance);
+
+ private static native byte[] nativeGetOptimizeInfo(IcingSearchEngineImpl instance);
+
+ private static native byte[] nativeGetStorageInfo(IcingSearchEngineImpl instance);
+
+ private static native byte[] nativeReset(IcingSearchEngineImpl instance);
+
+ private static native byte[] nativeSearchSuggestions(
+ IcingSearchEngineImpl instance, byte[] suggestionSpecBytes);
+
+ private static native byte[] nativeGetDebugInfo(IcingSearchEngineImpl instance, int verbosity);
+
+ private static native boolean nativeShouldLog(short severity, short verbosity);
+
+ private static native boolean nativeSetLoggingLevel(short severity, short verbosity);
+
+ private static native String nativeGetLoggingTag();
+}
diff --git a/java/src/com/google/android/icing/IcingSearchEngineInterface.java b/java/src/com/google/android/icing/IcingSearchEngineInterface.java
new file mode 100644
index 0000000..9d567f9
--- /dev/null
+++ b/java/src/com/google/android/icing/IcingSearchEngineInterface.java
@@ -0,0 +1,153 @@
+package com.google.android.icing;
+
+import android.os.RemoteException;
+import com.google.android.icing.proto.DebugInfoResultProto;
+import com.google.android.icing.proto.DebugInfoVerbosity;
+import com.google.android.icing.proto.DeleteByNamespaceResultProto;
+import com.google.android.icing.proto.DeleteByQueryResultProto;
+import com.google.android.icing.proto.DeleteBySchemaTypeResultProto;
+import com.google.android.icing.proto.DeleteResultProto;
+import com.google.android.icing.proto.DocumentProto;
+import com.google.android.icing.proto.GetAllNamespacesResultProto;
+import com.google.android.icing.proto.GetOptimizeInfoResultProto;
+import com.google.android.icing.proto.GetResultProto;
+import com.google.android.icing.proto.GetResultSpecProto;
+import com.google.android.icing.proto.GetSchemaResultProto;
+import com.google.android.icing.proto.GetSchemaTypeResultProto;
+import com.google.android.icing.proto.InitializeResultProto;
+import com.google.android.icing.proto.OptimizeResultProto;
+import com.google.android.icing.proto.PersistToDiskResultProto;
+import com.google.android.icing.proto.PersistType;
+import com.google.android.icing.proto.PutResultProto;
+import com.google.android.icing.proto.ReportUsageResultProto;
+import com.google.android.icing.proto.ResetResultProto;
+import com.google.android.icing.proto.ResultSpecProto;
+import com.google.android.icing.proto.SchemaProto;
+import com.google.android.icing.proto.ScoringSpecProto;
+import com.google.android.icing.proto.SearchResultProto;
+import com.google.android.icing.proto.SearchSpecProto;
+import com.google.android.icing.proto.SetSchemaResultProto;
+import com.google.android.icing.proto.StorageInfoResultProto;
+import com.google.android.icing.proto.SuggestionResponse;
+import com.google.android.icing.proto.SuggestionSpecProto;
+import com.google.android.icing.proto.UsageReport;
+
+/**
+ * A common user-facing interface to expose the funcationalities provided by Icing Library.
+ *
+ * <p>All the methods here throw {@link RemoteException} because the implementation for
+ * gmscore-appsearch-dynamite will throw it.
+ */
+public interface IcingSearchEngineInterface extends AutoCloseable {
+ /**
+ * Initializes the current IcingSearchEngine implementation.
+ *
+ * <p>Internally the icing instance will be initialized.
+ */
+ InitializeResultProto initialize();
+
+ /** Sets the schema for the icing instance. */
+ SetSchemaResultProto setSchema(SchemaProto schema);
+
+ /**
+ * Sets the schema for the icing instance.
+ *
+ * @param ignoreErrorsAndDeleteDocuments force to set the schema and delete documents in case of
+ * incompatible schema change.
+ */
+ SetSchemaResultProto setSchema(SchemaProto schema, boolean ignoreErrorsAndDeleteDocuments);
+
+ /** Gets the schema for the icing instance. */
+ GetSchemaResultProto getSchema();
+
+ /**
+ * Gets the schema for the icing instance.
+ *
+ * @param schemaType type of the schema.
+ */
+ GetSchemaTypeResultProto getSchemaType(String schemaType);
+
+ /** Puts the document. */
+ PutResultProto put(DocumentProto document);
+
+ /**
+ * Gets the document.
+ *
+ * @param namespace namespace of the document.
+ * @param uri uri of the document.
+ * @param getResultSpec the spec for getting the document.
+ */
+ GetResultProto get(String namespace, String uri, GetResultSpecProto getResultSpec);
+
+ /** Reports usage. */
+ ReportUsageResultProto reportUsage(UsageReport usageReport);
+
+ /** Gets all namespaces. */
+ GetAllNamespacesResultProto getAllNamespaces();
+
+ /**
+ * Searches over the documents.
+ *
+ * <p>Documents need to be retrieved on the following {@link #getNextPage} calls on the returned
+ * {@link SearchResultProto}.
+ */
+ SearchResultProto search(
+ SearchSpecProto searchSpec, ScoringSpecProto scoringSpec, ResultSpecProto resultSpec);
+
+ /** Gets the next page. */
+ SearchResultProto getNextPage(long nextPageToken);
+
+ /** Invalidates the next page token. */
+ void invalidateNextPageToken(long nextPageToken);
+
+ /**
+ * Deletes the document.
+ *
+ * @param namespace the namespace the document to be deleted belong to.
+ * @param uri the uri for the document to be deleted.
+ */
+ DeleteResultProto delete(String namespace, String uri);
+
+ /** Returns the suggestions for the search query. */
+ SuggestionResponse searchSuggestions(SuggestionSpecProto suggestionSpec);
+
+ /** Deletes documents by the namespace. */
+ DeleteByNamespaceResultProto deleteByNamespace(String namespace);
+
+ /** Deletes documents by the schema type. */
+ DeleteBySchemaTypeResultProto deleteBySchemaType(String schemaType);
+
+ /** Deletes documents by the search query. */
+ DeleteByQueryResultProto deleteByQuery(SearchSpecProto searchSpec);
+
+ /**
+ * Deletes document by the search query
+ *
+ * @param returnDeletedDocumentInfo whether additional information about deleted documents will be
+ * included in {@link DeleteByQueryResultProto}.
+ */
+ DeleteByQueryResultProto deleteByQuery(
+ SearchSpecProto searchSpec, boolean returnDeletedDocumentInfo);
+
+ /** Makes sure every update/delete received till this point is flushed to disk. */
+ PersistToDiskResultProto persistToDisk(PersistType.Code persistTypeCode);
+
+ /** Makes the icing instance run tasks that are too expensive to be run in real-time. */
+ OptimizeResultProto optimize();
+
+ /** Gets information about the optimization. */
+ GetOptimizeInfoResultProto getOptimizeInfo();
+
+ /** Gets information about the storage. */
+ StorageInfoResultProto getStorageInfo();
+
+ /** Gets the debug information for the current icing instance. */
+ DebugInfoResultProto getDebugInfo(DebugInfoVerbosity.Code verbosity);
+
+ /** Clears all data from the current icing instance, and reinitializes it. */
+ ResetResultProto reset();
+
+ /** Closes the current icing instance. */
+ @Override
+ void close();
+}
diff --git a/java/src/com/google/android/icing/IcingSearchEngineUtils.java b/java/src/com/google/android/icing/IcingSearchEngineUtils.java
new file mode 100644
index 0000000..0913216
--- /dev/null
+++ b/java/src/com/google/android/icing/IcingSearchEngineUtils.java
@@ -0,0 +1,471 @@
+// Copyright (C) 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package com.google.android.icing;
+
+import android.util.Log;
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+import com.google.android.icing.proto.DebugInfoResultProto;
+import com.google.android.icing.proto.DeleteByNamespaceResultProto;
+import com.google.android.icing.proto.DeleteByQueryResultProto;
+import com.google.android.icing.proto.DeleteBySchemaTypeResultProto;
+import com.google.android.icing.proto.DeleteResultProto;
+import com.google.android.icing.proto.GetAllNamespacesResultProto;
+import com.google.android.icing.proto.GetOptimizeInfoResultProto;
+import com.google.android.icing.proto.GetResultProto;
+import com.google.android.icing.proto.GetSchemaResultProto;
+import com.google.android.icing.proto.GetSchemaTypeResultProto;
+import com.google.android.icing.proto.InitializeResultProto;
+import com.google.android.icing.proto.OptimizeResultProto;
+import com.google.android.icing.proto.PersistToDiskResultProto;
+import com.google.android.icing.proto.PutResultProto;
+import com.google.android.icing.proto.ReportUsageResultProto;
+import com.google.android.icing.proto.ResetResultProto;
+import com.google.android.icing.proto.SearchResultProto;
+import com.google.android.icing.proto.SetSchemaResultProto;
+import com.google.android.icing.proto.StatusProto;
+import com.google.android.icing.proto.StorageInfoResultProto;
+import com.google.android.icing.proto.SuggestionResponse;
+import com.google.protobuf.ExtensionRegistryLite;
+import com.google.protobuf.InvalidProtocolBufferException;
+
+/**
+ * Contains utility methods for IcingSearchEngine to convert byte arrays to the corresponding
+ * protos.
+ *
+ * <p>It is also being used by AppSearch dynamite 0p client APIs to convert byte arrays to the
+ * protos.
+ */
+public final class IcingSearchEngineUtils {
+ private static final String TAG = "IcingSearchEngineUtils";
+ private static final ExtensionRegistryLite EXTENSION_REGISTRY_LITE =
+ ExtensionRegistryLite.getEmptyRegistry();
+
+ private IcingSearchEngineUtils() {}
+
+ // TODO(b/240333360) Check to see if we can use one template function to replace those
+ @NonNull
+ public static InitializeResultProto byteArrayToInitializeResultProto(
+ @Nullable byte[] initializeResultBytes) {
+ if (initializeResultBytes == null) {
+ Log.e(TAG, "Received null InitializeResult from native.");
+ return InitializeResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return InitializeResultProto.parseFrom(initializeResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing InitializeResultProto.", e);
+ return InitializeResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static SetSchemaResultProto byteArrayToSetSchemaResultProto(
+ @Nullable byte[] setSchemaResultBytes) {
+ if (setSchemaResultBytes == null) {
+ Log.e(TAG, "Received null SetSchemaResultProto from native.");
+ return SetSchemaResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return SetSchemaResultProto.parseFrom(setSchemaResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing SetSchemaResultProto.", e);
+ return SetSchemaResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static GetSchemaResultProto byteArrayToGetSchemaResultProto(
+ @Nullable byte[] getSchemaResultBytes) {
+ if (getSchemaResultBytes == null) {
+ Log.e(TAG, "Received null GetSchemaResultProto from native.");
+ return GetSchemaResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return GetSchemaResultProto.parseFrom(getSchemaResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing GetSchemaResultProto.", e);
+ return GetSchemaResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static GetSchemaTypeResultProto byteArrayToGetSchemaTypeResultProto(
+ @Nullable byte[] getSchemaTypeResultBytes) {
+ if (getSchemaTypeResultBytes == null) {
+ Log.e(TAG, "Received null GetSchemaTypeResultProto from native.");
+ return GetSchemaTypeResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return GetSchemaTypeResultProto.parseFrom(getSchemaTypeResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing GetSchemaTypeResultProto.", e);
+ return GetSchemaTypeResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static PutResultProto byteArrayToPutResultProto(@Nullable byte[] putResultBytes) {
+ if (putResultBytes == null) {
+ Log.e(TAG, "Received null PutResultProto from native.");
+ return PutResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return PutResultProto.parseFrom(putResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing PutResultProto.", e);
+ return PutResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static GetResultProto byteArrayToGetResultProto(@Nullable byte[] getResultBytes) {
+ if (getResultBytes == null) {
+ Log.e(TAG, "Received null GetResultProto from native.");
+ return GetResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return GetResultProto.parseFrom(getResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing GetResultProto.", e);
+ return GetResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static ReportUsageResultProto byteArrayToReportUsageResultProto(
+ @Nullable byte[] reportUsageResultBytes) {
+ if (reportUsageResultBytes == null) {
+ Log.e(TAG, "Received null ReportUsageResultProto from native.");
+ return ReportUsageResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return ReportUsageResultProto.parseFrom(reportUsageResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing ReportUsageResultProto.", e);
+ return ReportUsageResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static GetAllNamespacesResultProto byteArrayToGetAllNamespacesResultProto(
+ @Nullable byte[] getAllNamespacesResultBytes) {
+ if (getAllNamespacesResultBytes == null) {
+ Log.e(TAG, "Received null GetAllNamespacesResultProto from native.");
+ return GetAllNamespacesResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return GetAllNamespacesResultProto.parseFrom(
+ getAllNamespacesResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing GetAllNamespacesResultProto.", e);
+ return GetAllNamespacesResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static SearchResultProto byteArrayToSearchResultProto(@Nullable byte[] searchResultBytes) {
+ if (searchResultBytes == null) {
+ Log.e(TAG, "Received null SearchResultProto from native.");
+ return SearchResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ SearchResultProto.Builder searchResultProtoBuilder =
+ SearchResultProto.newBuilder().mergeFrom(searchResultBytes, EXTENSION_REGISTRY_LITE);
+ setNativeToJavaJniLatency(searchResultProtoBuilder);
+ return searchResultProtoBuilder.build();
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing SearchResultProto.", e);
+ return SearchResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ private static void setNativeToJavaJniLatency(
+ SearchResultProto.Builder searchResultProtoBuilder) {
+ int nativeToJavaLatencyMs =
+ (int)
+ (System.currentTimeMillis()
+ - searchResultProtoBuilder.getQueryStats().getNativeToJavaStartTimestampMs());
+ searchResultProtoBuilder.setQueryStats(
+ searchResultProtoBuilder.getQueryStats().toBuilder()
+ .setNativeToJavaJniLatencyMs(nativeToJavaLatencyMs));
+ }
+
+ @NonNull
+ public static DeleteResultProto byteArrayToDeleteResultProto(@Nullable byte[] deleteResultBytes) {
+ if (deleteResultBytes == null) {
+ Log.e(TAG, "Received null DeleteResultProto from native.");
+ return DeleteResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return DeleteResultProto.parseFrom(deleteResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing DeleteResultProto.", e);
+ return DeleteResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static SuggestionResponse byteArrayToSuggestionResponse(
+ @Nullable byte[] suggestionResponseBytes) {
+ if (suggestionResponseBytes == null) {
+ Log.e(TAG, "Received null suggestionResponseBytes from native.");
+ return SuggestionResponse.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return SuggestionResponse.parseFrom(suggestionResponseBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing suggestionResponseBytes.", e);
+ return SuggestionResponse.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static DeleteByNamespaceResultProto byteArrayToDeleteByNamespaceResultProto(
+ @Nullable byte[] deleteByNamespaceResultBytes) {
+ if (deleteByNamespaceResultBytes == null) {
+ Log.e(TAG, "Received null DeleteByNamespaceResultProto from native.");
+ return DeleteByNamespaceResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return DeleteByNamespaceResultProto.parseFrom(
+ deleteByNamespaceResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing DeleteByNamespaceResultProto.", e);
+ return DeleteByNamespaceResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static DeleteBySchemaTypeResultProto byteArrayToDeleteBySchemaTypeResultProto(
+ @Nullable byte[] deleteBySchemaTypeResultBytes) {
+ if (deleteBySchemaTypeResultBytes == null) {
+ Log.e(TAG, "Received null DeleteBySchemaTypeResultProto from native.");
+ return DeleteBySchemaTypeResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return DeleteBySchemaTypeResultProto.parseFrom(
+ deleteBySchemaTypeResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing DeleteBySchemaTypeResultProto.", e);
+ return DeleteBySchemaTypeResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static DeleteByQueryResultProto byteArrayToDeleteByQueryResultProto(
+ @Nullable byte[] deleteResultBytes) {
+ if (deleteResultBytes == null) {
+ Log.e(TAG, "Received null DeleteResultProto from native.");
+ return DeleteByQueryResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return DeleteByQueryResultProto.parseFrom(deleteResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing DeleteResultProto.", e);
+ return DeleteByQueryResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static PersistToDiskResultProto byteArrayToPersistToDiskResultProto(
+ @Nullable byte[] persistToDiskResultBytes) {
+ if (persistToDiskResultBytes == null) {
+ Log.e(TAG, "Received null PersistToDiskResultProto from native.");
+ return PersistToDiskResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return PersistToDiskResultProto.parseFrom(persistToDiskResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing PersistToDiskResultProto.", e);
+ return PersistToDiskResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static OptimizeResultProto byteArrayToOptimizeResultProto(
+ @Nullable byte[] optimizeResultBytes) {
+ if (optimizeResultBytes == null) {
+ Log.e(TAG, "Received null OptimizeResultProto from native.");
+ return OptimizeResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return OptimizeResultProto.parseFrom(optimizeResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing OptimizeResultProto.", e);
+ return OptimizeResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static GetOptimizeInfoResultProto byteArrayToGetOptimizeInfoResultProto(
+ @Nullable byte[] getOptimizeInfoResultBytes) {
+ if (getOptimizeInfoResultBytes == null) {
+ Log.e(TAG, "Received null GetOptimizeInfoResultProto from native.");
+ return GetOptimizeInfoResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return GetOptimizeInfoResultProto.parseFrom(
+ getOptimizeInfoResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing GetOptimizeInfoResultProto.", e);
+ return GetOptimizeInfoResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static StorageInfoResultProto byteArrayToStorageInfoResultProto(
+ @Nullable byte[] storageInfoResultProtoBytes) {
+ if (storageInfoResultProtoBytes == null) {
+ Log.e(TAG, "Received null StorageInfoResultProto from native.");
+ return StorageInfoResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return StorageInfoResultProto.parseFrom(storageInfoResultProtoBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing GetOptimizeInfoResultProto.", e);
+ return StorageInfoResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static DebugInfoResultProto byteArrayToDebugInfoResultProto(
+ @Nullable byte[] debugInfoResultProtoBytes) {
+ if (debugInfoResultProtoBytes == null) {
+ Log.e(TAG, "Received null DebugInfoResultProto from native.");
+ return DebugInfoResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return DebugInfoResultProto.parseFrom(debugInfoResultProtoBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing DebugInfoResultProto.", e);
+ return DebugInfoResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+
+ @NonNull
+ public static ResetResultProto byteArrayToResetResultProto(@Nullable byte[] resetResultBytes) {
+ if (resetResultBytes == null) {
+ Log.e(TAG, "Received null ResetResultProto from native.");
+ return ResetResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+
+ try {
+ return ResetResultProto.parseFrom(resetResultBytes, EXTENSION_REGISTRY_LITE);
+ } catch (InvalidProtocolBufferException e) {
+ Log.e(TAG, "Error parsing ResetResultProto.", e);
+ return ResetResultProto.newBuilder()
+ .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL))
+ .build();
+ }
+ }
+}
diff --git a/proto/icing/proto/scoring.proto b/proto/icing/proto/scoring.proto
index 13861c9..a8040a1 100644
--- a/proto/icing/proto/scoring.proto
+++ b/proto/icing/proto/scoring.proto
@@ -25,7 +25,7 @@ option objc_class_prefix = "ICNG";
// Encapsulates the configurations on how Icing should score and rank the search
// results.
// TODO(b/170347684): Change all timestamps to seconds.
-// Next tag: 4
+// Next tag: 12
message ScoringSpecProto {
// OPTIONAL: Indicates how the search results will be ranked.
message RankingStrategy {
@@ -71,6 +71,9 @@ message ScoringSpecProto {
// Ranked by the aggregated score of the joined documents.
JOIN_AGGREGATE_SCORE = 10;
+
+ // Ranked by the advanced scoring expression provided.
+ ADVANCED_SCORING_EXPRESSION = 11;
}
}
optional RankingStrategy.Code rank_by = 1;
@@ -99,6 +102,10 @@ message ScoringSpecProto {
// all properties that are not specified are given a raw, pre-normalized
// weight of 1.0 when scoring.
repeated TypePropertyWeights type_property_weights = 3;
+
+ // OPTIONAL: Specifies the scoring expression for ADVANCED_SCORING_EXPRESSION
+ // RankingStrategy.
+ optional string advanced_scoring_expression = 4;
}
// Next tag: 3
diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto
index 181c63c..e7e0208 100644
--- a/proto/icing/proto/search.proto
+++ b/proto/icing/proto/search.proto
@@ -299,7 +299,7 @@ message SearchResultProto {
// determined by ScoringSpecProto.rank_by.
optional double score = 3;
- // The documents that were joined to a parent document.
+ // The child documents that were joined to a parent document.
repeated ResultProto joined_results = 4;
}
repeated ResultProto results = 2;
@@ -430,37 +430,53 @@ message SuggestionResponse {
//
// Next tag: 7
message JoinSpecProto {
- // A nested SearchSpec that will be used to retrieve joined documents. If you
- // are only looking to join on Action type documents, you could set a schema
- // filter in this SearchSpec. This includes the nested search query. See
- // SearchSpecProto.
- optional SearchSpecProto nested_search_spec = 1;
+ // Collection of several specs that will be used for searching and joining
+ // child documents.
+ //
+ // Next tag: 4
+ message NestedSpecProto {
+ // A nested SearchSpec that will be used to retrieve child documents. If you
+ // are only looking to join on a specific type documents, you could set a
+ // schema filter in this SearchSpec. This includes the nested search query.
+ // See SearchSpecProto.
+ optional SearchSpecProto search_spec = 1;
+
+ // A nested ScoringSpec that will be used to score child documents.
+ // See ScoringSpecProto.
+ optional ScoringSpecProto scoring_spec = 2;
+
+ // A nested ResultSpec that will be used to format child documents in the
+ // result joined documents, e.g. snippeting, projection.
+ // See ResultSpecProto.
+ optional ResultSpecProto result_spec = 3;
+ }
+ optional NestedSpecProto nested_spec = 1;
// The equivalent of a primary key in SQL. This is an expression that will be
// used to match child documents from the nested search to this document. One
- // such expression is qualifiedId(). When used, it means the
- // child_property_expression in the joined documents must be equal to the
- // qualified id.
+ // such expression is qualifiedId(). When used, it means the contents of
+ // child_property_expression property in the child documents must be equal to
+ // the qualified id.
// TODO(b/256022027) allow for parent_property_expression to be any property
// of the parent document.
optional string parent_property_expression = 2;
// The equivalent of a foreign key in SQL. This defines an equality constraint
// between a property in a child document and a property in the parent
- // document. For example, if you want to join Action documents which an
+ // document. For example, if you want to join child documents which an
// entityId property containing a fully qualified document id,
// child_property_expression can be set to "entityId".
// TODO(b/256022027) figure out how to allow this to refer to documents
// outside of same pkg+db+ns.
optional string child_property_expression = 3;
- // The max amount of joined documents to join to a parent document.
- optional int32 max_joined_result_count = 4;
+ // The max number of child documents to join to a parent document.
+ optional int32 max_joined_child_count = 4;
- // The strategy by which to score the aggregation of joined documents. For
+ // The strategy by which to score the aggregation of child documents. For
// example, you might want to know which entity document has the most actions
// taken on it. If JOIN_AGGREGATE_SCORE is used in the base SearchSpecProto,
- // the COUNT value will rank entity documents based on the number of joined
+ // the COUNT value will rank entity documents based on the number of child
// documents.
enum AggregationScore {
UNDEFINED = 0;
diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt
index 55403b4..654903b 100644
--- a/synced_AOSP_CL_number.txt
+++ b/synced_AOSP_CL_number.txt
@@ -1 +1 @@
-set(synced_AOSP_CL_number=487674301)
+set(synced_AOSP_CL_number=494856295)