diff options
124 files changed, 15644 insertions, 2919 deletions
diff --git a/icing/file/file-backed-vector.h b/icing/file/file-backed-vector.h index 183c091..1d99e24 100644 --- a/icing/file/file-backed-vector.h +++ b/icing/file/file-backed-vector.h @@ -76,6 +76,7 @@ #include "icing/file/filesystem.h" #include "icing/file/memory-mapped-file.h" #include "icing/legacy/core/icing-string-util.h" +#include "icing/portable/platform.h" #include "icing/util/crc32.h" #include "icing/util/logging.h" #include "icing/util/math-util.h" @@ -147,10 +148,17 @@ class FileBackedVector { } }; - // Absolute max file size for FileBackedVector. Note that Android has a - // (2^31-1)-byte single file size limit, so kMaxFileSize is 2^31-1. + // Absolute max file size for FileBackedVector. + // - We memory map the whole file, so file size ~= memory size. + // - On 32-bit platform, the virtual memory address space is 4GB. To avoid + // exhausting the memory, set smaller file size limit for 32-bit platform. +#ifdef ICING_ARCH_BIT_64 static constexpr int32_t kMaxFileSize = - std::numeric_limits<int32_t>::max(); // 2^31-1 Bytes, ~2.1 GB; + std::numeric_limits<int32_t>::max(); // 2^31-1 Bytes, ~2.1 GB +#else + static constexpr int32_t kMaxFileSize = + (1 << 28) + Header::kHeaderSize; // 2^28 + 12 Bytes, ~256 MiB +#endif // Size of element type T. The value is same as sizeof(T), while we should // avoid using sizeof(T) in our codebase to prevent unexpected unsigned @@ -461,7 +469,8 @@ class FileBackedVector { // Absolute max # of elements allowed. Since we are using int32_t to store // num_elements, max value is 2^31-1. Still the actual max # of elements are - // determined by max_file_size, kElementTypeSize, and Header::kHeaderSize. + // determined by max_file_size, kMaxFileSize, kElementTypeSize, and + // Header::kHeaderSize. static constexpr int32_t kMaxNumElements = std::numeric_limits<int32_t>::max(); @@ -887,7 +896,7 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( kElementTypeSize) { return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( "%d elements total size exceed maximum bytes of elements allowed, " - "%d bytes", + "%" PRId64 " bytes", num_elements, mmapped_file_->max_file_size() - Header::kHeaderSize)); } diff --git a/icing/file/persistent-hash-map.cc b/icing/file/persistent-hash-map.cc index 0c9fd7f..0af5e2f 100644 --- a/icing/file/persistent-hash-map.cc +++ b/icing/file/persistent-hash-map.cc @@ -180,13 +180,66 @@ std::string GetKeyValueStorageFilePath(std::string_view base_dir) { ".k"); } +// Calculates how many buckets we need given num_entries and +// max_load_factor_percent. Round it up to 2's power. +// +// REQUIRES: 0 < num_entries <= Entry::kMaxNumEntries && +// max_load_factor_percent > 0 +int32_t CalculateNumBucketsRequired(int32_t num_entries, + int32_t max_load_factor_percent) { + // Calculate ceil(num_entries * 100 / max_load_factor_percent) + int32_t num_entries_100 = num_entries * 100; + int32_t num_buckets_required = + num_entries_100 / max_load_factor_percent + + (num_entries_100 % max_load_factor_percent == 0 ? 0 : 1); + if ((num_buckets_required & (num_buckets_required - 1)) != 0) { + // not 2's power + return 1 << (32 - __builtin_clz(num_buckets_required)); + } + return num_buckets_required; +} + } // namespace +bool PersistentHashMap::Options::IsValid() const { + if (!(value_type_size > 0 && value_type_size <= kMaxValueTypeSize && + max_num_entries > 0 && max_num_entries <= Entry::kMaxNumEntries && + max_load_factor_percent > 0 && average_kv_byte_size > 0 && + init_num_buckets > 0 && init_num_buckets <= Bucket::kMaxNumBuckets)) { + return false; + } + + // We've ensured (static_assert) that storing kMaxNumBuckets buckets won't + // exceed FileBackedVector::kMaxFileSize, so only need to verify # of buckets + // required won't exceed kMaxNumBuckets. + if (CalculateNumBucketsRequired(max_num_entries, max_load_factor_percent) > + Bucket::kMaxNumBuckets) { + return false; + } + + // Verify # of key value pairs can fit into kv_storage. + if (average_kv_byte_size > kMaxKVTotalByteSize / max_num_entries) { + return false; + } + + // Verify init_num_buckets is 2's power. Requiring init_num_buckets to be 2^n + // guarantees that num_buckets will eventually grow to be exactly + // max_num_buckets since CalculateNumBucketsRequired rounds it up to 2^n. + if ((init_num_buckets & (init_num_buckets - 1)) != 0) { + return false; + } + + return true; +} + /* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> PersistentHashMap::Create(const Filesystem& filesystem, - std::string_view base_dir, int32_t value_type_size, - int32_t max_load_factor_percent, - int32_t init_num_buckets) { + std::string_view base_dir, const Options& options) { + if (!options.IsValid()) { + return absl_ports::InvalidArgumentError( + "Invalid PersistentHashMap options"); + } + if (!filesystem.FileExists( GetMetadataFilePath(base_dir, kSubDirectory).c_str()) || !filesystem.FileExists( @@ -195,11 +248,9 @@ PersistentHashMap::Create(const Filesystem& filesystem, GetEntryStorageFilePath(base_dir, kSubDirectory).c_str()) || !filesystem.FileExists(GetKeyValueStorageFilePath(base_dir).c_str())) { // TODO: erase all files if missing any. - return InitializeNewFiles(filesystem, base_dir, value_type_size, - max_load_factor_percent, init_num_buckets); + return InitializeNewFiles(filesystem, base_dir, options); } - return InitializeExistingFiles(filesystem, base_dir, value_type_size, - max_load_factor_percent); + return InitializeExistingFiles(filesystem, base_dir, options); } PersistentHashMap::~PersistentHashMap() { @@ -398,9 +449,7 @@ libtextclassifier3::StatusOr<Crc32> PersistentHashMap::ComputeChecksum() { /* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem, std::string_view base_dir, - int32_t value_type_size, - int32_t max_load_factor_percent, - int32_t init_num_buckets) { + const Options& options) { // Create directory. const std::string dir_path = absl_ports::StrCat(base_dir, "/", kSubDirectory); if (!filesystem.CreateDirectoryRecursively(dir_path.c_str())) { @@ -408,32 +457,54 @@ PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem, absl_ports::StrCat("Failed to create directory: ", dir_path)); } - // Initialize 3 storages + int32_t max_num_buckets_required = + std::max(options.init_num_buckets, + CalculateNumBucketsRequired(options.max_num_entries, + options.max_load_factor_percent)); + + // Initialize bucket_storage + int32_t pre_mapping_mmap_size = sizeof(Bucket) * max_num_buckets_required; + int32_t max_file_size = + pre_mapping_mmap_size + FileBackedVector<Bucket>::Header::kHeaderSize; ICING_ASSIGN_OR_RETURN( std::unique_ptr<FileBackedVector<Bucket>> bucket_storage, FileBackedVector<Bucket>::Create( filesystem, GetBucketStorageFilePath(base_dir, kSubDirectory), - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size, + pre_mapping_mmap_size)); + + // Initialize entry_storage + pre_mapping_mmap_size = sizeof(Entry) * options.max_num_entries; + max_file_size = + pre_mapping_mmap_size + FileBackedVector<Entry>::Header::kHeaderSize; ICING_ASSIGN_OR_RETURN( std::unique_ptr<FileBackedVector<Entry>> entry_storage, FileBackedVector<Entry>::Create( filesystem, GetEntryStorageFilePath(base_dir, kSubDirectory), - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size, + pre_mapping_mmap_size)); + + // Initialize kv_storage + pre_mapping_mmap_size = + options.average_kv_byte_size * options.max_num_entries; + max_file_size = + pre_mapping_mmap_size + FileBackedVector<char>::Header::kHeaderSize; ICING_ASSIGN_OR_RETURN(std::unique_ptr<FileBackedVector<char>> kv_storage, FileBackedVector<char>::Create( filesystem, GetKeyValueStorageFilePath(base_dir), - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + max_file_size, pre_mapping_mmap_size)); // Initialize buckets. - ICING_RETURN_IF_ERROR( - bucket_storage->Set(/*idx=*/0, /*len=*/init_num_buckets, Bucket())); + ICING_RETURN_IF_ERROR(bucket_storage->Set( + /*idx=*/0, /*len=*/options.init_num_buckets, Bucket())); ICING_RETURN_IF_ERROR(bucket_storage->PersistToDisk()); // Create and initialize new info Info new_info; new_info.version = kVersion; - new_info.value_type_size = value_type_size; - new_info.max_load_factor_percent = max_load_factor_percent; + new_info.value_type_size = options.value_type_size; + new_info.max_load_factor_percent = options.max_load_factor_percent; new_info.num_deleted_entries = 0; new_info.num_deleted_key_value_bytes = 0; @@ -458,7 +529,7 @@ PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem, /*file_offset=*/0, /*mmap_size=*/sizeof(Crcs) + sizeof(Info))); return std::unique_ptr<PersistentHashMap>(new PersistentHashMap( - filesystem, base_dir, std::move(metadata_mmapped_file), + filesystem, base_dir, options, std::move(metadata_mmapped_file), std::move(bucket_storage), std::move(entry_storage), std::move(kv_storage))); } @@ -466,8 +537,7 @@ PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem, /* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem, std::string_view base_dir, - int32_t value_type_size, - int32_t max_load_factor_percent) { + const Options& options) { // Mmap the content of the crcs and info. ICING_ASSIGN_OR_RETURN( MemoryMappedFile metadata_mmapped_file, @@ -477,21 +547,41 @@ PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem, ICING_RETURN_IF_ERROR(metadata_mmapped_file.Remap( /*file_offset=*/0, /*mmap_size=*/sizeof(Crcs) + sizeof(Info))); - // Initialize 3 storages + int32_t max_num_buckets_required = CalculateNumBucketsRequired( + options.max_num_entries, options.max_load_factor_percent); + + // Initialize bucket_storage + int32_t pre_mapping_mmap_size = sizeof(Bucket) * max_num_buckets_required; + int32_t max_file_size = + pre_mapping_mmap_size + FileBackedVector<Bucket>::Header::kHeaderSize; ICING_ASSIGN_OR_RETURN( std::unique_ptr<FileBackedVector<Bucket>> bucket_storage, FileBackedVector<Bucket>::Create( filesystem, GetBucketStorageFilePath(base_dir, kSubDirectory), - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size, + pre_mapping_mmap_size)); + + // Initialize entry_storage + pre_mapping_mmap_size = sizeof(Entry) * options.max_num_entries; + max_file_size = + pre_mapping_mmap_size + FileBackedVector<Entry>::Header::kHeaderSize; ICING_ASSIGN_OR_RETURN( std::unique_ptr<FileBackedVector<Entry>> entry_storage, FileBackedVector<Entry>::Create( filesystem, GetEntryStorageFilePath(base_dir, kSubDirectory), - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size, + pre_mapping_mmap_size)); + + // Initialize kv_storage + pre_mapping_mmap_size = + options.average_kv_byte_size * options.max_num_entries; + max_file_size = + pre_mapping_mmap_size + FileBackedVector<char>::Header::kHeaderSize; ICING_ASSIGN_OR_RETURN(std::unique_ptr<FileBackedVector<char>> kv_storage, FileBackedVector<char>::Create( filesystem, GetKeyValueStorageFilePath(base_dir), - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + max_file_size, pre_mapping_mmap_size)); Crcs* crcs_ptr = reinterpret_cast<Crcs*>( metadata_mmapped_file.mutable_region() + Crcs::kFileOffset); @@ -499,22 +589,38 @@ PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem, metadata_mmapped_file.mutable_region() + Info::kFileOffset); // Value type size should be consistent. - if (value_type_size != info_ptr->value_type_size) { + if (options.value_type_size != info_ptr->value_type_size) { return absl_ports::FailedPreconditionError("Incorrect value type size"); } + // Current # of entries should not exceed options.max_num_entries + // We compute max_file_size of 3 storages by options.max_num_entries. Since we + // won't recycle space of deleted entries (and key-value bytes), they're still + // occupying space in storages. Even if # of "active" entries doesn't exceed + // options.max_num_entries, the new kvp to be inserted still potentially + // exceeds max_file_size. + // Therefore, we should use entry_storage->num_elements() instead of # of + // "active" entries + // (i.e. entry_storage->num_elements() - info_ptr->num_deleted_entries) to + // check. This feature avoids storages being grown extremely large when there + // are many Delete() and Put() operations. + if (entry_storage->num_elements() > options.max_num_entries) { + return absl_ports::FailedPreconditionError( + "Current # of entries exceeds max num entries"); + } + // Validate checksums of info and 3 storages. ICING_RETURN_IF_ERROR( ValidateChecksums(crcs_ptr, info_ptr, bucket_storage.get(), entry_storage.get(), kv_storage.get())); // Allow max_load_factor_percent_ change. - if (max_load_factor_percent != info_ptr->max_load_factor_percent) { + if (options.max_load_factor_percent != info_ptr->max_load_factor_percent) { ICING_VLOG(2) << "Changing max_load_factor_percent from " << info_ptr->max_load_factor_percent << " to " - << max_load_factor_percent; + << options.max_load_factor_percent; - info_ptr->max_load_factor_percent = max_load_factor_percent; + info_ptr->max_load_factor_percent = options.max_load_factor_percent; crcs_ptr->component_crcs.info_crc = info_ptr->ComputeChecksum().Get(); crcs_ptr->all_crc = crcs_ptr->component_crcs.ComputeChecksum().Get(); ICING_RETURN_IF_ERROR(metadata_mmapped_file.PersistToDisk()); @@ -522,7 +628,7 @@ PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem, auto persistent_hash_map = std::unique_ptr<PersistentHashMap>(new PersistentHashMap( - filesystem, base_dir, std::move(metadata_mmapped_file), + filesystem, base_dir, options, std::move(metadata_mmapped_file), std::move(bucket_storage), std::move(entry_storage), std::move(kv_storage))); ICING_RETURN_IF_ERROR( @@ -576,8 +682,17 @@ libtextclassifier3::Status PersistentHashMap::CopyEntryValue( libtextclassifier3::Status PersistentHashMap::Insert(int32_t bucket_idx, std::string_view key, const void* value) { - // If size() + 1 exceeds Entry::kMaxNumEntries, then return error. - if (size() > Entry::kMaxNumEntries - 1) { + // If entry_storage_->num_elements() + 1 exceeds options_.max_num_entries, + // then return error. + // We compute max_file_size of 3 storages by options_.max_num_entries. Since + // we won't recycle space of deleted entries (and key-value bytes), they're + // still occupying space in storages. Even if # of "active" entries (i.e. + // size()) doesn't exceed options_.max_num_entries, the new kvp to be inserted + // still potentially exceeds max_file_size. + // Therefore, we should use entry_storage_->num_elements() instead of size() + // to check. This feature avoids storages being grown extremely large when + // there are many Delete() and Put() operations. + if (entry_storage_->num_elements() > options_.max_num_entries - 1) { return absl_ports::ResourceExhaustedError("Cannot insert new entry"); } diff --git a/icing/file/persistent-hash-map.h b/icing/file/persistent-hash-map.h index ef3995c..57fa070 100644 --- a/icing/file/persistent-hash-map.h +++ b/icing/file/persistent-hash-map.h @@ -134,12 +134,10 @@ class PersistentHashMap { // Bucket class Bucket { public: - // Absolute max # of buckets allowed. Since max file size on Android is - // 2^31-1, we can at most have ~2^29 buckets. To make it power of 2, round - // it down to 2^28. Also since we're using FileBackedVector to store - // buckets, add some static_asserts to ensure numbers here are compatible - // with FileBackedVector. - static constexpr int32_t kMaxNumBuckets = 1 << 28; + // Absolute max # of buckets allowed. Since we're using FileBackedVector to + // store buckets, add some static_asserts to ensure numbers here are + // compatible with FileBackedVector. + static constexpr int32_t kMaxNumBuckets = 1 << 24; explicit Bucket(int32_t head_entry_index = Entry::kInvalidIndex) : head_entry_index_(head_entry_index) {} @@ -170,16 +168,14 @@ class PersistentHashMap { // Entry class Entry { public: - // Absolute max # of entries allowed. Since max file size on Android is - // 2^31-1, we can at most have ~2^28 entries. To make it power of 2, round - // it down to 2^27. Also since we're using FileBackedVector to store - // entries, add some static_asserts to ensure numbers here are compatible - // with FileBackedVector. + // Absolute max # of entries allowed. Since we're using FileBackedVector to + // store entries, add some static_asserts to ensure numbers here are + // compatible with FileBackedVector. // // Still the actual max # of entries are determined by key-value storage, // since length of the key varies and affects # of actual key-value pairs // that can be stored. - static constexpr int32_t kMaxNumEntries = 1 << 27; + static constexpr int32_t kMaxNumEntries = 1 << 23; static constexpr int32_t kMaxIndex = kMaxNumEntries - 1; static constexpr int32_t kInvalidIndex = -1; @@ -217,19 +213,64 @@ class PersistentHashMap { "Max # of entries cannot fit into FileBackedVector"); // Key-value serialized type - static constexpr int32_t kMaxKVTotalByteSize = - (FileBackedVector<char>::kMaxFileSize - - FileBackedVector<char>::Header::kHeaderSize) / - FileBackedVector<char>::kElementTypeSize; + static constexpr int32_t kMaxKVTotalByteSize = 1 << 28; static constexpr int32_t kMaxKVIndex = kMaxKVTotalByteSize - 1; static constexpr int32_t kInvalidKVIndex = -1; static_assert(sizeof(char) == FileBackedVector<char>::kElementTypeSize, "Char type size is inconsistent with FileBackedVector element " "type size"); + static_assert(kMaxKVTotalByteSize <= + FileBackedVector<char>::kMaxFileSize - + FileBackedVector<char>::Header::kHeaderSize, + "Max total byte size of key value pairs cannot fit into " + "FileBackedVector"); + + static constexpr int32_t kMaxValueTypeSize = 1 << 10; + + struct Options { + static constexpr int32_t kDefaultMaxLoadFactorPercent = 100; + static constexpr int32_t kDefaultAverageKVByteSize = 32; + static constexpr int32_t kDefaultInitNumBuckets = 1 << 13; + + explicit Options( + int32_t value_type_size_in, + int32_t max_num_entries_in = Entry::kMaxNumEntries, + int32_t max_load_factor_percent_in = kDefaultMaxLoadFactorPercent, + int32_t average_kv_byte_size_in = kDefaultAverageKVByteSize, + int32_t init_num_buckets_in = kDefaultInitNumBuckets) + : value_type_size(value_type_size_in), + max_num_entries(max_num_entries_in), + max_load_factor_percent(max_load_factor_percent_in), + average_kv_byte_size(average_kv_byte_size_in), + init_num_buckets(init_num_buckets_in) {} + + bool IsValid() const; + + // (fixed) size of the serialized value type for hash map. + int32_t value_type_size; + + // Max # of entries, default Entry::kMaxNumEntries. + int32_t max_num_entries; + + // Percentage of the max loading for the hash map. If load_factor_percent + // exceeds max_load_factor_percent, then rehash will be invoked (and # of + // buckets will be doubled). + // load_factor_percent = 100 * num_keys / num_buckets + // + // Note that load_factor_percent exceeding 100 is considered valid. + int32_t max_load_factor_percent; + + // Average byte size of a key value pair. It is used to estimate kv_storage_ + // pre_mapping_mmap_size. + int32_t average_kv_byte_size; + + // Initial # of buckets for the persistent hash map. It should be 2's power. + // It is used when creating new persistent hash map and ignored when + // creating the instance from existing files. + int32_t init_num_buckets; + }; static constexpr int32_t kVersion = 1; - static constexpr int32_t kDefaultMaxLoadFactorPercent = 100; - static constexpr int32_t kDefaultInitNumBuckets = 8192; static constexpr std::string_view kFilePrefix = "persistent_hash_map"; // Only metadata, bucket, entry files are stored under this sub-directory, for @@ -243,28 +284,17 @@ class PersistentHashMap { // sub-directory and files to be stored. If base_dir doesn't exist, // then PersistentHashMap will automatically create it. If files // exist, then it will initialize the hash map from existing files. - // value_type_size: (fixed) size of the serialized value type for hash map. - // max_load_factor_percent: percentage of the max loading for the hash map. - // load_factor_percent = 100 * num_keys / num_buckets - // If load_factor_percent exceeds - // max_load_factor_percent, then rehash will be - // invoked (and # of buckets will be doubled). - // Note that load_factor_percent exceeding 100 is - // considered valid. - // init_num_buckets: initial # of buckets for the persistent hash map. It is - // used when creating new persistent hash map and ignored - // when creating the instance from existing files. + // options: Options instance. // // Returns: + // INVALID_ARGUMENT_ERROR if any value in options is invalid. // FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored // checksum. // INTERNAL_ERROR on I/O errors. // Any FileBackedVector errors. static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> Create(const Filesystem& filesystem, std::string_view base_dir, - int32_t value_type_size, - int32_t max_load_factor_percent = kDefaultMaxLoadFactorPercent, - int32_t init_num_buckets = kDefaultInitNumBuckets); + const Options& options); ~PersistentHashMap(); @@ -275,7 +305,7 @@ class PersistentHashMap { // // Returns: // OK on success - // RESOURCE_EXHAUSTED_ERROR if # of entries reach kMaxNumEntries + // RESOURCE_EXHAUSTED_ERROR if # of entries reach options_.max_num_entries // INVALID_ARGUMENT_ERROR if the key is invalid (i.e. contains '\0') // INTERNAL_ERROR on I/O error or any data inconsistency // Any FileBackedVector errors @@ -373,12 +403,13 @@ class PersistentHashMap { explicit PersistentHashMap( const Filesystem& filesystem, std::string_view base_dir, - MemoryMappedFile&& metadata_mmapped_file, + const Options& options, MemoryMappedFile&& metadata_mmapped_file, std::unique_ptr<FileBackedVector<Bucket>> bucket_storage, std::unique_ptr<FileBackedVector<Entry>> entry_storage, std::unique_ptr<FileBackedVector<char>> kv_storage) : filesystem_(&filesystem), base_dir_(base_dir), + options_(options), metadata_mmapped_file_(std::make_unique<MemoryMappedFile>( std::move(metadata_mmapped_file))), bucket_storage_(std::move(bucket_storage)), @@ -387,13 +418,11 @@ class PersistentHashMap { static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> InitializeNewFiles(const Filesystem& filesystem, std::string_view base_dir, - int32_t value_type_size, int32_t max_load_factor_percent, - int32_t init_num_buckets); + const Options& options); static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> InitializeExistingFiles(const Filesystem& filesystem, - std::string_view base_dir, int32_t value_type_size, - int32_t max_load_factor_percent); + std::string_view base_dir, const Options& options); // Find the index of the target entry (that contains the key) from a bucket // (specified by bucket index). Also return the previous entry index, since @@ -457,6 +486,8 @@ class PersistentHashMap { const Filesystem* filesystem_; std::string base_dir_; + Options options_; + std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_; // Storages diff --git a/icing/file/persistent-hash-map_test.cc b/icing/file/persistent-hash-map_test.cc index 8024388..8fde4a8 100644 --- a/icing/file/persistent-hash-map_test.cc +++ b/icing/file/persistent-hash-map_test.cc @@ -24,20 +24,11 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "icing/file/file-backed-vector.h" #include "icing/file/filesystem.h" #include "icing/testing/common-matchers.h" #include "icing/testing/tmp-directory.h" #include "icing/util/crc32.h" -namespace icing { -namespace lib { - -namespace { - -static constexpr int32_t kCorruptedValueOffset = 3; -static constexpr int32_t kTestInitNumBuckets = 1; - using ::testing::Contains; using ::testing::Eq; using ::testing::Gt; @@ -51,10 +42,19 @@ using ::testing::Pointee; using ::testing::SizeIs; using ::testing::UnorderedElementsAre; +namespace icing { +namespace lib { + +namespace { + using Bucket = PersistentHashMap::Bucket; using Crcs = PersistentHashMap::Crcs; using Entry = PersistentHashMap::Entry; using Info = PersistentHashMap::Info; +using Options = PersistentHashMap::Options; + +static constexpr int32_t kCorruptedValueOffset = 3; +static constexpr int32_t kTestInitNumBuckets = 1; class PersistentHashMapTest : public ::testing::Test { protected: @@ -95,10 +95,110 @@ class PersistentHashMapTest : public ::testing::Test { std::string base_dir_; }; +TEST_F(PersistentHashMapTest, OptionsInvalidValueTypeSize) { + Options options(/*value_type_size_in=*/sizeof(int)); + ASSERT_TRUE(options.IsValid()); + + options.value_type_size = -1; + EXPECT_FALSE(options.IsValid()); + + options.value_type_size = 0; + EXPECT_FALSE(options.IsValid()); + + options.value_type_size = PersistentHashMap::kMaxValueTypeSize + 1; + EXPECT_FALSE(options.IsValid()); +} + +TEST_F(PersistentHashMapTest, OptionsInvalidMaxNumEntries) { + Options options(/*value_type_size_in=*/sizeof(int)); + ASSERT_TRUE(options.IsValid()); + + options.max_num_entries = -1; + EXPECT_FALSE(options.IsValid()); + + options.max_num_entries = 0; + EXPECT_FALSE(options.IsValid()); + + options.max_num_entries = Entry::kMaxNumEntries + 1; + EXPECT_FALSE(options.IsValid()); +} + +TEST_F(PersistentHashMapTest, OptionsInvalidMaxLoadFactorPercent) { + Options options(/*value_type_size_in=*/sizeof(int)); + ASSERT_TRUE(options.IsValid()); + + options.max_load_factor_percent = -1; + EXPECT_FALSE(options.IsValid()); + + options.max_load_factor_percent = 0; + EXPECT_FALSE(options.IsValid()); +} + +TEST_F(PersistentHashMapTest, OptionsInvalidAverageKVByteSize) { + Options options(/*value_type_size_in=*/sizeof(int)); + ASSERT_TRUE(options.IsValid()); + + options.average_kv_byte_size = -1; + EXPECT_FALSE(options.IsValid()); + + options.average_kv_byte_size = 0; + EXPECT_FALSE(options.IsValid()); +} + +TEST_F(PersistentHashMapTest, OptionsInvalidInitNumBuckets) { + Options options(/*value_type_size_in=*/sizeof(int)); + ASSERT_TRUE(options.IsValid()); + + options.init_num_buckets = -1; + EXPECT_FALSE(options.IsValid()); + + options.init_num_buckets = 0; + EXPECT_FALSE(options.IsValid()); + + options.init_num_buckets = Bucket::kMaxNumBuckets + 1; + EXPECT_FALSE(options.IsValid()); + + // not 2's power + options.init_num_buckets = 3; + EXPECT_FALSE(options.IsValid()); +} + +TEST_F(PersistentHashMapTest, OptionsNumBucketsRequiredExceedsMaxNumBuckets) { + Options options(/*value_type_size_in=*/sizeof(int)); + ASSERT_TRUE(options.IsValid()); + + options.max_num_entries = Entry::kMaxNumEntries; + options.max_load_factor_percent = 30; + EXPECT_FALSE(options.IsValid()); +} + +TEST_F(PersistentHashMapTest, + OptionsEstimatedNumKeyValuePairExceedsStorageMaxSize) { + Options options(/*value_type_size_in=*/sizeof(int)); + ASSERT_TRUE(options.IsValid()); + + options.max_num_entries = 1 << 20; + options.average_kv_byte_size = 1 << 20; + ASSERT_THAT(static_cast<int64_t>(options.max_num_entries) * + options.average_kv_byte_size, + Gt(PersistentHashMap::kMaxKVTotalByteSize)); + EXPECT_FALSE(options.IsValid()); +} + TEST_F(PersistentHashMapTest, InvalidBaseDir) { - EXPECT_THAT(PersistentHashMap::Create(filesystem_, "/dev/null", - /*value_type_size=*/sizeof(int)), - StatusIs(libtextclassifier3::StatusCode::INTERNAL)); + EXPECT_THAT( + PersistentHashMap::Create(filesystem_, "/dev/null", + Options(/*value_type_size_in=*/sizeof(int))), + StatusIs(libtextclassifier3::StatusCode::INTERNAL)); +} + +TEST_F(PersistentHashMapTest, CreateWithInvalidOptionsShouldFail) { + Options invalid_options(/*value_type_size_in=*/-1); + ASSERT_FALSE(invalid_options.IsValid()); + + EXPECT_THAT( + PersistentHashMap::Create(filesystem_, base_dir_, invalid_options), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } TEST_F(PersistentHashMapTest, InitializeNewFiles) { @@ -107,7 +207,7 @@ TEST_F(PersistentHashMapTest, InitializeNewFiles) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + Options(/*value_type_size_in=*/sizeof(int)))); EXPECT_THAT(persistent_hash_map, Pointee(IsEmpty())); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); @@ -128,7 +228,7 @@ TEST_F(PersistentHashMapTest, InitializeNewFiles) { EXPECT_THAT(info.version, Eq(PersistentHashMap::kVersion)); EXPECT_THAT(info.value_type_size, Eq(sizeof(int))); EXPECT_THAT(info.max_load_factor_percent, - Eq(PersistentHashMap::kDefaultMaxLoadFactorPercent)); + Eq(Options::kDefaultMaxLoadFactorPercent)); EXPECT_THAT(info.num_deleted_entries, Eq(0)); EXPECT_THAT(info.num_deleted_key_value_bytes, Eq(0)); @@ -153,52 +253,81 @@ TEST_F(PersistentHashMapTest, InitializeNewFiles) { .Get())); } -TEST_F(PersistentHashMapTest, InitializeNewFilesWithCustomInitBucketSize) { +TEST_F(PersistentHashMapTest, InitializeNewFilesWithCustomInitNumBuckets) { + int custom_init_num_buckets = 128; + // Create new persistent hash map - int custom_init_bucket_size = 123; ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - custom_init_bucket_size)); - EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(custom_init_bucket_size)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/custom_init_num_buckets))); + EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(custom_init_num_buckets)); +} + +TEST_F(PersistentHashMapTest, + InitializeNewFilesWithInitNumBucketsSmallerThanNumBucketsRequired) { + int init_num_buckets = 65536; + + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/1, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/init_num_buckets))); + EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(init_num_buckets)); } -TEST_F(PersistentHashMapTest, InitBucketSizeShouldNotAffectExistingFiles) { - int init_bucket_size1 = 4; +TEST_F(PersistentHashMapTest, InitNumBucketsShouldNotAffectExistingFiles) { + Options options(/*value_type_size_in=*/sizeof(int)); + + int original_init_num_buckets = 4; { + options.init_num_buckets = original_init_num_buckets; + ASSERT_TRUE(options.IsValid()); + // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create( - filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - init_bucket_size1)); - EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(init_bucket_size1)); + PersistentHashMap::Create(filesystem_, base_dir_, options)); + EXPECT_THAT(persistent_hash_map->num_buckets(), + Eq(original_init_num_buckets)); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); } - int init_bucket_size2 = 8; + // Set new init_num_buckets. + options.init_num_buckets = 8; + ASSERT_TRUE(options.IsValid()); + ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - init_bucket_size2)); + PersistentHashMap::Create(filesystem_, base_dir_, options)); // # of buckets should still be the original value. - EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(init_bucket_size1)); + EXPECT_THAT(persistent_hash_map->num_buckets(), + Eq(original_init_num_buckets)); } TEST_F(PersistentHashMapTest, - TestInitializationFailsWithoutPersistToDiskOrDestruction) { + InitializationShouldFailWithoutPersistToDiskOrDestruction) { + Options options(/*value_type_size_in=*/sizeof(int)); + // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); // Put some key value pairs. ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); @@ -214,17 +343,17 @@ TEST_F(PersistentHashMapTest, // Without calling PersistToDisk, checksums will not be recomputed or synced // to disk, so initializing another instance on the same files should fail. - EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int)), + EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_, options), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } -TEST_F(PersistentHashMapTest, TestInitializationSucceedsWithPersistToDisk) { +TEST_F(PersistentHashMapTest, InitializationShouldSucceedWithPersistToDisk) { + Options options(/*value_type_size_in=*/sizeof(int)); + // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map1, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); // Put some key value pairs. ICING_ASSERT_OK(persistent_hash_map1->Put("a", Serialize(1).data())); @@ -245,20 +374,20 @@ TEST_F(PersistentHashMapTest, TestInitializationSucceedsWithPersistToDisk) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map2, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); EXPECT_THAT(persistent_hash_map2, Pointee(SizeIs(2))); EXPECT_THAT(GetValueByKey(persistent_hash_map2.get(), "a"), IsOkAndHolds(1)); EXPECT_THAT(GetValueByKey(persistent_hash_map2.get(), "b"), IsOkAndHolds(2)); } -TEST_F(PersistentHashMapTest, TestInitializationSucceedsAfterDestruction) { +TEST_F(PersistentHashMapTest, InitializationShouldSucceedAfterDestruction) { + Options options(/*value_type_size_in=*/sizeof(int)); + { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); ICING_ASSERT_OK(persistent_hash_map->Put("c", Serialize(3).data())); @@ -278,8 +407,7 @@ TEST_F(PersistentHashMapTest, TestInitializationSucceedsAfterDestruction) { // we should be able to get the same contents. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2))); EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); @@ -293,7 +421,7 @@ TEST_F(PersistentHashMapTest, ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + Options(/*value_type_size_in=*/sizeof(int)))); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); @@ -305,7 +433,8 @@ TEST_F(PersistentHashMapTest, ASSERT_THAT(sizeof(char), Not(Eq(sizeof(int)))); libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> persistent_hash_map_or = PersistentHashMap::Create( - filesystem_, base_dir_, /*value_type_size=*/sizeof(char)); + filesystem_, base_dir_, + Options(/*value_type_size_in=*/sizeof(char))); EXPECT_THAT(persistent_hash_map_or, StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(persistent_hash_map_or.status().error_message(), @@ -313,13 +442,55 @@ TEST_F(PersistentHashMapTest, } } +TEST_F(PersistentHashMapTest, + InitializeExistingFilesWithMaxNumEntriesSmallerThanSizeShouldFail) { + Options options(/*value_type_size_in=*/sizeof(int)); + + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, options)); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + + { + // Attempt to create the persistent hash map with max num entries smaller + // than the current size. This should fail. + options.max_num_entries = 1; + ASSERT_TRUE(options.IsValid()); + + EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_, options), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + } + + // Delete 1 kvp. + ICING_ASSERT_OK(persistent_hash_map->Delete("a")); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + + { + // Attempt to create the persistent hash map with max num entries: + // - Not smaller than current # of active kvps. + // - Smaller than # of all inserted kvps (regardless of activeness). + // This should fail. + options.max_num_entries = 1; + ASSERT_TRUE(options.IsValid()); + + EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_, options), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + } +} + TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongAllCrc) { + Options options(/*value_type_size_in=*/sizeof(int)); + { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); @@ -345,8 +516,8 @@ TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongAllCrc) { // Attempt to create the persistent hash map with metadata containing // corrupted all_crc. This should fail. libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> - persistent_hash_map_or = PersistentHashMap::Create( - filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + persistent_hash_map_or = + PersistentHashMap::Create(filesystem_, base_dir_, options); EXPECT_THAT(persistent_hash_map_or, StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(persistent_hash_map_or.status().error_message(), @@ -356,12 +527,13 @@ TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongAllCrc) { TEST_F(PersistentHashMapTest, InitializeExistingFilesWithCorruptedInfoShouldFail) { + Options options(/*value_type_size_in=*/sizeof(int)); + { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); @@ -386,8 +558,8 @@ TEST_F(PersistentHashMapTest, // Attempt to create the persistent hash map with info that doesn't match // its checksum and confirm that it fails. libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> - persistent_hash_map_or = PersistentHashMap::Create( - filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + persistent_hash_map_or = + PersistentHashMap::Create(filesystem_, base_dir_, options); EXPECT_THAT(persistent_hash_map_or, StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(persistent_hash_map_or.status().error_message(), @@ -397,12 +569,13 @@ TEST_F(PersistentHashMapTest, TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongBucketStorageCrc) { + Options options(/*value_type_size_in=*/sizeof(int)); + { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); @@ -430,8 +603,8 @@ TEST_F(PersistentHashMapTest, // Attempt to create the persistent hash map with metadata containing // corrupted bucket_storage_crc. This should fail. libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> - persistent_hash_map_or = PersistentHashMap::Create( - filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + persistent_hash_map_or = + PersistentHashMap::Create(filesystem_, base_dir_, options); EXPECT_THAT(persistent_hash_map_or, StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT( @@ -441,12 +614,13 @@ TEST_F(PersistentHashMapTest, } TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongEntryStorageCrc) { + Options options(/*value_type_size_in=*/sizeof(int)); + { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); @@ -474,8 +648,8 @@ TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongEntryStorageCrc) { // Attempt to create the persistent hash map with metadata containing // corrupted entry_storage_crc. This should fail. libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> - persistent_hash_map_or = PersistentHashMap::Create( - filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + persistent_hash_map_or = + PersistentHashMap::Create(filesystem_, base_dir_, options); EXPECT_THAT(persistent_hash_map_or, StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(persistent_hash_map_or.status().error_message(), @@ -485,12 +659,13 @@ TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongEntryStorageCrc) { TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongKeyValueStorageCrc) { + Options options(/*value_type_size_in=*/sizeof(int)); + { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + PersistentHashMap::Create(filesystem_, base_dir_, options)); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); @@ -518,8 +693,8 @@ TEST_F(PersistentHashMapTest, // Attempt to create the persistent hash map with metadata containing // corrupted kv_storage_crc. This should fail. libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> - persistent_hash_map_or = PersistentHashMap::Create( - filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + persistent_hash_map_or = + PersistentHashMap::Create(filesystem_, base_dir_, options); EXPECT_THAT(persistent_hash_map_or, StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT( @@ -530,15 +705,18 @@ TEST_F(PersistentHashMapTest, TEST_F(PersistentHashMapTest, InitializeExistingFilesAllowDifferentMaxLoadFactorPercent) { + Options options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets); + { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create( - filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create(filesystem_, base_dir_, options)); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); @@ -549,18 +727,19 @@ TEST_F(PersistentHashMapTest, ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); } - int32_t new_max_load_factor_percent = 200; { - ASSERT_THAT(new_max_load_factor_percent, - Not(Eq(PersistentHashMap::kDefaultMaxLoadFactorPercent))); + // Set new max_load_factor_percent. + options.max_load_factor_percent = 200; + ASSERT_TRUE(options.IsValid()); + ASSERT_THAT(options.max_load_factor_percent, + Not(Eq(Options::kDefaultMaxLoadFactorPercent))); + // Attempt to create the persistent hash map with different max load factor // percent. This should succeed and metadata should be modified correctly. // Also verify all entries should remain unchanged. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - new_max_load_factor_percent)); + PersistentHashMap::Create(filesystem_, base_dir_, options)); EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2))); EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); @@ -578,17 +757,15 @@ TEST_F(PersistentHashMapTest, Info info; ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &info, sizeof(Info), Info::kFileOffset)); - EXPECT_THAT(info.max_load_factor_percent, Eq(new_max_load_factor_percent)); + EXPECT_THAT(info.max_load_factor_percent, + Eq(options.max_load_factor_percent)); // Also should update crcs correctly. We test it by creating instance again // and make sure it won't get corrupted crcs/info errors. { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - new_max_load_factor_percent, - kTestInitNumBuckets)); + PersistentHashMap::Create(filesystem_, base_dir_, options)); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); } @@ -596,17 +773,20 @@ TEST_F(PersistentHashMapTest, TEST_F(PersistentHashMapTest, InitializeExistingFilesWithDifferentMaxLoadFactorPercentShouldRehash) { + Options options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets); + double prev_loading_percent; int prev_num_buckets; { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create( - filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create(filesystem_, base_dir_, options)); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); ICING_ASSERT_OK(persistent_hash_map->Put("c", Serialize(3).data())); @@ -620,47 +800,47 @@ TEST_F(PersistentHashMapTest, persistent_hash_map->num_buckets(); prev_num_buckets = persistent_hash_map->num_buckets(); ASSERT_THAT(prev_loading_percent, - Not(Gt(PersistentHashMap::kDefaultMaxLoadFactorPercent))); + Not(Gt(Options::kDefaultMaxLoadFactorPercent))); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); } - int32_t greater_max_load_factor_percent = 150; { - ASSERT_THAT(greater_max_load_factor_percent, Gt(prev_loading_percent)); + // Set greater max_load_factor_percent. + options.max_load_factor_percent = 150; + ASSERT_TRUE(options.IsValid()); + ASSERT_THAT(options.max_load_factor_percent, Gt(prev_loading_percent)); + // Attempt to create the persistent hash map with max load factor greater // than previous loading. There should be no rehashing and # of buckets // should remain the same. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - greater_max_load_factor_percent, - kTestInitNumBuckets)); + PersistentHashMap::Create(filesystem_, base_dir_, options)); EXPECT_THAT(persistent_hash_map->num_buckets(), Eq(prev_num_buckets)); ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); } - int32_t smaller_max_load_factor_percent = 25; { - ASSERT_THAT(smaller_max_load_factor_percent, Lt(prev_loading_percent)); + // Set smaller max_load_factor_percent. + options.max_load_factor_percent = 50; + ASSERT_TRUE(options.IsValid()); + ASSERT_THAT(options.max_load_factor_percent, Lt(prev_loading_percent)); + // Attempt to create the persistent hash map with max load factor smaller // than previous loading. There should be rehashing since the loading // exceeds the limit. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - smaller_max_load_factor_percent, - kTestInitNumBuckets)); + PersistentHashMap::Create(filesystem_, base_dir_, options)); // After changing max_load_factor_percent, there should be rehashing and the // new loading should not be greater than the new max load factor. EXPECT_THAT(persistent_hash_map->size() * 100.0 / persistent_hash_map->num_buckets(), - Not(Gt(smaller_max_load_factor_percent))); + Not(Gt(options.max_load_factor_percent))); EXPECT_THAT(persistent_hash_map->num_buckets(), Not(Eq(prev_num_buckets))); EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); @@ -675,10 +855,15 @@ TEST_F(PersistentHashMapTest, PutAndGet) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); EXPECT_THAT(persistent_hash_map, Pointee(IsEmpty())); EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), @@ -706,10 +891,15 @@ TEST_F(PersistentHashMapTest, PutShouldOverwriteValueIfKeyExists) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ICING_ASSERT_OK( persistent_hash_map->Put("default-google.com", Serialize(100).data())); @@ -734,10 +924,15 @@ TEST_F(PersistentHashMapTest, ShouldRehash) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); int original_num_buckets = persistent_hash_map->num_buckets(); // Insert 100 key value pairs. There should be rehashing so the loading of @@ -749,7 +944,7 @@ TEST_F(PersistentHashMapTest, ShouldRehash) { EXPECT_THAT(persistent_hash_map->size() * 100.0 / persistent_hash_map->num_buckets(), - Not(Gt(PersistentHashMap::kDefaultMaxLoadFactorPercent))); + Not(Gt(Options::kDefaultMaxLoadFactorPercent))); } EXPECT_THAT(persistent_hash_map->num_buckets(), Not(Eq(original_num_buckets))); @@ -765,10 +960,15 @@ TEST_F(PersistentHashMapTest, GetOrPutShouldPutIfKeyDoesNotExist) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); @@ -786,10 +986,15 @@ TEST_F(PersistentHashMapTest, GetOrPutShouldGetIfKeyExists) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ASSERT_THAT( persistent_hash_map->Put("default-google.com", Serialize(1).data()), @@ -810,10 +1015,15 @@ TEST_F(PersistentHashMapTest, Delete) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); // Delete a non-existing key should get NOT_FOUND error EXPECT_THAT(persistent_hash_map->Delete("default-google.com"), @@ -856,10 +1066,15 @@ TEST_F(PersistentHashMapTest, DeleteMultiple) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); std::unordered_map<std::string, int> existing_keys; std::unordered_set<std::string> deleted_keys; @@ -909,10 +1124,14 @@ TEST_F(PersistentHashMapTest, DeleteBucketHeadElement) { // Preventing rehashing makes it much easier to test collisions. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - /*max_load_factor_percent=*/1000, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/1000, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ICING_ASSERT_OK( persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); @@ -943,10 +1162,14 @@ TEST_F(PersistentHashMapTest, DeleteBucketIntermediateElement) { // Preventing rehashing makes it much easier to test collisions. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - /*max_load_factor_percent=*/1000, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/1000, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ICING_ASSERT_OK( persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); @@ -976,10 +1199,14 @@ TEST_F(PersistentHashMapTest, DeleteBucketTailElement) { // Preventing rehashing makes it much easier to test collisions. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - /*max_load_factor_percent=*/1000, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/1000, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ICING_ASSERT_OK( persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); @@ -1010,10 +1237,14 @@ TEST_F(PersistentHashMapTest, DeleteBucketOnlySingleElement) { // Preventing rehashing makes it much easier to test collisions. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - /*max_load_factor_percent=*/1000, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/1000, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ICING_ASSERT_OK( persistent_hash_map->Put("default-google.com", Serialize(100).data())); @@ -1026,12 +1257,48 @@ TEST_F(PersistentHashMapTest, DeleteBucketOnlySingleElement) { StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } +TEST_F(PersistentHashMapTest, OperationsWhenReachingMaxNumEntries) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/1, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/1))); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com", Serialize(100).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + + // Put new key should fail. + EXPECT_THAT( + persistent_hash_map->Put("default-youtube.com", Serialize(50).data()), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + // Modify existing key should succeed. + EXPECT_THAT( + persistent_hash_map->Put("default-google.com", Serialize(200).data()), + IsOk()); + + // Put after delete should still fail. See the comment in + // PersistentHashMap::Insert for more details. + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com")); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(0))); + EXPECT_THAT( + persistent_hash_map->Put("default-youtube.com", Serialize(50).data()), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + TEST_F(PersistentHashMapTest, ShouldFailIfKeyContainsTerminationCharacter) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int))); + Options(/*value_type_size_in=*/sizeof(int)))); const char invalid_key[] = "a\0bc"; std::string_view invalid_key_view(invalid_key, 4); @@ -1051,10 +1318,15 @@ TEST_F(PersistentHashMapTest, EmptyHashMapIterator) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); EXPECT_FALSE(persistent_hash_map->GetIterator().Advance()); } @@ -1063,10 +1335,15 @@ TEST_F(PersistentHashMapTest, Iterator) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); std::unordered_map<std::string, int> kvps; // Insert 100 key value pairs @@ -1085,10 +1362,15 @@ TEST_F(PersistentHashMapTest, IteratorAfterDeletingFirstKeyValuePair) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ICING_ASSERT_OK( persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); @@ -1109,10 +1391,15 @@ TEST_F(PersistentHashMapTest, IteratorAfterDeletingIntermediateKeyValuePair) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ICING_ASSERT_OK( persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); @@ -1133,10 +1420,15 @@ TEST_F(PersistentHashMapTest, IteratorAfterDeletingLastKeyValuePair) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ICING_ASSERT_OK( persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); @@ -1157,10 +1449,15 @@ TEST_F(PersistentHashMapTest, IteratorAfterDeletingAllKeyValuePairs) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<PersistentHashMap> persistent_hash_map, - PersistentHashMap::Create(filesystem_, base_dir_, - /*value_type_size=*/sizeof(int), - PersistentHashMap::kDefaultMaxLoadFactorPercent, - kTestInitNumBuckets)); + PersistentHashMap::Create( + filesystem_, base_dir_, + Options( + /*value_type_size_in=*/sizeof(int), + /*max_num_entries_in=*/Entry::kMaxNumEntries, + /*max_load_factor_percent_in=*/ + Options::kDefaultMaxLoadFactorPercent, + /*average_kv_byte_size_in=*/Options::kDefaultAverageKVByteSize, + /*init_num_buckets_in=*/kTestInitNumBuckets))); ICING_ASSERT_OK( persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); diff --git a/icing/file/posting_list/posting-list-accessor.cc b/icing/file/posting_list/posting-list-accessor.cc new file mode 100644 index 0000000..00f4417 --- /dev/null +++ b/icing/file/posting_list/posting-list-accessor.cc @@ -0,0 +1,110 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/file/posting_list/posting-list-accessor.h" + +#include <cstdint> +#include <memory> + +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/index-block.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +void PostingListAccessor::FlushPreexistingPostingList() { + if (preexisting_posting_list_->block.max_num_posting_lists() == 1) { + // If this is a max-sized posting list, then just keep track of the id for + // chaining. It'll be flushed to disk when preexisting_posting_list_ is + // destructed. + prev_block_identifier_ = preexisting_posting_list_->id; + } else { + // If this is NOT a max-sized posting list, then our data have outgrown this + // particular posting list. Move the data into the in-memory posting list + // and free this posting list. + // + // Move will always succeed since posting_list_buffer_ is max_pl_bytes. + GetSerializer()->MoveFrom(/*dst=*/&posting_list_buffer_, + /*src=*/&preexisting_posting_list_->posting_list); + + // Now that all the contents of this posting list have been copied, there's + // no more use for it. Make it available to be used for another posting + // list. + storage_->FreePostingList(std::move(*preexisting_posting_list_)); + } + preexisting_posting_list_.reset(); +} + +libtextclassifier3::Status PostingListAccessor::FlushInMemoryPostingList() { + // We exceeded max_pl_bytes(). Need to flush posting_list_buffer_ and update + // the chain. + uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes( + storage_->block_size(), GetSerializer()->GetDataTypeBytes()); + ICING_ASSIGN_OR_RETURN(PostingListHolder holder, + storage_->AllocatePostingList(max_posting_list_bytes)); + holder.block.set_next_block_index(prev_block_identifier_.block_index()); + prev_block_identifier_ = holder.id; + return GetSerializer()->MoveFrom(/*dst=*/&holder.posting_list, + /*src=*/&posting_list_buffer_); +} + +PostingListAccessor::FinalizeResult PostingListAccessor::Finalize() && { + if (preexisting_posting_list_ != nullptr) { + // Our data are already in an existing posting list. Nothing else to do, but + // return its id. + return FinalizeResult(libtextclassifier3::Status::OK, + preexisting_posting_list_->id); + } + if (GetSerializer()->GetBytesUsed(&posting_list_buffer_) <= 0) { + return FinalizeResult(absl_ports::InvalidArgumentError( + "Can't finalize an empty PostingListAccessor. " + "There's nothing to Finalize!"), + PostingListIdentifier::kInvalid); + } + uint32_t posting_list_bytes = + GetSerializer()->GetMinPostingListSizeToFit(&posting_list_buffer_); + if (prev_block_identifier_.is_valid()) { + posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes( + storage_->block_size(), GetSerializer()->GetDataTypeBytes()); + } + auto holder_or = storage_->AllocatePostingList(posting_list_bytes); + if (!holder_or.ok()) { + return FinalizeResult(std::move(holder_or).status(), + prev_block_identifier_); + } + PostingListHolder holder = std::move(holder_or).ValueOrDie(); + if (prev_block_identifier_.is_valid()) { + holder.block.set_next_block_index(prev_block_identifier_.block_index()); + } + + // Move to allocated area. This should never actually return an error. We know + // that editor.posting_list() is valid because it wouldn't have successfully + // returned by AllocatePostingList if it wasn't. We know posting_list_buffer_ + // is valid because we created it in-memory. And finally, we know that the + // data from posting_list_buffer_ will fit in editor.posting_list() because we + // requested it be at at least posting_list_bytes large. + auto status = GetSerializer()->MoveFrom(/*dst=*/&holder.posting_list, + /*src=*/&posting_list_buffer_); + if (!status.ok()) { + return FinalizeResult(std::move(status), prev_block_identifier_); + } + return FinalizeResult(libtextclassifier3::Status::OK, holder.id); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/main/posting-list-accessor.h b/icing/file/posting_list/posting-list-accessor.h index 3f93c3a..c7d614f 100644 --- a/icing/index/main/posting-list-accessor.h +++ b/icing/file/posting_list/posting-list-accessor.h @@ -1,4 +1,4 @@ -// Copyright (C) 2019 Google LLC +// Copyright (C) 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,20 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef ICING_INDEX_POSTING_LIST_ACCESSOR_H_ -#define ICING_INDEX_POSTING_LIST_ACCESSOR_H_ +#ifndef ICING_FILE_POSTING_LIST_POSTING_LIST_ACCESSOR_H_ +#define ICING_FILE_POSTING_LIST_POSTING_LIST_ACCESSOR_H_ #include <cstdint> #include <memory> -#include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" -#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/file/posting_list/flash-index-storage.h" #include "icing/file/posting_list/posting-list-identifier.h" #include "icing/file/posting_list/posting-list-used.h" -#include "icing/index/hit/hit.h" -#include "icing/index/main/posting-list-used-hit-serializer.h" namespace icing { namespace lib { @@ -38,62 +34,14 @@ namespace lib { // 3. Ensure that PostingListUseds can only be freed by calling methods which // will also properly maintain the FlashIndexStorage free list and prevent // callers from modifying the Posting List after freeing. - -// This class is used to provide a simple abstraction for adding hits to posting -// lists. PostingListAccessor handles 1) selection of properly-sized posting -// lists for the accumulated hits during Finalize() and 2) chaining of max-sized -// posting lists. class PostingListAccessor { public: - // Creates an empty PostingListAccessor. - // - // RETURNS: - // - On success, a valid instance of PostingListAccessor - // - INVALID_ARGUMENT error if storage has an invalid block_size. - static libtextclassifier3::StatusOr<PostingListAccessor> Create( - FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer); - - // Create a PostingListAccessor with an existing posting list identified by - // existing_posting_list_id. - // - // The PostingListAccessor will add hits to this posting list until it is - // necessary either to 1) chain the posting list (if it is max-sized) or 2) - // move its hits to a larger posting list. - // - // RETURNS: - // - On success, a valid instance of PostingListAccessor - // - INVALID_ARGUMENT if storage has an invalid block_size. - static libtextclassifier3::StatusOr<PostingListAccessor> CreateFromExisting( - FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer, - PostingListIdentifier existing_posting_list_id); - - // Retrieve the next batch of hits for the posting list chain - // - // RETURNS: - // - On success, a vector of hits in the posting list chain - // - INTERNAL if called on an instance of PostingListAccessor that was - // created via PostingListAccessor::Create, if unable to read the next - // posting list in the chain or if the posting list has been corrupted - // somehow. - libtextclassifier3::StatusOr<std::vector<Hit>> GetNextHitsBatch(); - - // Prepend one hit. This may result in flushing the posting list to disk (if - // the PostingListAccessor holds a max-sized posting list that is full) or - // freeing a pre-existing posting list if it is too small to fit all hits - // necessary. - // - // RETURNS: - // - OK, on success - // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the - // previously added hit. - // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new - // posting list. - libtextclassifier3::Status PrependHit(const Hit& hit); + virtual ~PostingListAccessor() = default; struct FinalizeResult { // - OK on success // - INVALID_ARGUMENT if there was no pre-existing posting list and no - // hits were added + // data were added // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a // new posting list. libtextclassifier3::Status status; @@ -101,22 +49,27 @@ class PostingListAccessor { // if status is OK. May be valid if status is non-OK, but previous blocks // were written. PostingListIdentifier id; + + explicit FinalizeResult(libtextclassifier3::Status status_in, + PostingListIdentifier id_in) + : status(std::move(status_in)), id(std::move(id_in)) {} }; - // Write all accumulated hits to storage. + // Write all accumulated data to storage. // // If accessor points to a posting list chain with multiple posting lists in // the chain and unable to write the last posting list in the chain, Finalize // will return the error and also populate id with the id of the // second-to-last posting list. - static FinalizeResult Finalize(PostingListAccessor accessor); + FinalizeResult Finalize() &&; - private: + virtual PostingListUsedSerializer* GetSerializer() = 0; + + protected: explicit PostingListAccessor( - FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer, + FlashIndexStorage* storage, std::unique_ptr<uint8_t[]> posting_list_buffer_array, PostingListUsed posting_list_buffer) : storage_(storage), - serializer_(serializer), prev_block_identifier_(PostingListIdentifier::kInvalid), posting_list_buffer_array_(std::move(posting_list_buffer_array)), posting_list_buffer_(std::move(posting_list_buffer)), @@ -141,23 +94,21 @@ class PostingListAccessor { FlashIndexStorage* storage_; // Does not own. - PostingListUsedHitSerializer* serializer_; // Does not own. - // The PostingListIdentifier of the first max-sized posting list in the // posting list chain or PostingListIdentifier::kInvalid if there is no // posting list chain. PostingListIdentifier prev_block_identifier_; // An editor to an existing posting list on disk. If available (non-NULL), - // we'll try to add all hits to this posting list. Once this posting list + // we'll try to add all data to this posting list. Once this posting list // fills up, we'll either 1) chain it (if a max-sized posting list) and put - // future hits in posting_list_buffer_ or 2) copy all of its hits into + // future data in posting_list_buffer_ or 2) copy all of its data into // posting_list_buffer_ and free this pl (if not a max-sized posting list). // TODO(tjbarron) provide a benchmark to demonstrate the effects that re-using // existing posting lists has on latency. std::unique_ptr<PostingListHolder> preexisting_posting_list_; - // In-memory posting list used to buffer hits before writing them to the + // In-memory posting list used to buffer data before writing them to the // smallest on-disk posting list that will fit them. // posting_list_buffer_array_ owns the memory region that posting_list_buffer_ // interprets. Therefore, posting_list_buffer_array_ must have the same @@ -171,4 +122,4 @@ class PostingListAccessor { } // namespace lib } // namespace icing -#endif // ICING_INDEX_POSTING_LIST_ACCESSOR_H_ +#endif // ICING_FILE_POSTING_LIST_POSTING_LIST_ACCESSOR_H_ diff --git a/icing/file/posting_list/posting-list-common.h b/icing/file/posting_list/posting-list-common.h index cbe2ddf..44c6dd2 100644 --- a/icing/file/posting_list/posting-list-common.h +++ b/icing/file/posting_list/posting-list-common.h @@ -25,8 +25,6 @@ namespace lib { using PostingListIndex = int32_t; inline constexpr PostingListIndex kInvalidPostingListIndex = ~0U; -inline constexpr uint32_t kNumSpecialData = 2; - inline constexpr uint32_t kInvalidBlockIndex = 0; } // namespace lib diff --git a/icing/file/posting_list/posting-list-used.h b/icing/file/posting_list/posting-list-used.h index ec4b067..5821880 100644 --- a/icing/file/posting_list/posting-list-used.h +++ b/icing/file/posting_list/posting-list-used.h @@ -45,6 +45,28 @@ class PostingListUsed; // posting list. class PostingListUsedSerializer { public: + // Special data is either a DataType instance or data_start_offset. + template <typename DataType> + union SpecialData { + explicit SpecialData(const DataType& data) : data_(data) {} + + explicit SpecialData(uint32_t data_start_offset) + : data_start_offset_(data_start_offset) {} + + const DataType& data() const { return data_; } + + uint32_t data_start_offset() const { return data_start_offset_; } + void set_data_start_offset(uint32_t data_start_offset) { + data_start_offset_ = data_start_offset; + } + + private: + DataType data_; + uint32_t data_start_offset_; + } __attribute__((packed)); + + static constexpr uint32_t kNumSpecialData = 2; + virtual ~PostingListUsedSerializer() = default; // Returns byte size of the data type. diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index 4c4bf65..60e347e 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -35,6 +35,8 @@ #include "icing/index/index-processor.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/dummy-numeric-index.h" +#include "icing/join/join-processor.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/portable/endian.h" #include "icing/proto/debug.pb.h" @@ -477,6 +479,7 @@ void IcingSearchEngine::ResetMembers() { language_segmenter_.reset(); normalizer_.reset(); index_.reset(); + integer_index_.reset(); } libtextclassifier3::Status IcingSearchEngine::CheckInitMarkerFile( @@ -598,6 +601,11 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( std::string marker_filepath = MakeSetSchemaMarkerFilePath(options_.base_dir()); + + // TODO(b/249829533): switch to use persistent numeric index after + // implementing and initialize numeric index. + integer_index_ = std::make_unique<DummyNumericIndex<int64_t>>(); + libtextclassifier3::Status index_init_status; if (absl_ports::IsNotFound(schema_store_->GetSchema().status())) { // The schema was either lost or never set before. Wipe out the doc store @@ -864,6 +872,12 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( return result_proto; } + status = integer_index_->Reset(); + if (!status.ok()) { + TransformStatus(status, result_status); + return result_proto; + } + IndexRestorationResult restore_result = RestoreIndexIfNeeded(); // DATA_LOSS means that we have successfully re-added content to the // index. Some indexed content was lost, but otherwise the index is in a @@ -963,17 +977,17 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { TokenizedDocument tokenized_document( std::move(tokenized_document_or).ValueOrDie()); - auto document_id_or = - document_store_->Put(tokenized_document.document(), - tokenized_document.num_tokens(), put_document_stats); + auto document_id_or = document_store_->Put( + tokenized_document.document(), tokenized_document.num_string_tokens(), + put_document_stats); if (!document_id_or.ok()) { TransformStatus(document_id_or.status(), result_status); return result_proto; } DocumentId document_id = document_id_or.ValueOrDie(); - auto index_processor_or = - IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get()); + auto index_processor_or = IndexProcessor::Create( + normalizer_.get(), index_.get(), integer_index_.get(), clock_.get()); if (!index_processor_or.ok()) { TransformStatus(index_processor_or.status(), result_status); return result_proto; @@ -1243,8 +1257,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( std::unique_ptr<Timer> component_timer = clock_->GetNewTimer(); // Gets unordered results from query processor auto query_processor_or = QueryProcessor::Create( - index_.get(), language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get()); + index_.get(), integer_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), schema_store_.get()); if (!query_processor_or.ok()) { TransformStatus(query_processor_or.status(), result_status); delete_stats->set_parse_query_latency_ms( @@ -1419,6 +1433,8 @@ OptimizeResultProto IcingSearchEngine::Optimize() { optimize_stats->set_index_restoration_mode( OptimizeStatsProto::FULL_INDEX_REBUILD); ICING_LOG(WARNING) << "Resetting the entire index!"; + + // Reset string index libtextclassifier3::Status index_reset_status = index_->Reset(); if (!index_reset_status.ok()) { status = absl_ports::Annotate( @@ -1430,6 +1446,18 @@ OptimizeResultProto IcingSearchEngine::Optimize() { return result_proto; } + // Reset integer index + index_reset_status = integer_index_->Reset(); + if (!index_reset_status.ok()) { + status = absl_ports::Annotate( + absl_ports::InternalError("Failed to reset integer index."), + index_reset_status.error_message()); + TransformStatus(status, result_status); + optimize_stats->set_index_restoration_latency_ms( + optimize_index_timer->GetElapsedMilliseconds()); + return result_proto; + } + IndexRestorationResult index_restoration_status = RestoreIndexIfNeeded(); // DATA_LOSS means that we have successfully re-added content to the index. // Some indexed content was lost, but otherwise the index is in a valid @@ -1539,6 +1567,8 @@ GetOptimizeInfoResultProto IcingSearchEngine::GetOptimizeInfo() { } int64_t index_elements_size = index_elements_size_or.ValueOrDie(); + // TODO(b/259744228): add stats for integer index + // Sum up the optimizable sizes from DocumentStore and Index result_proto.set_estimated_optimizable_bytes( index_elements_size * doc_store_optimize_info.optimizable_docs / @@ -1568,6 +1598,7 @@ StorageInfoResultProto IcingSearchEngine::GetStorageInfo() { schema_store_->GetStorageInfo(); *result.mutable_storage_info()->mutable_index_storage_info() = index_->GetStorageInfo(); + // TODO(b/259744228): add stats for integer index result.mutable_status()->set_code(StatusProto::OK); return result; } @@ -1588,6 +1619,8 @@ DebugInfoResultProto IcingSearchEngine::GetDebugInfo( *debug_info.mutable_debug_info()->mutable_index_info() = index_->GetDebugInfo(verbosity); + // TODO(b/259744228): add debug info for integer index + // Document Store libtextclassifier3::StatusOr<DocumentDebugInfoProto> document_debug_info = document_store_->GetDebugInfo(verbosity); @@ -1620,6 +1653,7 @@ libtextclassifier3::Status IcingSearchEngine::InternalPersistToDisk( ICING_RETURN_IF_ERROR(schema_store_->PersistToDisk()); ICING_RETURN_IF_ERROR(document_store_->PersistToDisk(PersistType::FULL)); ICING_RETURN_IF_ERROR(index_->PersistToDisk()); + ICING_RETURN_IF_ERROR(integer_index_->PersistToDisk()); return libtextclassifier3::Status::OK; } @@ -1664,77 +1698,82 @@ SearchResultProto IcingSearchEngine::Search( query_stats->set_is_first_page(true); query_stats->set_requested_page_size(result_spec.num_per_page()); - std::unique_ptr<Timer> component_timer = clock_->GetNewTimer(); - // Gets unordered results from query processor - auto query_processor_or = QueryProcessor::Create( - index_.get(), language_segmenter_.get(), normalizer_.get(), - document_store_.get(), schema_store_.get()); - if (!query_processor_or.ok()) { - TransformStatus(query_processor_or.status(), result_status); - query_stats->set_parse_query_latency_ms( - component_timer->GetElapsedMilliseconds()); - return result_proto; - } - std::unique_ptr<QueryProcessor> query_processor = - std::move(query_processor_or).ValueOrDie(); - - auto query_results_or = - query_processor->ParseSearch(search_spec, scoring_spec.rank_by()); - if (!query_results_or.ok()) { - TransformStatus(query_results_or.status(), result_status); - query_stats->set_parse_query_latency_ms( - component_timer->GetElapsedMilliseconds()); - return result_proto; - } - QueryResults query_results = std::move(query_results_or).ValueOrDie(); - query_stats->set_parse_query_latency_ms( - component_timer->GetElapsedMilliseconds()); - + // Process query and score + QueryScoringResults query_scoring_results = + ProcessQueryAndScore(search_spec, scoring_spec, result_spec); int term_count = 0; - for (const auto& section_and_terms : query_results.query_terms) { + for (const auto& section_and_terms : query_scoring_results.query_terms) { term_count += section_and_terms.second.size(); } query_stats->set_num_terms(term_count); - - component_timer = clock_->GetNewTimer(); - // Scores but does not rank the results. - libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> - scoring_processor_or = ScoringProcessor::Create( - scoring_spec, document_store_.get(), schema_store_.get()); - if (!scoring_processor_or.ok()) { - TransformStatus(scoring_processor_or.status(), result_status); - query_stats->set_scoring_latency_ms( - component_timer->GetElapsedMilliseconds()); + query_stats->set_parse_query_latency_ms( + query_scoring_results.parse_query_latency_ms); + query_stats->set_scoring_latency_ms(query_scoring_results.scoring_latency_ms); + if (!query_scoring_results.status.ok()) { + TransformStatus(query_scoring_results.status, result_status); return result_proto; } - std::unique_ptr<ScoringProcessor> scoring_processor = - std::move(scoring_processor_or).ValueOrDie(); - std::vector<ScoredDocumentHit> result_document_hits = - scoring_processor->Score(std::move(query_results.root_iterator), - performance_configuration_.num_to_score, - &query_results.query_term_iterators); - query_stats->set_scoring_latency_ms( - component_timer->GetElapsedMilliseconds()); - query_stats->set_num_documents_scored(result_document_hits.size()); + query_stats->set_num_documents_scored( + query_scoring_results.scored_document_hits.size()); // Returns early for empty result - if (result_document_hits.empty()) { + if (query_scoring_results.scored_document_hits.empty()) { result_status->set_code(StatusProto::OK); return result_proto; } - component_timer = clock_->GetNewTimer(); - // Ranks results - std::unique_ptr<ScoredDocumentHitsRanker> ranker = - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( - std::move(result_document_hits), - /*is_descending=*/scoring_spec.order_by() == - ScoringSpecProto::Order::DESC); - query_stats->set_ranking_latency_ms( - component_timer->GetElapsedMilliseconds()); + std::unique_ptr<ScoredDocumentHitsRanker> ranker; + if (search_spec.has_join_spec()) { + // Process 2nd query + QueryScoringResults nested_query_scoring_results = ProcessQueryAndScore( + search_spec.join_spec().nested_spec().search_spec(), + search_spec.join_spec().nested_spec().scoring_spec(), + search_spec.join_spec().nested_spec().result_spec()); + // TOOD(b/256022027): set different kinds of latency for 2nd query. + if (!nested_query_scoring_results.status.ok()) { + TransformStatus(nested_query_scoring_results.status, result_status); + return result_proto; + } - component_timer = clock_->GetNewTimer(); - // RanksAndPaginates and retrieves the document protos and snippets if + // Join 2 scored document hits + JoinProcessor join_processor(document_store_.get()); + libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>> + joined_result_document_hits_or = join_processor.Join( + search_spec.join_spec(), + std::move(query_scoring_results.scored_document_hits), + std::move(nested_query_scoring_results.scored_document_hits)); + if (!joined_result_document_hits_or.ok()) { + TransformStatus(joined_result_document_hits_or.status(), result_status); + return result_proto; + } + std::vector<JoinedScoredDocumentHit> joined_result_document_hits = + std::move(joined_result_document_hits_or).ValueOrDie(); + // TODO(b/256022027): set join latency + + std::unique_ptr<Timer> component_timer = clock_->GetNewTimer(); + // Ranks results + ranker = std::make_unique< + PriorityQueueScoredDocumentHitsRanker<JoinedScoredDocumentHit>>( + std::move(joined_result_document_hits), + /*is_descending=*/scoring_spec.order_by() == + ScoringSpecProto::Order::DESC); + query_stats->set_ranking_latency_ms( + component_timer->GetElapsedMilliseconds()); + } else { + // Non-join query + std::unique_ptr<Timer> component_timer = clock_->GetNewTimer(); + // Ranks results + ranker = std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( + std::move(query_scoring_results.scored_document_hits), + /*is_descending=*/scoring_spec.order_by() == + ScoringSpecProto::Order::DESC); + query_stats->set_ranking_latency_ms( + component_timer->GetElapsedMilliseconds()); + } + + std::unique_ptr<Timer> component_timer = clock_->GetNewTimer(); + // CacheAndRetrieveFirstPage and retrieves the document protos and snippets if // requested auto result_retriever_or = ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), @@ -1750,8 +1789,9 @@ SearchResultProto IcingSearchEngine::Search( libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> page_result_info_or = result_state_manager_->CacheAndRetrieveFirstPage( - std::move(ranker), std::move(query_results.query_terms), search_spec, - scoring_spec, result_spec, *document_store_, *result_retriever); + std::move(ranker), std::move(query_scoring_results.query_terms), + search_spec, scoring_spec, result_spec, *document_store_, + *result_retriever); if (!page_result_info_or.ok()) { TransformStatus(page_result_info_or.status(), result_status); query_stats->set_document_retrieval_latency_ms( @@ -1783,6 +1823,63 @@ SearchResultProto IcingSearchEngine::Search( return result_proto; } +IcingSearchEngine::QueryScoringResults IcingSearchEngine::ProcessQueryAndScore( + const SearchSpecProto& search_spec, const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec) { + std::unique_ptr<Timer> component_timer = clock_->GetNewTimer(); + + // Gets unordered results from query processor + auto query_processor_or = QueryProcessor::Create( + index_.get(), integer_index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), schema_store_.get()); + if (!query_processor_or.ok()) { + return QueryScoringResults( + std::move(query_processor_or).status(), /*query_terms_in=*/{}, + /*scored_document_hits_in=*/{}, + /*parse_query_latency_ms_in=*/component_timer->GetElapsedMilliseconds(), + /*scoring_latency_ms_in=*/0); + } + std::unique_ptr<QueryProcessor> query_processor = + std::move(query_processor_or).ValueOrDie(); + + auto query_results_or = + query_processor->ParseSearch(search_spec, scoring_spec.rank_by()); + if (!query_results_or.ok()) { + return QueryScoringResults( + std::move(query_results_or).status(), /*query_terms_in=*/{}, + /*scored_document_hits_in=*/{}, + /*parse_query_latency_ms_in=*/component_timer->GetElapsedMilliseconds(), + /*scoring_latency_ms_in=*/0); + } + QueryResults query_results = std::move(query_results_or).ValueOrDie(); + int64_t parse_query_latency_ms = component_timer->GetElapsedMilliseconds(); + + component_timer = clock_->GetNewTimer(); + // Scores but does not rank the results. + libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> + scoring_processor_or = ScoringProcessor::Create( + scoring_spec, document_store_.get(), schema_store_.get()); + if (!scoring_processor_or.ok()) { + return QueryScoringResults(std::move(scoring_processor_or).status(), + std::move(query_results.query_terms), + /*scored_document_hits_in=*/{}, + parse_query_latency_ms, + /*scoring_latency_ms_in=*/0); + } + std::unique_ptr<ScoringProcessor> scoring_processor = + std::move(scoring_processor_or).ValueOrDie(); + std::vector<ScoredDocumentHit> scored_document_hits = + scoring_processor->Score(std::move(query_results.root_iterator), + performance_configuration_.num_to_score, + &query_results.query_term_iterators); + int64_t scoring_latency_ms = component_timer->GetElapsedMilliseconds(); + + return QueryScoringResults(libtextclassifier3::Status::OK, + std::move(query_results.query_terms), + std::move(scored_document_hits), + parse_query_latency_ms, scoring_latency_ms); +} + SearchResultProto IcingSearchEngine::GetNextPage(uint64_t next_page_token) { SearchResultProto result_proto; StatusProto* result_status = result_proto.mutable_status(); @@ -2006,8 +2103,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { return {libtextclassifier3::Status::OK, false}; } - auto index_processor_or = - IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get()); + auto index_processor_or = IndexProcessor::Create( + normalizer_.get(), index_.get(), integer_index_.get(), clock_.get()); if (!index_processor_or.ok()) { return {index_processor_or.status(), true}; } diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h index 4b0576f..221d86c 100644 --- a/icing/icing-search-engine.h +++ b/icing/icing-search-engine.h @@ -19,6 +19,7 @@ #include <memory> #include <string> #include <string_view> +#include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" @@ -26,6 +27,7 @@ #include "icing/absl_ports/thread_annotations.h" #include "icing/file/filesystem.h" #include "icing/index/index.h" +#include "icing/index/numeric/numeric-index.h" #include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/performance-configuration.h" @@ -41,8 +43,10 @@ #include "icing/proto/search.pb.h" #include "icing/proto/storage.pb.h" #include "icing/proto/usage.pb.h" +#include "icing/query/query-terms.h" #include "icing/result/result-state-manager.h" #include "icing/schema/schema-store.h" +#include "icing/scoring/scored-document-hit.h" #include "icing/store/document-store.h" #include "icing/tokenization/language-segmenter.h" #include "icing/transform/normalizer.h" @@ -464,9 +468,14 @@ class IcingSearchEngine { std::unique_ptr<const Normalizer> normalizer_ ICING_GUARDED_BY(mutex_); - // Storage for all hits of content from the document store. + // Storage for all hits of string contents from the document store. std::unique_ptr<Index> index_ ICING_GUARDED_BY(mutex_); + // Storage for all hits of numeric contents from the document store. + // TODO(b/249829533): integrate more functions with integer_index_. + std::unique_ptr<NumericIndex<int64_t>> integer_index_ + ICING_GUARDED_BY(mutex_); + // Pointer to JNI class references const std::unique_ptr<const JniCache> jni_cache_; @@ -552,6 +561,37 @@ class IcingSearchEngine { InitializeStatsProto* initialize_stats) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + // Processes query and scores according to the specs. It is a helper function + // (called by Search) to process and score normal query and the nested child + // query for join search. + // + // Returns a QueryScoringResults + // OK on success with a vector of ScoredDocumentHits, + // SectionRestrictQueryTermsMap, and other stats fields for logging. + // Any other errors when processing the query or scoring + struct QueryScoringResults { + libtextclassifier3::Status status; + SectionRestrictQueryTermsMap query_terms; + std::vector<ScoredDocumentHit> scored_document_hits; + int64_t parse_query_latency_ms; + int64_t scoring_latency_ms; + + explicit QueryScoringResults( + libtextclassifier3::Status status_in, + SectionRestrictQueryTermsMap&& query_terms_in, + std::vector<ScoredDocumentHit>&& scored_document_hits_in, + int64_t parse_query_latency_ms_in, int64_t scoring_latency_ms_in) + : status(std::move(status_in)), + query_terms(std::move(query_terms_in)), + scored_document_hits(std::move(scored_document_hits_in)), + parse_query_latency_ms(parse_query_latency_ms_in), + scoring_latency_ms(scoring_latency_ms_in) {} + }; + QueryScoringResults ProcessQueryAndScore(const SearchSpecProto& search_spec, + const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec) + ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + // Many of the internal components rely on other components' derived data. // Check that everything is consistent with each other so that we're not // using outdated derived data in some parts of our system. diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc index 7a60101..8cb7e7f 100644 --- a/icing/icing-search-engine_test.cc +++ b/icing/icing-search-engine_test.cc @@ -55,9 +55,9 @@ #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/jni-test-helpers.h" #include "icing/testing/random-string.h" -#include "icing/testing/snippet-helpers.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" +#include "icing/util/snippet-helpers.h" namespace icing { namespace lib { @@ -9895,6 +9895,423 @@ TEST_F(IcingSearchEngineTest, IcingShouldWorkFor64Sections) { EqualsSearchResultIgnoreStatsAndScores(expected_no_documents)); } +TEST_F(IcingSearchEngineTest, SimpleJoin) { + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("Person") + .AddProperty(PropertyConfigBuilder() + .SetName("firstName") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("lastName") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("emailAddress") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType(SchemaTypeConfigBuilder().SetType("Email").AddProperty( + PropertyConfigBuilder() + .SetName("subjectId") + .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + DocumentProto person1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("firstName", "first1") + .AddStringProperty("lastName", "last1") + .AddStringProperty("emailAddress", "email1@gmail.com") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + DocumentProto person2 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "person2") + .SetSchema("Person") + .AddStringProperty("firstName", "first2") + .AddStringProperty("lastName", "last2") + .AddStringProperty("emailAddress", "email2@gmail.com") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + DocumentProto person3 = + DocumentBuilder() + .SetKey("pkg$db/name#space\\\\", "person3") + .SetSchema("Person") + .AddStringProperty("firstName", "first3") + .AddStringProperty("lastName", "last3") + .AddStringProperty("emailAddress", "email3@gmail.com") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + + DocumentProto email1 = + DocumentBuilder() + .SetKey("namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subjectId", "pkg$db/namespace#person1") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + DocumentProto email2 = + DocumentBuilder() + .SetKey("namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subjectId", "pkg$db/namespace#person2") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + DocumentProto email3 = + DocumentBuilder() + .SetKey("namespace", "email3") + .SetSchema("Email") + .AddStringProperty("subjectId", "pkg$db/name\\#space\\\\#person3") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(person1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(person2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(person3).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email3).status(), ProtoIsOk()); + + // Parent SearchSpec + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::PREFIX); + search_spec.set_query("first"); + + // JoinSpec + JoinSpecProto* join_spec = search_spec.mutable_join_spec(); + // Set max_joined_child_count as 2, so only email 3, email2 will be included + // in the nested result and email1 will be truncated. + join_spec->set_max_joined_child_count(2); + join_spec->set_parent_property_expression("this.fullyQualifiedId()"); + join_spec->set_child_property_expression("subjectId"); + JoinSpecProto::NestedSpecProto* nested_spec = + join_spec->mutable_nested_spec(); + SearchSpecProto* nested_search_spec = nested_spec->mutable_search_spec(); + nested_search_spec->set_term_match_type(TermMatchType::PREFIX); + nested_search_spec->set_query(""); + *nested_spec->mutable_scoring_spec() = GetDefaultScoringSpec(); + *nested_spec->mutable_result_spec() = ResultSpecProto::default_instance(); + + // Parent ScoringSpec + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + + // Parent ResultSpec + ResultSpecProto result_spec; + result_spec.set_num_per_page(1); + + SearchResultProto expected_result1; + expected_result1.mutable_status()->set_code(StatusProto::OK); + SearchResultProto::ResultProto* result_proto1 = + expected_result1.mutable_results()->Add(); + *result_proto1->mutable_document() = person3; + *result_proto1->mutable_joined_results()->Add()->mutable_document() = email3; + + SearchResultProto expected_result2; + expected_result2.mutable_status()->set_code(StatusProto::OK); + SearchResultProto::ResultProto* result_proto2 = + expected_result2.mutable_results()->Add(); + *result_proto2->mutable_document() = person2; + *result_proto2->mutable_joined_results()->Add()->mutable_document() = email2; + + SearchResultProto expected_result3; + expected_result3.mutable_status()->set_code(StatusProto::OK); + SearchResultProto::ResultProto* result_proto3 = + expected_result3.mutable_results()->Add(); + *result_proto3->mutable_document() = person1; + *result_proto3->mutable_joined_results()->Add()->mutable_document() = email1; + + SearchResultProto result1 = + icing.Search(search_spec, scoring_spec, result_spec); + uint64_t next_page_token = result1.next_page_token(); + EXPECT_THAT(next_page_token, Ne(kInvalidNextPageToken)); + expected_result1.set_next_page_token(next_page_token); + EXPECT_THAT(result1, + EqualsSearchResultIgnoreStatsAndScores(expected_result1)); + + SearchResultProto result2 = icing.GetNextPage(next_page_token); + next_page_token = result2.next_page_token(); + EXPECT_THAT(next_page_token, Ne(kInvalidNextPageToken)); + expected_result2.set_next_page_token(next_page_token); + EXPECT_THAT(result2, + EqualsSearchResultIgnoreStatsAndScores(expected_result2)); + + SearchResultProto result3 = icing.GetNextPage(next_page_token); + next_page_token = result3.next_page_token(); + EXPECT_THAT(next_page_token, Eq(kInvalidNextPageToken)); + EXPECT_THAT(result3, + EqualsSearchResultIgnoreStatsAndScores(expected_result3)); +} + +TEST_F(IcingSearchEngineTest, InvalidJoins) { + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("Person") + .AddProperty(PropertyConfigBuilder() + .SetName("firstName") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("lastName") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("emailAddress") + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType(SchemaTypeConfigBuilder().SetType("Email").AddProperty( + PropertyConfigBuilder() + .SetName("subjectId") + .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + DocumentProto person1 = + DocumentBuilder() + .SetKey("pkg$db/namespace", "person1") + .SetSchema("Person") + .AddStringProperty("firstName", "first1") + .AddStringProperty("lastName", "last1") + .AddStringProperty("emailAddress", "email1@gmail.com") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + DocumentProto person2 = + DocumentBuilder() + .SetKey("pkg$db/namespace\\", "person2") + .SetSchema("Person") + .AddStringProperty("firstName", "first2") + .AddStringProperty("lastName", "last2") + .AddStringProperty("emailAddress", "email2@gmail.com") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + + // "invalid format" does not refer to any document, so it will not be joined + // to any document. + DocumentProto email1 = + DocumentBuilder() + .SetKey("namespace", "email1") + .SetSchema("Email") + .AddStringProperty("subjectId", "invalid format") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + // This will not be joined because the # in the subjectId is escaped. + DocumentProto email2 = + DocumentBuilder() + .SetKey("namespace", "email2") + .SetSchema("Email") + .AddStringProperty("subjectId", "pkg$db/namespace\\#person2") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(person1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(person2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email2).status(), ProtoIsOk()); + + // Parent SearchSpec + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::PREFIX); + search_spec.set_query("first"); + + // JoinSpec + JoinSpecProto* join_spec = search_spec.mutable_join_spec(); + // Set max_joined_child_count as 2, so only email 3, email2 will be included + // in the nested result and email1 will be truncated. + join_spec->set_max_joined_child_count(2); + join_spec->set_parent_property_expression("this.fullyQualifiedId()"); + join_spec->set_child_property_expression("subjectId"); + JoinSpecProto::NestedSpecProto* nested_spec = + join_spec->mutable_nested_spec(); + SearchSpecProto* nested_search_spec = nested_spec->mutable_search_spec(); + nested_search_spec->set_term_match_type(TermMatchType::PREFIX); + nested_search_spec->set_query(""); + *nested_spec->mutable_scoring_spec() = GetDefaultScoringSpec(); + *nested_spec->mutable_result_spec() = ResultSpecProto::default_instance(); + + // Parent ScoringSpec + ScoringSpecProto scoring_spec = GetDefaultScoringSpec(); + + // Parent ResultSpec + ResultSpecProto result_spec; + result_spec.set_num_per_page(1); + + SearchResultProto expected_result1; + expected_result1.mutable_status()->set_code(StatusProto::OK); + SearchResultProto::ResultProto* result_proto1 = + expected_result1.mutable_results()->Add(); + *result_proto1->mutable_document() = person2; + + SearchResultProto expected_result2; + expected_result2.mutable_status()->set_code(StatusProto::OK); + SearchResultProto::ResultProto* result_proto2 = + expected_result2.mutable_results()->Add(); + *result_proto2->mutable_document() = person1; + + SearchResultProto result1 = + icing.Search(search_spec, scoring_spec, result_spec); + uint64_t next_page_token = result1.next_page_token(); + EXPECT_THAT(next_page_token, Ne(kInvalidNextPageToken)); + expected_result1.set_next_page_token(next_page_token); + EXPECT_THAT(result1, + EqualsSearchResultIgnoreStatsAndScores(expected_result1)); + + SearchResultProto result2 = icing.GetNextPage(next_page_token); + next_page_token = result2.next_page_token(); + EXPECT_THAT(next_page_token, Eq(kInvalidNextPageToken)); + EXPECT_THAT(result2, + EqualsSearchResultIgnoreStatsAndScores(expected_result2)); +} + +TEST_F(IcingSearchEngineTest, NumericFilterAdvancedQuerySucceeds) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // Create the schema and document store + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("transaction") + .AddProperty(PropertyConfigBuilder() + .SetName("price") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("cost") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + + DocumentProto document_one = DocumentBuilder() + .SetKey("namespace", "1") + .SetSchema("transaction") + .SetCreationTimestampMs(1) + .AddInt64Property("price", 10) + .Build(); + ASSERT_THAT(icing.Put(document_one).status(), ProtoIsOk()); + + DocumentProto document_two = DocumentBuilder() + .SetKey("namespace", "2") + .SetSchema("transaction") + .SetCreationTimestampMs(1) + .AddInt64Property("price", 25) + .Build(); + ASSERT_THAT(icing.Put(document_two).status(), ProtoIsOk()); + + DocumentProto document_three = DocumentBuilder() + .SetKey("namespace", "3") + .SetSchema("transaction") + .SetCreationTimestampMs(1) + .AddInt64Property("cost", 2) + .Build(); + ASSERT_THAT(icing.Put(document_three).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_query("price < 20"); + search_spec.set_search_type( + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY); + + SearchResultProto results = + icing.Search(search_spec, ScoringSpecProto::default_instance(), + ResultSpecProto::default_instance()); + ASSERT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document_one)); + + search_spec.set_query("price == 25"); + results = icing.Search(search_spec, ScoringSpecProto::default_instance(), + ResultSpecProto::default_instance()); + ASSERT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document_two)); + + search_spec.set_query("cost > 2"); + results = icing.Search(search_spec, ScoringSpecProto::default_instance(), + ResultSpecProto::default_instance()); + EXPECT_THAT(results.results(), IsEmpty()); + + search_spec.set_query("cost >= 2"); + results = icing.Search(search_spec, ScoringSpecProto::default_instance(), + ResultSpecProto::default_instance()); + ASSERT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document_three)); + + search_spec.set_query("price <= 25"); + results = icing.Search(search_spec, ScoringSpecProto::default_instance(), + ResultSpecProto::default_instance()); + ASSERT_THAT(results.results(), SizeIs(2)); + EXPECT_THAT(results.results(0).document(), EqualsProto(document_two)); + EXPECT_THAT(results.results(1).document(), EqualsProto(document_one)); +} + +TEST_F(IcingSearchEngineTest, NumericFilterOldQueryFails) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + // Create the schema and document store + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("transaction") + .AddProperty(PropertyConfigBuilder() + .SetName("price") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("cost") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + + DocumentProto document_one = DocumentBuilder() + .SetKey("namespace", "1") + .SetSchema("transaction") + .SetCreationTimestampMs(1) + .AddInt64Property("price", 10) + .Build(); + ASSERT_THAT(icing.Put(document_one).status(), ProtoIsOk()); + + DocumentProto document_two = DocumentBuilder() + .SetKey("namespace", "2") + .SetSchema("transaction") + .SetCreationTimestampMs(1) + .AddInt64Property("price", 25) + .Build(); + ASSERT_THAT(icing.Put(document_two).status(), ProtoIsOk()); + + DocumentProto document_three = DocumentBuilder() + .SetKey("namespace", "3") + .SetSchema("transaction") + .SetCreationTimestampMs(1) + .AddInt64Property("cost", 2) + .Build(); + ASSERT_THAT(icing.Put(document_three).status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_query("price < 20"); + search_spec.set_search_type(SearchSpecProto::SearchType::ICING_RAW_QUERY); + + SearchResultProto results = + icing.Search(search_spec, ScoringSpecProto::default_instance(), + ResultSpecProto::default_instance()); + EXPECT_THAT(results.status(), ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); +} + } // namespace } // namespace lib } // namespace icing diff --git a/icing/index/index-processor.cc b/icing/index/index-processor.cc index cfeda31..9f21c9d 100644 --- a/icing/index/index-processor.cc +++ b/icing/index/index-processor.cc @@ -21,18 +21,12 @@ #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" -#include "icing/absl_ports/canonical_errors.h" -#include "icing/absl_ports/str_cat.h" #include "icing/index/index.h" -#include "icing/legacy/core/icing-string-util.h" +#include "icing/index/integer-section-indexing-handler.h" +#include "icing/index/numeric/numeric-index.h" +#include "icing/index/string-section-indexing-handler.h" #include "icing/proto/logging.pb.h" -#include "icing/proto/schema.pb.h" -#include "icing/schema/section-manager.h" -#include "icing/schema/section.h" #include "icing/store/document-id.h" -#include "icing/tokenization/token.h" -#include "icing/tokenization/tokenizer-factory.h" -#include "icing/tokenization/tokenizer.h" #include "icing/transform/normalizer.h" #include "icing/util/status-macros.h" #include "icing/util/tokenized-document.h" @@ -42,117 +36,33 @@ namespace lib { libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>> IndexProcessor::Create(const Normalizer* normalizer, Index* index, + NumericIndex<int64_t>* integer_index, const Clock* clock) { ICING_RETURN_ERROR_IF_NULL(normalizer); ICING_RETURN_ERROR_IF_NULL(index); + ICING_RETURN_ERROR_IF_NULL(integer_index); ICING_RETURN_ERROR_IF_NULL(clock); + std::vector<std::unique_ptr<SectionIndexingHandler>> handlers; + handlers.push_back( + std::make_unique<StringSectionIndexingHandler>(clock, normalizer, index)); + handlers.push_back( + std::make_unique<IntegerSectionIndexingHandler>(clock, integer_index)); + return std::unique_ptr<IndexProcessor>( - new IndexProcessor(normalizer, index, clock)); + new IndexProcessor(std::move(handlers), clock)); } libtextclassifier3::Status IndexProcessor::IndexDocument( const TokenizedDocument& tokenized_document, DocumentId document_id, PutDocumentStatsProto* put_document_stats) { - std::unique_ptr<Timer> index_timer = clock_.GetNewTimer(); - - if (index_->last_added_document_id() != kInvalidDocumentId && - document_id <= index_->last_added_document_id()) { - return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( - "DocumentId %d must be greater than last added document_id %d", - document_id, index_->last_added_document_id())); - } - index_->set_last_added_document_id(document_id); - uint32_t num_tokens = 0; - libtextclassifier3::Status status; - for (const TokenizedSection& section : tokenized_document.sections()) { - if (section.metadata.tokenizer == - StringIndexingConfig::TokenizerType::NONE) { - ICING_LOG(WARNING) - << "Unexpected TokenizerType::NONE found when indexing document."; - } - // TODO(b/152934343): pass real namespace ids in - Index::Editor editor = - index_->Edit(document_id, section.metadata.id, - section.metadata.term_match_type, /*namespace_id=*/0); - for (std::string_view token : section.token_sequence) { - ++num_tokens; - - switch (section.metadata.tokenizer) { - case StringIndexingConfig::TokenizerType::VERBATIM: - // data() is safe to use here because a token created from the - // VERBATIM tokenizer is the entire string value. The character at - // data() + token.length() is guaranteed to be a null char. - status = editor.BufferTerm(token.data()); - break; - case StringIndexingConfig::TokenizerType::NONE: - [[fallthrough]]; - case StringIndexingConfig::TokenizerType::RFC822: - [[fallthrough]]; - case StringIndexingConfig::TokenizerType::URL: - [[fallthrough]]; - case StringIndexingConfig::TokenizerType::PLAIN: - std::string normalized_term = normalizer_.NormalizeTerm(token); - status = editor.BufferTerm(normalized_term.c_str()); - } - - if (!status.ok()) { - // We've encountered a failure. Bail out. We'll mark this doc as deleted - // and signal a failure to the client. - ICING_LOG(WARNING) << "Failed to buffer term in lite lexicon due to: " - << status.error_message(); - break; - } - } - if (!status.ok()) { - break; - } - // Add all the seen terms to the index with their term frequency. - status = editor.IndexAllBufferedTerms(); - if (!status.ok()) { - ICING_LOG(WARNING) << "Failed to add hits in lite index due to: " - << status.error_message(); - break; - } - } - - if (put_document_stats != nullptr) { - put_document_stats->set_index_latency_ms( - index_timer->GetElapsedMilliseconds()); - put_document_stats->mutable_tokenization_stats()->set_num_tokens_indexed( - num_tokens); - } - - // If we're either successful or we've hit resource exhausted, then attempt a - // merge. - if ((status.ok() || absl_ports::IsResourceExhausted(status)) && - index_->WantsMerge()) { - ICING_LOG(ERROR) << "Merging the index at docid " << document_id << "."; - - std::unique_ptr<Timer> merge_timer = clock_.GetNewTimer(); - libtextclassifier3::Status merge_status = index_->Merge(); - - if (!merge_status.ok()) { - ICING_LOG(ERROR) << "Index merging failed. Clearing index."; - if (!index_->Reset().ok()) { - return absl_ports::InternalError(IcingStringUtil::StringPrintf( - "Unable to reset to clear index after merge failure. Merge " - "failure=%d:%s", - merge_status.error_code(), merge_status.error_message().c_str())); - } else { - return absl_ports::DataLossError(IcingStringUtil::StringPrintf( - "Forced to reset index after merge failure. Merge failure=%d:%s", - merge_status.error_code(), merge_status.error_message().c_str())); - } - } - - if (put_document_stats != nullptr) { - put_document_stats->set_index_merge_latency_ms( - merge_timer->GetElapsedMilliseconds()); - } + // TODO(b/259744228): set overall index latency. + for (auto& section_indexing_handler : section_indexing_handlers_) { + ICING_RETURN_IF_ERROR(section_indexing_handler->Handle( + tokenized_document, document_id, put_document_stats)); } - return status; + return libtextclassifier3::Status::OK; } } // namespace lib diff --git a/icing/index/index-processor.h b/icing/index/index-processor.h index b7ffdb5..45954c4 100644 --- a/icing/index/index-processor.h +++ b/icing/index/index-processor.h @@ -16,14 +16,15 @@ #define ICING_INDEX_INDEX_PROCESSOR_H_ #include <cstdint> -#include <string> +#include <memory> +#include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/index/index.h" +#include "icing/index/numeric/numeric-index.h" +#include "icing/index/section-indexing-handler.h" #include "icing/proto/logging.pb.h" -#include "icing/schema/section-manager.h" #include "icing/store/document-id.h" -#include "icing/tokenization/token.h" #include "icing/transform/normalizer.h" #include "icing/util/tokenized-document.h" @@ -40,7 +41,8 @@ class IndexProcessor { // An IndexProcessor on success // FAILED_PRECONDITION if any of the pointers is null. static libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>> Create( - const Normalizer* normalizer, Index* index, const Clock* clock); + const Normalizer* normalizer, Index* index, + NumericIndex<int64_t>* integer_index_, const Clock* clock); // Add tokenized document to the index, associated with document_id. If the // number of tokens in the document exceeds max_tokens_per_document, then only @@ -54,23 +56,21 @@ class IndexProcessor { // populated. // // Returns: - // INVALID_ARGUMENT if document_id is less than the document_id of a - // previously indexed document or tokenization fails. - // RESOURCE_EXHAUSTED if the index is full and can't add anymore content. - // DATA_LOSS if an attempt to merge the index fails and both indices are - // cleared as a result. - // NOT_FOUND if there is no definition for the document's schema type. - // INTERNAL_ERROR if any other errors occur + // - OK on success. + // - Any SectionIndexingHandler errors. libtextclassifier3::Status IndexDocument( const TokenizedDocument& tokenized_document, DocumentId document_id, PutDocumentStatsProto* put_document_stats = nullptr); private: - IndexProcessor(const Normalizer* normalizer, Index* index, const Clock* clock) - : normalizer_(*normalizer), index_(index), clock_(*clock) {} + explicit IndexProcessor(std::vector<std::unique_ptr<SectionIndexingHandler>>&& + section_indexing_handlers, + const Clock* clock) + : section_indexing_handlers_(std::move(section_indexing_handlers)), + clock_(*clock) {} - const Normalizer& normalizer_; - Index* const index_; + std::vector<std::unique_ptr<SectionIndexingHandler>> + section_indexing_handlers_; const Clock& clock_; }; diff --git a/icing/index/index-processor_benchmark.cc b/icing/index/index-processor_benchmark.cc index 68c592c..6123f47 100644 --- a/icing/index/index-processor_benchmark.cc +++ b/icing/index/index-processor_benchmark.cc @@ -18,6 +18,8 @@ #include "icing/file/filesystem.h" #include "icing/index/index-processor.h" #include "icing/index/index.h" +#include "icing/index/numeric/dummy-numeric-index.h" +#include "icing/index/numeric/numeric-index.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/schema/schema-store.h" #include "icing/schema/schema-util.h" @@ -55,7 +57,8 @@ // $ adb push blaze-bin/icing/index/index-processor_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/index-processor_benchmark --benchmark_filter=all +// $ adb shell /data/local/tmp/index-processor_benchmark +// --benchmark_filter=all // --adb // Flag to tell the benchmark that it'll be run on an Android device via adb, @@ -183,6 +186,8 @@ void BM_IndexDocumentWithOneProperty(benchmark::State& state) { std::unique_ptr<Index> index = CreateIndex(icing_filesystem, filesystem, index_dir); + std::unique_ptr<NumericIndex<int64_t>> integer_index = + std::make_unique<DummyNumericIndex<int64_t>>(); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = language_segmenter_factory::Create(std::move(options)).ValueOrDie(); @@ -191,7 +196,8 @@ void BM_IndexDocumentWithOneProperty(benchmark::State& state) { std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<IndexProcessor> index_processor, - IndexProcessor::Create(normalizer.get(), index.get(), &clock)); + IndexProcessor::Create(normalizer.get(), index.get(), integer_index.get(), + &clock)); DocumentProto input_document = CreateDocumentWithOneProperty(state.range(0)); TokenizedDocument tokenized_document(std::move( TokenizedDocument::Create(schema_store.get(), language_segmenter.get(), @@ -237,6 +243,8 @@ void BM_IndexDocumentWithTenProperties(benchmark::State& state) { std::unique_ptr<Index> index = CreateIndex(icing_filesystem, filesystem, index_dir); + std::unique_ptr<NumericIndex<int64_t>> integer_index = + std::make_unique<DummyNumericIndex<int64_t>>(); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = language_segmenter_factory::Create(std::move(options)).ValueOrDie(); @@ -245,7 +253,8 @@ void BM_IndexDocumentWithTenProperties(benchmark::State& state) { std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<IndexProcessor> index_processor, - IndexProcessor::Create(normalizer.get(), index.get(), &clock)); + IndexProcessor::Create(normalizer.get(), index.get(), integer_index.get(), + &clock)); DocumentProto input_document = CreateDocumentWithTenProperties(state.range(0)); @@ -293,6 +302,8 @@ void BM_IndexDocumentWithDiacriticLetters(benchmark::State& state) { std::unique_ptr<Index> index = CreateIndex(icing_filesystem, filesystem, index_dir); + std::unique_ptr<NumericIndex<int64_t>> integer_index = + std::make_unique<DummyNumericIndex<int64_t>>(); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = language_segmenter_factory::Create(std::move(options)).ValueOrDie(); @@ -301,7 +312,8 @@ void BM_IndexDocumentWithDiacriticLetters(benchmark::State& state) { std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<IndexProcessor> index_processor, - IndexProcessor::Create(normalizer.get(), index.get(), &clock)); + IndexProcessor::Create(normalizer.get(), index.get(), integer_index.get(), + &clock)); DocumentProto input_document = CreateDocumentWithDiacriticLetters(state.range(0)); @@ -349,6 +361,8 @@ void BM_IndexDocumentWithHiragana(benchmark::State& state) { std::unique_ptr<Index> index = CreateIndex(icing_filesystem, filesystem, index_dir); + std::unique_ptr<NumericIndex<int64_t>> integer_index = + std::make_unique<DummyNumericIndex<int64_t>>(); language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = language_segmenter_factory::Create(std::move(options)).ValueOrDie(); @@ -357,7 +371,8 @@ void BM_IndexDocumentWithHiragana(benchmark::State& state) { std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<IndexProcessor> index_processor, - IndexProcessor::Create(normalizer.get(), index.get(), &clock)); + IndexProcessor::Create(normalizer.get(), index.get(), integer_index.get(), + &clock)); DocumentProto input_document = CreateDocumentWithHiragana(state.range(0)); TokenizedDocument tokenized_document(std::move( diff --git a/icing/index/index-processor_test.cc b/icing/index/index-processor_test.cc index 3c848d3..b83d33c 100644 --- a/icing/index/index-processor_test.cc +++ b/icing/index/index-processor_test.cc @@ -34,6 +34,8 @@ #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/dummy-numeric-index.h" +#include "icing/index/numeric/numeric-index.h" #include "icing/index/term-property-id.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/legacy/index/icing-mock-filesystem.h" @@ -44,7 +46,6 @@ #include "icing/schema-builder.h" #include "icing/schema/schema-store.h" #include "icing/schema/schema-util.h" -#include "icing/schema/section-manager.h" #include "icing/schema/section.h" #include "icing/store/document-id.h" #include "icing/testing/common-matchers.h" @@ -81,49 +82,57 @@ constexpr std::string_view kIpsumText = "vehicula posuere vitae, convallis eu lorem. Donec semper augue eu nibh " "placerat semper."; -// type and property names of FakeType +// schema types constexpr std::string_view kFakeType = "FakeType"; +constexpr std::string_view kNestedType = "NestedType"; + +// Indexable properties and section Id. Section Id is determined by the +// lexicographical order of indexable property path. constexpr std::string_view kExactProperty = "exact"; +constexpr std::string_view kIndexableIntegerProperty = "indexableInteger"; constexpr std::string_view kPrefixedProperty = "prefixed"; -constexpr std::string_view kUnindexedProperty1 = "unindexed1"; -constexpr std::string_view kUnindexedProperty2 = "unindexed2"; constexpr std::string_view kRepeatedProperty = "repeated"; -constexpr std::string_view kSubProperty = "submessage"; -constexpr std::string_view kNestedType = "NestedType"; -constexpr std::string_view kNestedProperty = "nested"; -constexpr std::string_view kExactVerbatimProperty = "verbatimExact"; -constexpr std::string_view kPrefixedVerbatimProperty = "verbatimPrefixed"; constexpr std::string_view kRfc822Property = "rfc822"; +constexpr std::string_view kSubProperty = "submessage"; // submessage.nested +constexpr std::string_view kNestedProperty = "nested"; // submessage.nested // TODO (b/246964044): remove ifdef guard when url-tokenizer is ready for export // to Android. #ifdef ENABLE_URL_TOKENIZER -constexpr std::string_view kExactUrlProperty = "urlExact"; -constexpr std::string_view kPrefixedUrlProperty = "urlPrefixed"; +constexpr std::string_view kUrlExactProperty = "urlExact"; +constexpr std::string_view kUrlPrefixedProperty = "urlPrefixed"; #endif // ENABLE_URL_TOKENIZER - -constexpr DocumentId kDocumentId0 = 0; -constexpr DocumentId kDocumentId1 = 1; +constexpr std::string_view kVerbatimExactProperty = "verbatimExact"; +constexpr std::string_view kVerbatimPrefixedProperty = "verbatimPrefixed"; constexpr SectionId kExactSectionId = 0; -constexpr SectionId kPrefixedSectionId = 1; -constexpr SectionId kRepeatedSectionId = 2; -constexpr SectionId kRfc822SectionId = 3; -constexpr SectionId kNestedSectionId = 4; +constexpr SectionId kIndexableIntegerSectionId = 1; +constexpr SectionId kPrefixedSectionId = 2; +constexpr SectionId kRepeatedSectionId = 3; +constexpr SectionId kRfc822SectionId = 4; +constexpr SectionId kNestedSectionId = 5; // submessage.nested #ifdef ENABLE_URL_TOKENIZER -constexpr SectionId kUrlExactSectionId = 5; -constexpr SectionId kUrlPrefixedSectionId = 6; -constexpr SectionId kExactVerbatimSectionId = 7; -constexpr SectionId kPrefixedVerbatimSectionId = 8; -#else // !ENABLE_URL_TOKENIZER -constexpr SectionId kExactVerbatimSectionId = 5; -constexpr SectionId kPrefixedVerbatimSectionId = 6; +constexpr SectionId kUrlExactSectionId = 6; +constexpr SectionId kUrlPrefixedSectionId = 7; +constexpr SectionId kVerbatimExactSectionId = 8; +constexpr SectionId kVerbatimPrefixedSectionId = 9; +#else // !ENABLE_URL_TOKENIZER +constexpr SectionId kVerbatimExactSectionId = 6; +constexpr SectionId kVerbatimPrefixedSectionId = 7; #endif // ENABLE_URL_TOKENIZER +// Other non-indexable properties. +constexpr std::string_view kUnindexedProperty1 = "unindexed1"; +constexpr std::string_view kUnindexedProperty2 = "unindexed2"; + +constexpr DocumentId kDocumentId0 = 0; +constexpr DocumentId kDocumentId1 = 1; + using Cardinality = PropertyConfigProto::Cardinality; using DataType = PropertyConfigProto::DataType; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsEmpty; +using ::testing::SizeIs; using ::testing::Test; #ifdef ENABLE_URL_TOKENIZER @@ -146,6 +155,8 @@ class IndexProcessorTest : public Test { ICING_ASSERT_OK_AND_ASSIGN( index_, Index::Create(options, &filesystem_, &icing_filesystem_)); + integer_index_ = std::make_unique<DummyNumericIndex<int64_t>>(); + language_segmenter_factory::SegmenterOptions segmenter_options(ULOC_US); ICING_ASSERT_OK_AND_ASSIGN( lang_segmenter_, @@ -191,12 +202,12 @@ class IndexProcessorTest : public Test { TOKENIZER_PLAIN) .SetCardinality(CARDINALITY_REPEATED)) .AddProperty(PropertyConfigBuilder() - .SetName(kExactVerbatimProperty) + .SetName(kVerbatimExactProperty) .SetDataTypeString(TERM_MATCH_EXACT, TOKENIZER_VERBATIM) .SetCardinality(CARDINALITY_REPEATED)) .AddProperty(PropertyConfigBuilder() - .SetName(kPrefixedVerbatimProperty) + .SetName(kVerbatimPrefixedProperty) .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_VERBATIM) .SetCardinality(CARDINALITY_REPEATED)) @@ -208,15 +219,19 @@ class IndexProcessorTest : public Test { #ifdef ENABLE_URL_TOKENIZER .AddProperty( PropertyConfigBuilder() - .SetName(kExactUrlProperty) - .SetDataTypeString(MATCH_EXACT, TOKENIZER_URL) + .SetName(kUrlExactProperty) + .SetDataTypeString(TERM_MATCH_EXACT, TOKENIZER_URL) .SetCardinality(CARDINALITY_REPEATED)) .AddProperty( PropertyConfigBuilder() - .SetName(kPrefixedUrlProperty) - .SetDataTypeString(MATCH_PREFIX, TOKENIZER_URL) + .SetName(kUrlPrefixedProperty) + .SetDataTypeString(TERM_MATCH_PREFIX, TOKENIZER_URL) .SetCardinality(CARDINALITY_REPEATED)) #endif // ENABLE_URL_TOKENIZER + .AddProperty(PropertyConfigBuilder() + .SetName(kIndexableIntegerProperty) + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_REPEATED)) .AddProperty( PropertyConfigBuilder() .SetName(kSubProperty) @@ -236,7 +251,8 @@ class IndexProcessorTest : public Test { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_)); + IndexProcessor::Create(normalizer_.get(), index_.get(), + integer_index_.get(), &fake_clock_)); mock_icing_filesystem_ = std::make_unique<IcingMockFilesystem>(); } @@ -254,6 +270,7 @@ class IndexProcessorTest : public Test { std::unique_ptr<LanguageSegmenter> lang_segmenter_; std::unique_ptr<Normalizer> normalizer_; std::unique_ptr<Index> index_; + std::unique_ptr<NumericIndex<int64_t>> integer_index_; std::unique_ptr<SchemaStore> schema_store_; std::unique_ptr<IndexProcessor> index_processor_; }; @@ -282,11 +299,11 @@ std::vector<DocHitInfoTermFrequencyPair> GetHitsWithTermFrequency( TEST_F(IndexProcessorTest, CreationWithNullPointerShouldFail) { EXPECT_THAT(IndexProcessor::Create(/*normalizer=*/nullptr, index_.get(), - &fake_clock_), + integer_index_.get(), &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(IndexProcessor::Create(normalizer_.get(), /*index=*/nullptr, - &fake_clock_), + integer_index_.get(), &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } @@ -540,7 +557,8 @@ TEST_F(IndexProcessorTest, TooLongTokens) { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer.get(), index_.get(), &fake_clock_)); + IndexProcessor::Create(normalizer.get(), index_.get(), + integer_index_.get(), &fake_clock_)); DocumentProto document = DocumentBuilder() @@ -773,7 +791,8 @@ TEST_F(IndexProcessorTest, IndexingDocAutomaticMerge) { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_)); + IndexProcessor::Create(normalizer_.get(), index_.get(), + integer_index_.get(), &fake_clock_)); DocumentId doc_id = 0; // Have determined experimentally that indexing 3373 documents with this text // will cause the LiteIndex to fill up. Further indexing will fail unless the @@ -829,7 +848,8 @@ TEST_F(IndexProcessorTest, IndexingDocMergeFailureResets) { ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_)); + IndexProcessor::Create(normalizer_.get(), index_.get(), + integer_index_.get(), &fake_clock_)); // 3. Index one document. This should fit in the LiteIndex without requiring a // merge. @@ -856,14 +876,14 @@ TEST_F(IndexProcessorTest, ExactVerbatimProperty) { DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kExactVerbatimProperty), + .AddStringProperty(std::string(kVerbatimExactProperty), "Hello, world!") .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 1); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(1)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -876,7 +896,7 @@ TEST_F(IndexProcessorTest, ExactVerbatimProperty) { std::vector<DocHitInfoTermFrequencyPair> hits = GetHitsWithTermFrequency(std::move(itr)); std::unordered_map<SectionId, Hit::TermFrequency> expectedMap{ - {kExactVerbatimSectionId, 1}}; + {kVerbatimExactSectionId, 1}}; EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfoWithTermFrequency( kDocumentId0, expectedMap))); @@ -887,14 +907,14 @@ TEST_F(IndexProcessorTest, PrefixVerbatimProperty) { DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kPrefixedVerbatimProperty), + .AddStringProperty(std::string(kVerbatimPrefixedProperty), "Hello, world!") .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 1); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(1)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -908,7 +928,7 @@ TEST_F(IndexProcessorTest, PrefixVerbatimProperty) { std::vector<DocHitInfoTermFrequencyPair> hits = GetHitsWithTermFrequency(std::move(itr)); std::unordered_map<SectionId, Hit::TermFrequency> expectedMap{ - {kPrefixedVerbatimSectionId, 1}}; + {kVerbatimPrefixedSectionId, 1}}; EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfoWithTermFrequency( kDocumentId0, expectedMap))); @@ -919,14 +939,14 @@ TEST_F(IndexProcessorTest, VerbatimPropertyDoesntMatchSubToken) { DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kPrefixedVerbatimProperty), + .AddStringProperty(std::string(kVerbatimPrefixedProperty), "Hello, world!") .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 1); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(1)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -955,7 +975,7 @@ TEST_F(IndexProcessorTest, Rfc822PropertyExact) { TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 7); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -1000,7 +1020,7 @@ TEST_F(IndexProcessorTest, Rfc822PropertyExactShouldNotReturnPrefix) { TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 7); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -1028,7 +1048,7 @@ TEST_F(IndexProcessorTest, Rfc822PropertyPrefix) { TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 7); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -1069,7 +1089,7 @@ TEST_F(IndexProcessorTest, Rfc822PropertyNoMatch) { TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 7); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -1091,14 +1111,14 @@ TEST_F(IndexProcessorTest, ExactUrlProperty) { DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kExactUrlProperty), + .AddStringProperty(std::string(kUrlExactProperty), "http://www.google.com") .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 7); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -1144,14 +1164,14 @@ TEST_F(IndexProcessorTest, ExactUrlPropertyDoesNotMatchPrefix) { DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kExactUrlProperty), + .AddStringProperty(std::string(kUrlExactProperty), "https://mail.google.com/calendar/render") .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 8); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(8)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -1182,14 +1202,14 @@ TEST_F(IndexProcessorTest, PrefixUrlProperty) { DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kPrefixedUrlProperty), + .AddStringProperty(std::string(kUrlPrefixedProperty), "http://www.google.com") .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 7); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(7)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -1229,14 +1249,14 @@ TEST_F(IndexProcessorTest, PrefixUrlPropertyNoMatch) { DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kPrefixedUrlProperty), + .AddStringProperty(std::string(kUrlPrefixedProperty), "https://mail.google.com/calendar/render") .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); - EXPECT_THAT(tokenized_document.num_tokens(), 8); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(8)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), IsOk()); @@ -1270,6 +1290,61 @@ TEST_F(IndexProcessorTest, PrefixUrlPropertyNoMatch) { } #endif // ENABLE_URL_TOKENIZER +TEST_F(IndexProcessorTest, IndexableIntegerProperty) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddInt64Property(std::string(kIndexableIntegerProperty), 1, 2, 3, 4, + 5) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + // Expected to have 1 integer section. + EXPECT_THAT(tokenized_document.integer_sections(), SizeIs(1)); + + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<DocHitInfoIterator> itr, + integer_index_->GetIterator(kIndexableIntegerProperty, /*key_lower=*/1, + /*key_upper=*/5)); + + EXPECT_THAT( + GetHits(std::move(itr)), + ElementsAre(EqualsDocHitInfo( + kDocumentId0, std::vector<SectionId>{kIndexableIntegerSectionId}))); +} + +TEST_F(IndexProcessorTest, IndexableIntegerPropertyNoMatch) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddInt64Property(std::string(kIndexableIntegerProperty), 1, 2, 3, 4, + 5) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + // Expected to have 1 integer section. + EXPECT_THAT(tokenized_document.integer_sections(), SizeIs(1)); + + EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<DocHitInfoIterator> itr, + integer_index_->GetIterator(kIndexableIntegerProperty, /*key_lower=*/-1, + /*key_upper=*/0)); + + EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); +} + } // namespace } // namespace lib diff --git a/icing/index/integer-section-indexing-handler.cc b/icing/index/integer-section-indexing-handler.cc new file mode 100644 index 0000000..a49b9f3 --- /dev/null +++ b/icing/index/integer-section-indexing-handler.cc @@ -0,0 +1,70 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/integer-section-indexing-handler.h" + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/schema/section-manager.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/util/logging.h" +#include "icing/util/tokenized-document.h" + +namespace icing { +namespace lib { + +libtextclassifier3::Status IntegerSectionIndexingHandler::Handle( + const TokenizedDocument& tokenized_document, DocumentId document_id, + PutDocumentStatsProto* put_document_stats) { + // TODO(b/259744228): + // 1. Resolve last_added_document_id for index rebuilding before rollout + // 2. Set integer indexing latency and other stats + + libtextclassifier3::Status status; + // We have to add integer sections into integer index in reverse order because + // sections are sorted by SectionId in ascending order, but BasicHit should be + // added in descending order of SectionId (posting list requirement). + for (auto riter = tokenized_document.integer_sections().rbegin(); + riter != tokenized_document.integer_sections().rend(); ++riter) { + const Section<int64_t>& section = *riter; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = integer_index_.Edit( + section.metadata.path, document_id, section.metadata.id); + + for (int64_t key : section.content) { + status = editor->BufferKey(key); + if (!status.ok()) { + ICING_LOG(WARNING) + << "Failed to buffer keys into integer index due to: " + << status.error_message(); + break; + } + } + if (!status.ok()) { + break; + } + + // Add all the seen keys to the integer index. + status = editor->IndexAllBufferedKeys(); + if (!status.ok()) { + ICING_LOG(WARNING) << "Failed to add keys into integer index due to: " + << status.error_message(); + break; + } + } + + return status; +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/integer-section-indexing-handler.h b/icing/index/integer-section-indexing-handler.h new file mode 100644 index 0000000..dd0e46c --- /dev/null +++ b/icing/index/integer-section-indexing-handler.h @@ -0,0 +1,55 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_INTEGER_SECTION_INDEXING_HANDLER_H_ +#define ICING_INDEX_INTEGER_SECTION_INDEXING_HANDLER_H_ + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/index/numeric/numeric-index.h" +#include "icing/index/section-indexing-handler.h" +#include "icing/store/document-id.h" +#include "icing/util/clock.h" +#include "icing/util/tokenized-document.h" + +namespace icing { +namespace lib { + +class IntegerSectionIndexingHandler : public SectionIndexingHandler { + public: + explicit IntegerSectionIndexingHandler(const Clock* clock, + NumericIndex<int64_t>* integer_index) + : SectionIndexingHandler(clock), integer_index_(*integer_index) {} + + ~IntegerSectionIndexingHandler() override = default; + + // TODO(b/259744228): update this documentation after resolving + // last_added_document_id problem. + // Handles the integer indexing process: add hits into the integer index for + // all contents in tokenized_document.integer_sections. + // + /// Returns: + // - OK on success + // - Any NumericIndex<int64_t>::Editor errors. + libtextclassifier3::Status Handle( + const TokenizedDocument& tokenized_document, DocumentId document_id, + PutDocumentStatsProto* put_document_stats) override; + + private: + NumericIndex<int64_t>& integer_index_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_INTEGER_SECTION_INDEXING_HANDLER_H_ diff --git a/icing/index/iterator/doc-hit-info-iterator-test-util.h b/icing/index/iterator/doc-hit-info-iterator-test-util.h index ed6db23..fe3a4b9 100644 --- a/icing/index/iterator/doc-hit-info-iterator-test-util.h +++ b/icing/index/iterator/doc-hit-info-iterator-test-util.h @@ -15,6 +15,7 @@ #ifndef ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_ #define ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_ +#include <cinttypes> #include <string> #include <utility> #include <vector> diff --git a/icing/index/main/doc-hit-info-iterator-term-main.cc b/icing/index/main/doc-hit-info-iterator-term-main.cc index 098a450..f06124a 100644 --- a/icing/index/main/doc-hit-info-iterator-term-main.cc +++ b/icing/index/main/doc-hit-info-iterator-term-main.cc @@ -22,7 +22,7 @@ #include "icing/absl_ports/str_cat.h" #include "icing/file/posting_list/posting-list-identifier.h" #include "icing/index/hit/doc-hit-info.h" -#include "icing/index/main/posting-list-accessor.h" +#include "icing/index/main/posting-list-hit-accessor.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/schema/section.h" #include "icing/store/document-id.h" diff --git a/icing/index/main/doc-hit-info-iterator-term-main.h b/icing/index/main/doc-hit-info-iterator-term-main.h index c1b289f..6a21dc3 100644 --- a/icing/index/main/doc-hit-info-iterator-term-main.h +++ b/icing/index/main/doc-hit-info-iterator-term-main.h @@ -23,7 +23,7 @@ #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/index/main/main-index.h" -#include "icing/index/main/posting-list-accessor.h" +#include "icing/index/main/posting-list-hit-accessor.h" #include "icing/schema/section.h" namespace icing { @@ -91,7 +91,7 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { const std::string term_; // The accessor of the posting list chain for the requested term. - std::unique_ptr<PostingListAccessor> posting_list_accessor_; + std::unique_ptr<PostingListHitAccessor> posting_list_accessor_; MainIndex* main_index_; // Stores hits retrieved from the index. This may only be a subset of the hits diff --git a/icing/index/main/main-index.cc b/icing/index/main/main-index.cc index 1c61bfa..fd1630a 100644 --- a/icing/index/main/main-index.cc +++ b/icing/index/main/main-index.cc @@ -160,19 +160,16 @@ IndexStorageInfoProto MainIndex::GetStorageInfo( return storage_info; } -libtextclassifier3::StatusOr<std::unique_ptr<PostingListAccessor>> +libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>> MainIndex::GetAccessorForExactTerm(const std::string& term) { PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid; if (!main_lexicon_->Find(term.c_str(), &posting_list_id)) { return absl_ports::NotFoundError(IcingStringUtil::StringPrintf( "Term %s is not present in main lexicon.", term.c_str())); } - ICING_ASSIGN_OR_RETURN( - PostingListAccessor accessor, - PostingListAccessor::CreateFromExisting( - flash_index_storage_.get(), posting_list_used_hit_serializer_.get(), - posting_list_id)); - return std::make_unique<PostingListAccessor>(std::move(accessor)); + return PostingListHitAccessor::CreateFromExisting( + flash_index_storage_.get(), posting_list_used_hit_serializer_.get(), + posting_list_id); } libtextclassifier3::StatusOr<MainIndex::GetPrefixAccessorResult> @@ -202,13 +199,11 @@ MainIndex::GetAccessorForPrefixTerm(const std::string& prefix) { PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid; memcpy(&posting_list_id, main_itr.GetValue(), sizeof(posting_list_id)); ICING_ASSIGN_OR_RETURN( - PostingListAccessor pl_accessor, - PostingListAccessor::CreateFromExisting( + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::CreateFromExisting( flash_index_storage_.get(), posting_list_used_hit_serializer_.get(), posting_list_id)); - GetPrefixAccessorResult result = { - std::make_unique<PostingListAccessor>(std::move(pl_accessor)), exact}; - return result; + return GetPrefixAccessorResult(std::move(pl_accessor), exact); } // TODO(tjbarron): Implement a method PropertyReadersAll.HasAnyProperty(). @@ -245,12 +240,12 @@ MainIndex::FindTermsByPrefix( PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid; memcpy(&posting_list_id, term_iterator.GetValue(), sizeof(posting_list_id)); ICING_ASSIGN_OR_RETURN( - PostingListAccessor pl_accessor, - PostingListAccessor::CreateFromExisting( + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::CreateFromExisting( flash_index_storage_.get(), posting_list_used_hit_serializer_.get(), posting_list_id)); ICING_ASSIGN_OR_RETURN(std::vector<Hit> hits, - pl_accessor.GetNextHitsBatch()); + pl_accessor->GetNextHitsBatch()); while (!hits.empty()) { for (const Hit& hit : hits) { // Check whether this Hit is desired. @@ -297,7 +292,7 @@ MainIndex::FindTermsByPrefix( // The term is desired and no need to be scored. break; } - ICING_ASSIGN_OR_RETURN(hits, pl_accessor.GetNextHitsBatch()); + ICING_ASSIGN_OR_RETURN(hits, pl_accessor->GetNextHitsBatch()); } if (score > 0) { term_metadata_list.push_back(TermMetadata(term_iterator.GetKey(), score)); @@ -559,14 +554,14 @@ libtextclassifier3::Status MainIndex::AddHits( memcpy(&backfill_posting_list_id, main_lexicon_->GetValueAtIndex(other_tvi_main_tvi_pair.second), sizeof(backfill_posting_list_id)); - ICING_ASSIGN_OR_RETURN( - PostingListAccessor hit_accum, - PostingListAccessor::Create(flash_index_storage_.get(), - posting_list_used_hit_serializer_.get())); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<PostingListHitAccessor> hit_accum, + PostingListHitAccessor::Create( + flash_index_storage_.get(), + posting_list_used_hit_serializer_.get())); ICING_RETURN_IF_ERROR( - AddPrefixBackfillHits(backfill_posting_list_id, &hit_accum)); + AddPrefixBackfillHits(backfill_posting_list_id, hit_accum.get())); PostingListAccessor::FinalizeResult result = - PostingListAccessor::Finalize(std::move(hit_accum)); + std::move(*hit_accum).Finalize(); if (result.id.is_valid()) { main_lexicon_->SetValueAtIndex(other_tvi_main_tvi_pair.first, &result.id); } @@ -578,12 +573,12 @@ libtextclassifier3::Status MainIndex::AddHits( libtextclassifier3::Status MainIndex::AddHitsForTerm( uint32_t tvi, PostingListIdentifier backfill_posting_list_id, const TermIdHitPair* hit_elements, size_t len) { - // 1. Create a PostingListAccessor - either from the pre-existing block, if + // 1. Create a PostingListHitAccessor - either from the pre-existing block, if // one exists, or from scratch. PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid; memcpy(&posting_list_id, main_lexicon_->GetValueAtIndex(tvi), sizeof(posting_list_id)); - std::unique_ptr<PostingListAccessor> pl_accessor; + std::unique_ptr<PostingListHitAccessor> pl_accessor; if (posting_list_id.is_valid()) { if (posting_list_id.block_index() >= flash_index_storage_->num_blocks()) { ICING_LOG(ERROR) << "Index dropped hits. Invalid block index " @@ -597,18 +592,16 @@ libtextclassifier3::Status MainIndex::AddHitsForTerm( "Valid posting list has an invalid block index!"); } ICING_ASSIGN_OR_RETURN( - PostingListAccessor tmp, - PostingListAccessor::CreateFromExisting( + pl_accessor, + PostingListHitAccessor::CreateFromExisting( flash_index_storage_.get(), posting_list_used_hit_serializer_.get(), posting_list_id)); - pl_accessor = std::make_unique<PostingListAccessor>(std::move(tmp)); } else { // New posting list. - ICING_ASSIGN_OR_RETURN( - PostingListAccessor tmp, - PostingListAccessor::Create(flash_index_storage_.get(), - posting_list_used_hit_serializer_.get())); - pl_accessor = std::make_unique<PostingListAccessor>(std::move(tmp)); + ICING_ASSIGN_OR_RETURN(pl_accessor, + PostingListHitAccessor::Create( + flash_index_storage_.get(), + posting_list_used_hit_serializer_.get())); } // 2. Backfill any hits if necessary. @@ -625,7 +618,7 @@ libtextclassifier3::Status MainIndex::AddHitsForTerm( // 4. Finalize this posting list and put its identifier in the lexicon. PostingListAccessor::FinalizeResult result = - PostingListAccessor::Finalize(std::move(*pl_accessor)); + std::move(*pl_accessor).Finalize(); if (result.id.is_valid()) { main_lexicon_->SetValueAtIndex(tvi, &result.id); } @@ -634,18 +627,18 @@ libtextclassifier3::Status MainIndex::AddHitsForTerm( libtextclassifier3::Status MainIndex::AddPrefixBackfillHits( PostingListIdentifier backfill_posting_list_id, - PostingListAccessor* hit_accum) { + PostingListHitAccessor* hit_accum) { ICING_ASSIGN_OR_RETURN( - PostingListAccessor backfill_accessor, - PostingListAccessor::CreateFromExisting( + std::unique_ptr<PostingListHitAccessor> backfill_accessor, + PostingListHitAccessor::CreateFromExisting( flash_index_storage_.get(), posting_list_used_hit_serializer_.get(), backfill_posting_list_id)); std::vector<Hit> backfill_hits; ICING_ASSIGN_OR_RETURN(std::vector<Hit> tmp, - backfill_accessor.GetNextHitsBatch()); + backfill_accessor->GetNextHitsBatch()); while (!tmp.empty()) { std::copy(tmp.begin(), tmp.end(), std::back_inserter(backfill_hits)); - ICING_ASSIGN_OR_RETURN(tmp, backfill_accessor.GetNextHitsBatch()); + ICING_ASSIGN_OR_RETURN(tmp, backfill_accessor->GetNextHitsBatch()); } Hit last_added_hit; @@ -738,7 +731,7 @@ libtextclassifier3::Status MainIndex::Optimize( libtextclassifier3::StatusOr<DocumentId> MainIndex::TransferAndAddHits( const std::vector<DocumentId>& document_id_old_to_new, const char* term, - PostingListAccessor& old_pl_accessor, MainIndex* new_index) { + PostingListHitAccessor& old_pl_accessor, MainIndex* new_index) { std::vector<Hit> new_hits; bool has_no_exact_hits = true; bool has_hits_in_prefix_section = false; @@ -776,15 +769,14 @@ libtextclassifier3::StatusOr<DocumentId> MainIndex::TransferAndAddHits( } ICING_ASSIGN_OR_RETURN( - PostingListAccessor hit_accum, - PostingListAccessor::Create( + std::unique_ptr<PostingListHitAccessor> hit_accum, + PostingListHitAccessor::Create( new_index->flash_index_storage_.get(), new_index->posting_list_used_hit_serializer_.get())); for (auto itr = new_hits.rbegin(); itr != new_hits.rend(); ++itr) { - ICING_RETURN_IF_ERROR(hit_accum.PrependHit(*itr)); + ICING_RETURN_IF_ERROR(hit_accum->PrependHit(*itr)); } - PostingListAccessor::FinalizeResult result = - PostingListAccessor::Finalize(std::move(hit_accum)); + PostingListAccessor::FinalizeResult result = std::move(*hit_accum).Finalize(); if (!result.id.is_valid()) { return absl_ports::InternalError( absl_ports::StrCat("Failed to add translated hits for term: ", term)); @@ -826,14 +818,14 @@ libtextclassifier3::Status MainIndex::TransferIndex( continue; } ICING_ASSIGN_OR_RETURN( - PostingListAccessor pl_accessor, - PostingListAccessor::CreateFromExisting( + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::CreateFromExisting( flash_index_storage_.get(), posting_list_used_hit_serializer_.get(), posting_list_id)); ICING_ASSIGN_OR_RETURN( DocumentId curr_largest_document_id, TransferAndAddHits(document_id_old_to_new, term_itr.GetKey(), - pl_accessor, new_index)); + *pl_accessor, new_index)); if (curr_largest_document_id == kInvalidDocumentId) { continue; } diff --git a/icing/index/main/main-index.h b/icing/index/main/main-index.h index e257a77..70ae6f6 100644 --- a/icing/index/main/main-index.h +++ b/icing/index/main/main-index.h @@ -22,7 +22,7 @@ #include "icing/file/filesystem.h" #include "icing/file/posting_list/flash-index-storage.h" #include "icing/index/lite/term-id-hit-pair.h" -#include "icing/index/main/posting-list-accessor.h" +#include "icing/index/main/posting-list-hit-accessor.h" #include "icing/index/main/posting-list-used-hit-serializer.h" #include "icing/index/term-id-codec.h" #include "icing/index/term-metadata.h" @@ -48,27 +48,31 @@ class MainIndex { const std::string& index_directory, const Filesystem* filesystem, const IcingFilesystem* icing_filesystem); - // Get a PostingListAccessor that holds the posting list chain for 'term'. + // Get a PostingListHitAccessor that holds the posting list chain for 'term'. // // RETURNS: - // - On success, a valid PostingListAccessor + // - On success, a valid PostingListHitAccessor // - NOT_FOUND if term is not present in the main index. - libtextclassifier3::StatusOr<std::unique_ptr<PostingListAccessor>> + libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>> GetAccessorForExactTerm(const std::string& term); - // Get a PostingListAccessor for 'prefix'. + // Get a PostingListHitAccessor for 'prefix'. // // RETURNS: - // - On success, a result containing a valid PostingListAccessor. + // - On success, a result containing a valid PostingListHitAccessor. // - NOT_FOUND if neither 'prefix' nor any terms for which 'prefix' is a // prefix are present in the main index. struct GetPrefixAccessorResult { - // A PostingListAccessor that holds the posting list chain for the term + // A PostingListHitAccessor that holds the posting list chain for the term // that best represents 'prefix' in the main index. - std::unique_ptr<PostingListAccessor> accessor; + std::unique_ptr<PostingListHitAccessor> accessor; // True if the returned posting list chain is for 'prefix' or false if the // returned posting list chain is for a term for which 'prefix' is a prefix. bool exact; + + explicit GetPrefixAccessorResult( + std::unique_ptr<PostingListHitAccessor> accessor_in, bool exact_in) + : accessor(std::move(accessor_in)), exact(exact_in) {} }; libtextclassifier3::StatusOr<GetPrefixAccessorResult> GetAccessorForPrefixTerm(const std::string& prefix); @@ -302,7 +306,7 @@ class MainIndex { // posting list. libtextclassifier3::Status AddPrefixBackfillHits( PostingListIdentifier backfill_posting_list_id, - PostingListAccessor* hit_accum); + PostingListHitAccessor* hit_accum); // Transfer hits from old_pl_accessor to new_index for term. // @@ -311,7 +315,7 @@ class MainIndex { // INTERNAL_ERROR on IO error static libtextclassifier3::StatusOr<DocumentId> TransferAndAddHits( const std::vector<DocumentId>& document_id_old_to_new, const char* term, - PostingListAccessor& old_pl_accessor, MainIndex* new_index); + PostingListHitAccessor& old_pl_accessor, MainIndex* new_index); // Transfer hits from the current main index to new_index. // diff --git a/icing/index/main/posting-list-accessor.cc b/icing/index/main/posting-list-accessor.cc deleted file mode 100644 index 06ab0a1..0000000 --- a/icing/index/main/posting-list-accessor.cc +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright (C) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "icing/index/main/posting-list-accessor.h" - -#include <cstdint> -#include <memory> -#include <vector> - -#include "icing/absl_ports/canonical_errors.h" -#include "icing/file/posting_list/flash-index-storage.h" -#include "icing/file/posting_list/index-block.h" -#include "icing/file/posting_list/posting-list-identifier.h" -#include "icing/file/posting_list/posting-list-used.h" -#include "icing/index/main/posting-list-used-hit-serializer.h" -#include "icing/util/status-macros.h" - -namespace icing { -namespace lib { - -libtextclassifier3::StatusOr<PostingListAccessor> PostingListAccessor::Create( - FlashIndexStorage *storage, PostingListUsedHitSerializer *serializer) { - uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes( - storage->block_size(), serializer->GetDataTypeBytes()); - std::unique_ptr<uint8_t[]> posting_list_buffer_array = - std::make_unique<uint8_t[]>(max_posting_list_bytes); - ICING_ASSIGN_OR_RETURN( - PostingListUsed posting_list_buffer, - PostingListUsed::CreateFromUnitializedRegion( - serializer, posting_list_buffer_array.get(), max_posting_list_bytes)); - return PostingListAccessor(storage, serializer, - std::move(posting_list_buffer_array), - std::move(posting_list_buffer)); -} - -libtextclassifier3::StatusOr<PostingListAccessor> -PostingListAccessor::CreateFromExisting( - FlashIndexStorage *storage, PostingListUsedHitSerializer *serializer, - PostingListIdentifier existing_posting_list_id) { - // Our posting_list_buffer_ will start as empty. - ICING_ASSIGN_OR_RETURN(PostingListAccessor pl_accessor, - Create(storage, serializer)); - ICING_ASSIGN_OR_RETURN(PostingListHolder holder, - storage->GetPostingList(existing_posting_list_id)); - pl_accessor.preexisting_posting_list_ = - std::make_unique<PostingListHolder>(std::move(holder)); - return pl_accessor; -} - -// Returns the next batch of hits for the provided posting list. -libtextclassifier3::StatusOr<std::vector<Hit>> -PostingListAccessor::GetNextHitsBatch() { - if (preexisting_posting_list_ == nullptr) { - if (has_reached_posting_list_chain_end_) { - return std::vector<Hit>(); - } - return absl_ports::FailedPreconditionError( - "Cannot retrieve hits from a PostingListAccessor that was not created " - "from a preexisting posting list."); - } - ICING_ASSIGN_OR_RETURN( - std::vector<Hit> batch, - serializer_->GetHits(&preexisting_posting_list_->posting_list)); - uint32_t next_block_index; - // Posting lists will only be chained when they are max-sized, in which case - // block.next_block_index() will point to the next block for the next posting - // list. Otherwise, block.next_block_index() can be kInvalidBlockIndex or be - // used to point to the next free list block, which is not relevant here. - if (preexisting_posting_list_->block.max_num_posting_lists() == 1) { - next_block_index = preexisting_posting_list_->block.next_block_index(); - } else { - next_block_index = kInvalidBlockIndex; - } - if (next_block_index != kInvalidBlockIndex) { - PostingListIdentifier next_posting_list_id( - next_block_index, /*posting_list_index=*/0, - preexisting_posting_list_->block.posting_list_index_bits()); - ICING_ASSIGN_OR_RETURN(PostingListHolder holder, - storage_->GetPostingList(next_posting_list_id)); - preexisting_posting_list_ = - std::make_unique<PostingListHolder>(std::move(holder)); - } else { - has_reached_posting_list_chain_end_ = true; - preexisting_posting_list_.reset(); - } - return batch; -} - -libtextclassifier3::Status PostingListAccessor::PrependHit(const Hit &hit) { - PostingListUsed &active_pl = (preexisting_posting_list_ != nullptr) - ? preexisting_posting_list_->posting_list - : posting_list_buffer_; - libtextclassifier3::Status status = serializer_->PrependHit(&active_pl, hit); - if (!absl_ports::IsResourceExhausted(status)) { - return status; - } - // There is no more room to add hits to this current posting list! Therefore, - // we need to either move those hits to a larger posting list or flush this - // posting list and create another max-sized posting list in the chain. - if (preexisting_posting_list_ != nullptr) { - FlushPreexistingPostingList(); - } else { - ICING_RETURN_IF_ERROR(FlushInMemoryPostingList()); - } - - // Re-add hit. Should always fit since we just cleared posting_list_buffer_. - // It's fine to explicitly reference posting_list_buffer_ here because there's - // no way of reaching this line while preexisting_posting_list_ is still in - // use. - return serializer_->PrependHit(&posting_list_buffer_, hit); -} - -void PostingListAccessor::FlushPreexistingPostingList() { - if (preexisting_posting_list_->block.max_num_posting_lists() == 1) { - // If this is a max-sized posting list, then just keep track of the id for - // chaining. It'll be flushed to disk when preexisting_posting_list_ is - // destructed. - prev_block_identifier_ = preexisting_posting_list_->id; - } else { - // If this is NOT a max-sized posting list, then our hits have outgrown this - // particular posting list. Move the hits into the in-memory posting list - // and free this posting list. - // - // Move will always succeed since posting_list_buffer_ is max_pl_bytes. - serializer_->MoveFrom(/*dst=*/&posting_list_buffer_, - /*src=*/&preexisting_posting_list_->posting_list); - - // Now that all the contents of this posting list have been copied, there's - // no more use for it. Make it available to be used for another posting - // list. - storage_->FreePostingList(std::move(*preexisting_posting_list_)); - } - preexisting_posting_list_.reset(); -} - -libtextclassifier3::Status PostingListAccessor::FlushInMemoryPostingList() { - // We exceeded max_pl_bytes(). Need to flush posting_list_buffer_ and update - // the chain. - uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes( - storage_->block_size(), serializer_->GetDataTypeBytes()); - ICING_ASSIGN_OR_RETURN(PostingListHolder holder, - storage_->AllocatePostingList(max_posting_list_bytes)); - holder.block.set_next_block_index(prev_block_identifier_.block_index()); - prev_block_identifier_ = holder.id; - return serializer_->MoveFrom(/*dst=*/&holder.posting_list, - /*src=*/&posting_list_buffer_); -} - -PostingListAccessor::FinalizeResult PostingListAccessor::Finalize( - PostingListAccessor accessor) { - if (accessor.preexisting_posting_list_ != nullptr) { - // Our hits are already in an existing posting list. Nothing else to do, but - // return its id. - FinalizeResult result = {libtextclassifier3::Status::OK, - accessor.preexisting_posting_list_->id}; - return result; - } - if (accessor.serializer_->GetBytesUsed(&accessor.posting_list_buffer_) <= 0) { - FinalizeResult result = {absl_ports::InvalidArgumentError( - "Can't finalize an empty PostingListAccessor. " - "There's nothing to Finalize!"), - PostingListIdentifier::kInvalid}; - return result; - } - uint32_t posting_list_bytes = - accessor.serializer_->GetMinPostingListSizeToFit( - &accessor.posting_list_buffer_); - if (accessor.prev_block_identifier_.is_valid()) { - posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes( - accessor.storage_->block_size(), - accessor.serializer_->GetDataTypeBytes()); - } - auto holder_or = accessor.storage_->AllocatePostingList(posting_list_bytes); - if (!holder_or.ok()) { - FinalizeResult result = {holder_or.status(), - accessor.prev_block_identifier_}; - return result; - } - PostingListHolder holder = std::move(holder_or).ValueOrDie(); - if (accessor.prev_block_identifier_.is_valid()) { - holder.block.set_next_block_index( - accessor.prev_block_identifier_.block_index()); - } - - // Move to allocated area. This should never actually return an error. We know - // that editor.posting_list() is valid because it wouldn't have successfully - // returned by AllocatePostingList if it wasn't. We know posting_list_buffer_ - // is valid because we created it in-memory. And finally, we know that the - // hits from posting_list_buffer_ will fit in editor.posting_list() because we - // requested it be at at least posting_list_bytes large. - auto status = - accessor.serializer_->MoveFrom(/*dst=*/&holder.posting_list, - /*src=*/&accessor.posting_list_buffer_); - if (!status.ok()) { - FinalizeResult result = {std::move(status), - accessor.prev_block_identifier_}; - return result; - } - FinalizeResult result = {libtextclassifier3::Status::OK, holder.id}; - return result; -} - -} // namespace lib -} // namespace icing diff --git a/icing/index/main/posting-list-hit-accessor.cc b/icing/index/main/posting-list-hit-accessor.cc new file mode 100644 index 0000000..30b2410 --- /dev/null +++ b/icing/index/main/posting-list-hit-accessor.cc @@ -0,0 +1,126 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/main/posting-list-hit-accessor.h" + +#include <cstdint> +#include <memory> +#include <vector> + +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/index-block.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/main/posting-list-used-hit-serializer.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>> +PostingListHitAccessor::Create(FlashIndexStorage *storage, + PostingListUsedHitSerializer *serializer) { + uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes( + storage->block_size(), serializer->GetDataTypeBytes()); + std::unique_ptr<uint8_t[]> posting_list_buffer_array = + std::make_unique<uint8_t[]>(max_posting_list_bytes); + ICING_ASSIGN_OR_RETURN( + PostingListUsed posting_list_buffer, + PostingListUsed::CreateFromUnitializedRegion( + serializer, posting_list_buffer_array.get(), max_posting_list_bytes)); + return std::unique_ptr<PostingListHitAccessor>(new PostingListHitAccessor( + storage, serializer, std::move(posting_list_buffer_array), + std::move(posting_list_buffer))); +} + +libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>> +PostingListHitAccessor::CreateFromExisting( + FlashIndexStorage *storage, PostingListUsedHitSerializer *serializer, + PostingListIdentifier existing_posting_list_id) { + // Our posting_list_buffer_ will start as empty. + ICING_ASSIGN_OR_RETURN(std::unique_ptr<PostingListHitAccessor> pl_accessor, + Create(storage, serializer)); + ICING_ASSIGN_OR_RETURN(PostingListHolder holder, + storage->GetPostingList(existing_posting_list_id)); + pl_accessor->preexisting_posting_list_ = + std::make_unique<PostingListHolder>(std::move(holder)); + return pl_accessor; +} + +// Returns the next batch of hits for the provided posting list. +libtextclassifier3::StatusOr<std::vector<Hit>> +PostingListHitAccessor::GetNextHitsBatch() { + if (preexisting_posting_list_ == nullptr) { + if (has_reached_posting_list_chain_end_) { + return std::vector<Hit>(); + } + return absl_ports::FailedPreconditionError( + "Cannot retrieve hits from a PostingListHitAccessor that was not " + "created from a preexisting posting list."); + } + ICING_ASSIGN_OR_RETURN( + std::vector<Hit> batch, + serializer_->GetHits(&preexisting_posting_list_->posting_list)); + uint32_t next_block_index; + // Posting lists will only be chained when they are max-sized, in which case + // block.next_block_index() will point to the next block for the next posting + // list. Otherwise, block.next_block_index() can be kInvalidBlockIndex or be + // used to point to the next free list block, which is not relevant here. + if (preexisting_posting_list_->block.max_num_posting_lists() == 1) { + next_block_index = preexisting_posting_list_->block.next_block_index(); + } else { + next_block_index = kInvalidBlockIndex; + } + if (next_block_index != kInvalidBlockIndex) { + PostingListIdentifier next_posting_list_id( + next_block_index, /*posting_list_index=*/0, + preexisting_posting_list_->block.posting_list_index_bits()); + ICING_ASSIGN_OR_RETURN(PostingListHolder holder, + storage_->GetPostingList(next_posting_list_id)); + preexisting_posting_list_ = + std::make_unique<PostingListHolder>(std::move(holder)); + } else { + has_reached_posting_list_chain_end_ = true; + preexisting_posting_list_.reset(); + } + return batch; +} + +libtextclassifier3::Status PostingListHitAccessor::PrependHit(const Hit &hit) { + PostingListUsed &active_pl = (preexisting_posting_list_ != nullptr) + ? preexisting_posting_list_->posting_list + : posting_list_buffer_; + libtextclassifier3::Status status = serializer_->PrependHit(&active_pl, hit); + if (!absl_ports::IsResourceExhausted(status)) { + return status; + } + // There is no more room to add hits to this current posting list! Therefore, + // we need to either move those hits to a larger posting list or flush this + // posting list and create another max-sized posting list in the chain. + if (preexisting_posting_list_ != nullptr) { + FlushPreexistingPostingList(); + } else { + ICING_RETURN_IF_ERROR(FlushInMemoryPostingList()); + } + + // Re-add hit. Should always fit since we just cleared posting_list_buffer_. + // It's fine to explicitly reference posting_list_buffer_ here because there's + // no way of reaching this line while preexisting_posting_list_ is still in + // use. + return serializer_->PrependHit(&posting_list_buffer_, hit); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/main/posting-list-hit-accessor.h b/icing/index/main/posting-list-hit-accessor.h new file mode 100644 index 0000000..953f2bd --- /dev/null +++ b/icing/index/main/posting-list-hit-accessor.h @@ -0,0 +1,103 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_POSTING_LIST_HIT_ACCESSOR_H_ +#define ICING_INDEX_POSTING_LIST_HIT_ACCESSOR_H_ + +#include <cstdint> +#include <memory> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-accessor.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/hit/hit.h" +#include "icing/index/main/posting-list-used-hit-serializer.h" + +namespace icing { +namespace lib { + +// This class is used to provide a simple abstraction for adding hits to posting +// lists. PostingListHitAccessor handles 1) selection of properly-sized posting +// lists for the accumulated hits during Finalize() and 2) chaining of max-sized +// posting lists. +class PostingListHitAccessor : public PostingListAccessor { + public: + // Creates an empty PostingListHitAccessor. + // + // RETURNS: + // - On success, a valid unique_ptr instance of PostingListHitAccessor + // - INVALID_ARGUMENT error if storage has an invalid block_size. + static libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>> + Create(FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer); + + // Create a PostingListHitAccessor with an existing posting list identified by + // existing_posting_list_id. + // + // The PostingListHitAccessor will add hits to this posting list until it is + // necessary either to 1) chain the posting list (if it is max-sized) or 2) + // move its hits to a larger posting list. + // + // RETURNS: + // - On success, a valid unique_ptr instance of PostingListHitAccessor + // - INVALID_ARGUMENT if storage has an invalid block_size. + static libtextclassifier3::StatusOr<std::unique_ptr<PostingListHitAccessor>> + CreateFromExisting(FlashIndexStorage* storage, + PostingListUsedHitSerializer* serializer, + PostingListIdentifier existing_posting_list_id); + + PostingListUsedSerializer* GetSerializer() override { return serializer_; } + + // Retrieve the next batch of hits for the posting list chain + // + // RETURNS: + // - On success, a vector of hits in the posting list chain + // - INTERNAL if called on an instance of PostingListHitAccessor that was + // created via PostingListHitAccessor::Create, if unable to read the next + // posting list in the chain or if the posting list has been corrupted + // somehow. + libtextclassifier3::StatusOr<std::vector<Hit>> GetNextHitsBatch(); + + // Prepend one hit. This may result in flushing the posting list to disk (if + // the PostingListHitAccessor holds a max-sized posting list that is full) or + // freeing a pre-existing posting list if it is too small to fit all hits + // necessary. + // + // RETURNS: + // - OK, on success + // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the + // previously added hit. + // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new + // posting list. + libtextclassifier3::Status PrependHit(const Hit& hit); + + private: + explicit PostingListHitAccessor( + FlashIndexStorage* storage, PostingListUsedHitSerializer* serializer, + std::unique_ptr<uint8_t[]> posting_list_buffer_array, + PostingListUsed posting_list_buffer) + : PostingListAccessor(storage, std::move(posting_list_buffer_array), + std::move(posting_list_buffer)), + serializer_(serializer) {} + + PostingListUsedHitSerializer* serializer_; // Does not own. +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_POSTING_LIST_HIT_ACCESSOR_H_ diff --git a/icing/index/main/posting-list-accessor_test.cc b/icing/index/main/posting-list-hit-accessor_test.cc index 3145420..fcdd580 100644 --- a/icing/index/main/posting-list-accessor_test.cc +++ b/icing/index/main/posting-list-hit-accessor_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/index/main/posting-list-accessor.h" +#include "icing/index/main/posting-list-hit-accessor.h" #include <cstdint> @@ -40,7 +40,7 @@ using ::testing::Eq; using ::testing::Lt; using ::testing::SizeIs; -class PostingListAccessorTest : public ::testing::Test { +class PostingListHitAccessorTest : public ::testing::Test { protected: void SetUp() override { test_dir_ = GetTestTempDir() + "/test_dir"; @@ -71,19 +71,19 @@ class PostingListAccessorTest : public ::testing::Test { std::unique_ptr<FlashIndexStorage> flash_index_storage_; }; -TEST_F(PostingListAccessorTest, HitsAddAndRetrieveProperly) { +TEST_F(PostingListHitAccessorTest, HitsAddAndRetrieveProperly) { ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor, - PostingListAccessor::Create(flash_index_storage_.get(), - serializer_.get())); + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); // Add some hits! Any hits! std::vector<Hit> hits1 = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1); for (const Hit& hit : hits1) { - ICING_ASSERT_OK(pl_accessor.PrependHit(hit)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); } PostingListAccessor::FinalizeResult result = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); ICING_EXPECT_OK(result.status); EXPECT_THAT(result.id.block_index(), Eq(1)); EXPECT_THAT(result.id.posting_list_index(), Eq(0)); @@ -96,16 +96,16 @@ TEST_F(PostingListAccessorTest, HitsAddAndRetrieveProperly) { EXPECT_THAT(pl_holder.block.next_block_index(), Eq(kInvalidBlockIndex)); } -TEST_F(PostingListAccessorTest, PreexistingPLKeepOnSameBlock) { +TEST_F(PostingListHitAccessorTest, PreexistingPLKeepOnSameBlock) { ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor, - PostingListAccessor::Create(flash_index_storage_.get(), - serializer_.get())); + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); // Add a single hit. This will fit in a min-sized posting list. Hit hit1(/*section_id=*/1, /*document_id=*/0, Hit::kDefaultTermFrequency); - ICING_ASSERT_OK(pl_accessor.PrependHit(hit1)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); PostingListAccessor::FinalizeResult result1 = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); ICING_EXPECT_OK(result1.status); // Should have been allocated to the first block. EXPECT_THAT(result1.id.block_index(), Eq(1)); @@ -116,12 +116,12 @@ TEST_F(PostingListAccessorTest, PreexistingPLKeepOnSameBlock) { // reallocated. ICING_ASSERT_OK_AND_ASSIGN( pl_accessor, - PostingListAccessor::CreateFromExisting(flash_index_storage_.get(), - serializer_.get(), result1.id)); + PostingListHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); Hit hit2 = CreateHit(hit1, /*desired_byte_length=*/1); - ICING_ASSERT_OK(pl_accessor.PrependHit(hit2)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit2)); PostingListAccessor::FinalizeResult result2 = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); ICING_EXPECT_OK(result2.status); // Should have been allocated to the same posting list as the first hit. EXPECT_THAT(result2.id, Eq(result1.id)); @@ -134,11 +134,11 @@ TEST_F(PostingListAccessorTest, PreexistingPLKeepOnSameBlock) { IsOkAndHolds(ElementsAre(hit2, hit1))); } -TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) { +TEST_F(PostingListHitAccessorTest, PreexistingPLReallocateToLargerPL) { ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor, - PostingListAccessor::Create(flash_index_storage_.get(), - serializer_.get())); + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); // The smallest posting list size is 15 bytes. The first four hits will be // compressed to one byte each and will be able to fit in the 5 byte padded // region. The last hit will fit in one of the special hits. The posting list @@ -146,10 +146,10 @@ TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) { std::vector<Hit> hits1 = CreateHits(/*num_hits=*/5, /*desired_byte_length=*/1); for (const Hit& hit : hits1) { - ICING_ASSERT_OK(pl_accessor.PrependHit(hit)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); } PostingListAccessor::FinalizeResult result1 = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); ICING_EXPECT_OK(result1.status); // Should have been allocated to the first block. EXPECT_THAT(result1.id.block_index(), Eq(1)); @@ -158,8 +158,8 @@ TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) { // Now let's add some more hits! ICING_ASSERT_OK_AND_ASSIGN( pl_accessor, - PostingListAccessor::CreateFromExisting(flash_index_storage_.get(), - serializer_.get(), result1.id)); + PostingListHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); // The current posting list can fit at most 2 more hits. Adding 12 more hits // should result in these hits being moved to a larger posting list. std::vector<Hit> hits2 = CreateHits( @@ -167,10 +167,10 @@ TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) { /*desired_byte_length=*/1); for (const Hit& hit : hits2) { - ICING_ASSERT_OK(pl_accessor.PrependHit(hit)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); } PostingListAccessor::FinalizeResult result2 = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); ICING_EXPECT_OK(result2.status); // Should have been allocated to the second (new) block because the posting // list should have grown beyond the size that the first block maintains. @@ -188,19 +188,19 @@ TEST_F(PostingListAccessorTest, PreexistingPLReallocateToLargerPL) { IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); } -TEST_F(PostingListAccessorTest, MultiBlockChainsBlocksProperly) { +TEST_F(PostingListHitAccessorTest, MultiBlockChainsBlocksProperly) { ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor, - PostingListAccessor::Create(flash_index_storage_.get(), - serializer_.get())); + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); // Add some hits! Any hits! std::vector<Hit> hits1 = CreateHits(/*num_hits=*/5000, /*desired_byte_length=*/1); for (const Hit& hit : hits1) { - ICING_ASSERT_OK(pl_accessor.PrependHit(hit)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); } PostingListAccessor::FinalizeResult result1 = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); ICING_EXPECT_OK(result1.status); PostingListIdentifier second_block_id = result1.id; // Should have been allocated to the second block, which holds a max-sized @@ -235,19 +235,19 @@ TEST_F(PostingListAccessorTest, MultiBlockChainsBlocksProperly) { IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits1.rend()))); } -TEST_F(PostingListAccessorTest, PreexistingMultiBlockReusesBlocksProperly) { +TEST_F(PostingListHitAccessorTest, PreexistingMultiBlockReusesBlocksProperly) { ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor, - PostingListAccessor::Create(flash_index_storage_.get(), - serializer_.get())); + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); // Add some hits! Any hits! std::vector<Hit> hits1 = CreateHits(/*num_hits=*/5000, /*desired_byte_length=*/1); for (const Hit& hit : hits1) { - ICING_ASSERT_OK(pl_accessor.PrependHit(hit)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); } PostingListAccessor::FinalizeResult result1 = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); ICING_EXPECT_OK(result1.status); PostingListIdentifier first_add_id = result1.id; EXPECT_THAT(first_add_id, Eq(PostingListIdentifier( @@ -258,17 +258,17 @@ TEST_F(PostingListAccessorTest, PreexistingMultiBlockReusesBlocksProperly) { // second block. ICING_ASSERT_OK_AND_ASSIGN( pl_accessor, - PostingListAccessor::CreateFromExisting(flash_index_storage_.get(), - serializer_.get(), first_add_id)); + PostingListHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), first_add_id)); std::vector<Hit> hits2 = CreateHits( /*start_docid=*/hits1.back().document_id() + 1, /*num_hits=*/50, /*desired_byte_length=*/1); for (const Hit& hit : hits2) { - ICING_ASSERT_OK(pl_accessor.PrependHit(hit)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit)); } PostingListAccessor::FinalizeResult result2 = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); ICING_EXPECT_OK(result2.status); PostingListIdentifier second_add_id = result2.id; EXPECT_THAT(second_add_id, Eq(first_add_id)); @@ -302,61 +302,61 @@ TEST_F(PostingListAccessorTest, PreexistingMultiBlockReusesBlocksProperly) { IsOkAndHolds(ElementsAreArray(first_block_hits_start, hits1.rend()))); } -TEST_F(PostingListAccessorTest, InvalidHitReturnsInvalidArgument) { +TEST_F(PostingListHitAccessorTest, InvalidHitReturnsInvalidArgument) { ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor, - PostingListAccessor::Create(flash_index_storage_.get(), - serializer_.get())); + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); Hit invalid_hit; - EXPECT_THAT(pl_accessor.PrependHit(invalid_hit), + EXPECT_THAT(pl_accessor->PrependHit(invalid_hit), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } -TEST_F(PostingListAccessorTest, HitsNotDecreasingReturnsInvalidArgument) { +TEST_F(PostingListHitAccessorTest, HitsNotDecreasingReturnsInvalidArgument) { ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor, - PostingListAccessor::Create(flash_index_storage_.get(), - serializer_.get())); + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency); - ICING_ASSERT_OK(pl_accessor.PrependHit(hit1)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); Hit hit2(/*section_id=*/6, /*document_id=*/1, Hit::kDefaultTermFrequency); - EXPECT_THAT(pl_accessor.PrependHit(hit2), + EXPECT_THAT(pl_accessor->PrependHit(hit2), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); Hit hit3(/*section_id=*/2, /*document_id=*/0, Hit::kDefaultTermFrequency); - EXPECT_THAT(pl_accessor.PrependHit(hit3), + EXPECT_THAT(pl_accessor->PrependHit(hit3), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } -TEST_F(PostingListAccessorTest, NewPostingListNoHitsAdded) { +TEST_F(PostingListHitAccessorTest, NewPostingListNoHitsAdded) { ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor, - PostingListAccessor::Create(flash_index_storage_.get(), - serializer_.get())); + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); PostingListAccessor::FinalizeResult result1 = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); EXPECT_THAT(result1.status, StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } -TEST_F(PostingListAccessorTest, PreexistingPostingListNoHitsAdded) { +TEST_F(PostingListHitAccessorTest, PreexistingPostingListNoHitsAdded) { ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor, - PostingListAccessor::Create(flash_index_storage_.get(), - serializer_.get())); + std::unique_ptr<PostingListHitAccessor> pl_accessor, + PostingListHitAccessor::Create(flash_index_storage_.get(), + serializer_.get())); Hit hit1(/*section_id=*/3, /*document_id=*/1, Hit::kDefaultTermFrequency); - ICING_ASSERT_OK(pl_accessor.PrependHit(hit1)); + ICING_ASSERT_OK(pl_accessor->PrependHit(hit1)); PostingListAccessor::FinalizeResult result1 = - PostingListAccessor::Finalize(std::move(pl_accessor)); + std::move(*pl_accessor).Finalize(); ICING_ASSERT_OK(result1.status); ICING_ASSERT_OK_AND_ASSIGN( - PostingListAccessor pl_accessor2, - PostingListAccessor::CreateFromExisting(flash_index_storage_.get(), - serializer_.get(), result1.id)); + std::unique_ptr<PostingListHitAccessor> pl_accessor2, + PostingListHitAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); PostingListAccessor::FinalizeResult result2 = - PostingListAccessor::Finalize(std::move(pl_accessor2)); + std::move(*pl_accessor2).Finalize(); ICING_ASSERT_OK(result2.status); } diff --git a/icing/index/main/posting-list-used-hit-serializer.cc b/icing/index/main/posting-list-used-hit-serializer.cc index d45a428..a163188 100644 --- a/icing/index/main/posting-list-used-hit-serializer.cc +++ b/icing/index/main/posting-list-used-hit-serializer.cc @@ -20,7 +20,6 @@ #include <vector> #include "icing/absl_ports/canonical_errors.h" -#include "icing/file/posting_list/posting-list-common.h" #include "icing/file/posting_list/posting-list-used.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/legacy/index/icing-bit-util.h" diff --git a/icing/index/main/posting-list-used-hit-serializer.h b/icing/index/main/posting-list-used-hit-serializer.h index 70e3e6c..1a3cbc2 100644 --- a/icing/index/main/posting-list-used-hit-serializer.h +++ b/icing/index/main/posting-list-used-hit-serializer.h @@ -31,7 +31,7 @@ namespace lib { // comments in posting-list-used-hit-serializer.cc. class PostingListUsedHitSerializer : public PostingListUsedSerializer { public: - static constexpr uint32_t kSpecialHitsSize = sizeof(Hit) * kNumSpecialData; + static constexpr uint32_t kSpecialHitsSize = kNumSpecialData * sizeof(Hit); uint32_t GetDataTypeBytes() const override { return sizeof(Hit); } @@ -44,23 +44,14 @@ class PostingListUsedHitSerializer : public PostingListUsedSerializer { return kMinPostingListSize; } - // Min size of posting list that can fit these used bytes (see MoveFrom). uint32_t GetMinPostingListSizeToFit( const PostingListUsed* posting_list_used) const override; - // Returns bytes used by actual hits. uint32_t GetBytesUsed( const PostingListUsed* posting_list_used) const override; void Clear(PostingListUsed* posting_list_used) const override; - // Moves contents from posting list 'src' to 'dst'. Clears 'src'. - // - // RETURNS: - // - OK on success - // - INVALID_ARGUMENT if 'src' is not valid or 'src' is too large to fit in - // 'dst'. - // - FAILED_PRECONDITION if 'dst' posting list is in a corrupted state. libtextclassifier3::Status MoveFrom(PostingListUsed* dst, PostingListUsed* src) const override; diff --git a/icing/index/main/posting-list-used-hit-serializer_test.cc b/icing/index/main/posting-list-used-hit-serializer_test.cc index b87adc9..9ecb7ec 100644 --- a/icing/index/main/posting-list-used-hit-serializer_test.cc +++ b/icing/index/main/posting-list-used-hit-serializer_test.cc @@ -579,7 +579,7 @@ TEST(PostingListUsedHitSerializerTest, ICING_ASSERT_OK(serializer.PrependHit(&pl_used1, hit)); } - EXPECT_THAT(serializer.MoveFrom(&pl_used1, /*other=*/nullptr), + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used1, /*src=*/nullptr), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(serializer.GetHits(&pl_used1), IsOkAndHolds(ElementsAreArray(hits.rbegin(), hits.rend()))); @@ -625,7 +625,7 @@ TEST(PostingListUsedHitSerializerTest, } TEST(PostingListUsedHitSerializerTest, - MoveToInvalidPostingListReturnsInvalidArgument) { + MoveToInvalidPostingListReturnsFailedPrecondition) { PostingListUsedHitSerializer serializer; int size = 3 * serializer.GetMinPostingListSize(); @@ -657,7 +657,7 @@ TEST(PostingListUsedHitSerializerTest, *first_hit = invalid_hit; ++first_hit; *first_hit = invalid_hit; - EXPECT_THAT(serializer.MoveFrom(&pl_used2, &pl_used1), + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(serializer.GetHits(&pl_used1), IsOkAndHolds(ElementsAreArray(hits1.rbegin(), hits1.rend()))); diff --git a/icing/index/numeric/doc-hit-info-iterator-numeric.h b/icing/index/numeric/doc-hit-info-iterator-numeric.h new file mode 100644 index 0000000..1bfd193 --- /dev/null +++ b/icing/index/numeric/doc-hit-info-iterator-numeric.h @@ -0,0 +1,63 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_NUMERIC_DOC_HIT_INFO_ITERATOR_NUMERIC_H_ +#define ICING_INDEX_NUMERIC_DOC_HIT_INFO_ITERATOR_NUMERIC_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/numeric-index.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +template <typename T> +class DocHitInfoIteratorNumeric : public DocHitInfoIterator { + public: + explicit DocHitInfoIteratorNumeric( + std::unique_ptr<typename NumericIndex<T>::Iterator> numeric_index_iter) + : numeric_index_iter_(std::move(numeric_index_iter)) {} + + libtextclassifier3::Status Advance() override { + ICING_RETURN_IF_ERROR(numeric_index_iter_->Advance()); + + doc_hit_info_ = numeric_index_iter_->GetDocHitInfo(); + return libtextclassifier3::Status::OK; + } + + int32_t GetNumBlocksInspected() const override { return 0; } + + int32_t GetNumLeafAdvanceCalls() const override { return 0; } + + std::string ToString() const override { return "test"; } + + void PopulateMatchedTermsStats( + std::vector<TermMatchInfo>* matched_terms_stats, + SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override { + // For numeric hit iterator, this should do nothing since there is no term. + } + + private: + std::unique_ptr<typename NumericIndex<T>::Iterator> numeric_index_iter_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_NUMERIC_DOC_HIT_INFO_ITERATOR_NUMERIC_H_ diff --git a/icing/index/numeric/dummy-numeric-index.h b/icing/index/numeric/dummy-numeric-index.h new file mode 100644 index 0000000..a1d20f8 --- /dev/null +++ b/icing/index/numeric/dummy-numeric-index.h @@ -0,0 +1,239 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_NUMERIC_DUMMY_NUMERIC_INDEX_H_ +#define ICING_INDEX_NUMERIC_DUMMY_NUMERIC_INDEX_H_ + +#include <functional> +#include <map> +#include <memory> +#include <queue> +#include <string> +#include <string_view> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/hit/hit.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/doc-hit-info-iterator-numeric.h" +#include "icing/index/numeric/numeric-index.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" + +namespace icing { +namespace lib { + +template <typename T> +class DummyNumericIndex : public NumericIndex<T> { + public: + ~DummyNumericIndex() override = default; + + std::unique_ptr<typename NumericIndex<T>::Editor> Edit( + std::string_view property_name, DocumentId document_id, + SectionId section_id) override { + return std::make_unique<Editor>(property_name, document_id, section_id, + storage_); + } + + libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> GetIterator( + std::string_view property_name, T key_lower, T key_upper) const override; + + libtextclassifier3::Status Reset() override { + storage_.clear(); + return libtextclassifier3::Status::OK; + } + + libtextclassifier3::Status PersistToDisk() override { + return libtextclassifier3::Status::OK; + } + + private: + class Editor : public NumericIndex<T>::Editor { + public: + explicit Editor( + std::string_view property_name, DocumentId document_id, + SectionId section_id, + std::unordered_map<std::string, std::map<T, std::vector<BasicHit>>>& + storage) + : NumericIndex<T>::Editor(property_name, document_id, section_id), + storage_(storage) {} + + ~Editor() override = default; + + libtextclassifier3::Status BufferKey(T key) override { + seen_keys_.insert(key); + return libtextclassifier3::Status::OK; + } + + libtextclassifier3::Status IndexAllBufferedKeys() override; + + private: + std::unordered_set<T> seen_keys_; + std::unordered_map<std::string, std::map<T, std::vector<BasicHit>>>& + storage_; + }; + + class Iterator : public NumericIndex<T>::Iterator { + public: + // We group BasicHits (sorted by document_id) of a key into a Bucket (stored + // as std::vector) and store key -> vector in an std::map. When doing range + // query, we may access vectors from multiple keys and want to return + // BasicHits to callers sorted by document_id. Therefore, this problem is + // actually "merge K sorted vectors". + // To implement this algorithm via priority_queue, we create this wrapper + // class to store iterators of map and vector. + class BucketInfo { + public: + explicit BucketInfo( + typename std::map<T, std::vector<BasicHit>>::const_iterator + bucket_iter) + : bucket_iter_(bucket_iter), + vec_iter_(bucket_iter_->second.rbegin()) {} + + bool Advance() { return ++vec_iter_ != bucket_iter_->second.rend(); } + + const BasicHit& GetCurrentBasicHit() const { return *vec_iter_; } + + bool operator<(const BucketInfo& other) const { + // std::priority_queue is a max heap and we should return BasicHits in + // DocumentId descending order. + // - BucketInfo::operator< should have the same order as DocumentId. + // - BasicHit encodes inverted document id and its operator< compares + // the encoded raw value directly. + // - Therefore, BucketInfo::operator< should compare BasicHit reversely. + // - This will make priority_queue return buckets in DocumentId + // descending and SectionId ascending order. + // - Whatever direction we sort SectionId by (or pop by priority_queue) + // doesn't matter because all hits for the same DocumentId will be + // merged into a single DocHitInfo. + return other.GetCurrentBasicHit() < GetCurrentBasicHit(); + } + + private: + typename std::map<T, std::vector<BasicHit>>::const_iterator bucket_iter_; + std::vector<BasicHit>::const_reverse_iterator vec_iter_; + }; + + explicit Iterator(T key_lower, T key_upper, + std::vector<BucketInfo>&& bucket_info_vec) + : NumericIndex<T>::Iterator(key_lower, key_upper), + pq_(std::less<BucketInfo>(), std::move(bucket_info_vec)) {} + + ~Iterator() override = default; + + libtextclassifier3::Status Advance() override; + + DocHitInfo GetDocHitInfo() const override { return doc_hit_info_; } + + private: + std::priority_queue<BucketInfo> pq_; + DocHitInfo doc_hit_info_; + }; + + std::unordered_map<std::string, std::map<T, std::vector<BasicHit>>> storage_; +}; + +template <typename T> +libtextclassifier3::Status +DummyNumericIndex<T>::Editor::IndexAllBufferedKeys() { + auto property_map_iter = storage_.find(this->property_name_); + if (property_map_iter == storage_.end()) { + const auto& [inserted_iter, insert_result] = + storage_.insert({this->property_name_, {}}); + if (!insert_result) { + return absl_ports::InternalError( + absl_ports::StrCat("Failed to create a new map for property \"", + this->property_name_, "\"")); + } + property_map_iter = inserted_iter; + } + + for (const T& key : seen_keys_) { + auto key_map_iter = property_map_iter->second.find(key); + if (key_map_iter == property_map_iter->second.end()) { + const auto& [inserted_iter, insert_result] = + property_map_iter->second.insert({key, {}}); + if (!insert_result) { + return absl_ports::InternalError("Failed to create a new map for key"); + } + key_map_iter = inserted_iter; + } + key_map_iter->second.push_back( + BasicHit(this->section_id_, this->document_id_)); + } + return libtextclassifier3::Status::OK; +} + +template <typename T> +libtextclassifier3::Status DummyNumericIndex<T>::Iterator::Advance() { + if (pq_.empty()) { + return absl_ports::OutOfRangeError("End of iterator"); + } + + DocumentId document_id = pq_.top().GetCurrentBasicHit().document_id(); + doc_hit_info_ = DocHitInfo(document_id); + // Merge sections with same document_id into a single DocHitInfo + while (!pq_.empty() && + pq_.top().GetCurrentBasicHit().document_id() == document_id) { + doc_hit_info_.UpdateSection(pq_.top().GetCurrentBasicHit().section_id()); + + BucketInfo info = pq_.top(); + pq_.pop(); + + if (info.Advance()) { + pq_.push(std::move(info)); + } + } + + return libtextclassifier3::Status::OK; +} + +template <typename T> +libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> +DummyNumericIndex<T>::GetIterator(std::string_view property_name, T key_lower, + T key_upper) const { + if (key_lower > key_upper) { + return absl_ports::InvalidArgumentError( + "key_lower should not be greater than key_upper"); + } + + auto property_map_iter = storage_.find(std::string(property_name)); + if (property_map_iter == storage_.end()) { + return absl_ports::NotFoundError( + absl_ports::StrCat("Property \"", property_name, "\" not found")); + } + + std::vector<typename Iterator::BucketInfo> bucket_info_vec; + for (auto key_map_iter = property_map_iter->second.lower_bound(key_lower); + key_map_iter != property_map_iter->second.cend() && + key_map_iter->first <= key_upper; + ++key_map_iter) { + bucket_info_vec.push_back(typename Iterator::BucketInfo(key_map_iter)); + } + + return std::make_unique<DocHitInfoIteratorNumeric<T>>( + std::make_unique<Iterator>(key_lower, key_upper, + std::move(bucket_info_vec))); +} + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_NUMERIC_DUMMY_NUMERIC_INDEX_H_ diff --git a/icing/index/numeric/integer-index-data.h b/icing/index/numeric/integer-index-data.h new file mode 100644 index 0000000..92653fa --- /dev/null +++ b/icing/index/numeric/integer-index-data.h @@ -0,0 +1,59 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_NUMERIC_INTEGER_INDEX_DATA_H_ +#define ICING_INDEX_NUMERIC_INTEGER_INDEX_DATA_H_ + +#include <cstdint> + +#include "icing/index/hit/hit.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" + +namespace icing { +namespace lib { + +// Data wrapper to store BasicHit and key for integer index. +class IntegerIndexData { + public: + explicit IntegerIndexData(SectionId section_id, DocumentId document_id, + int64_t key) + : basic_hit_(section_id, document_id), key_(key) {} + + explicit IntegerIndexData() : basic_hit_(), key_(0) {} + + const BasicHit& basic_hit() const { return basic_hit_; } + + int64_t key() const { return key_; } + + bool is_valid() const { return basic_hit_.is_valid(); } + + bool operator<(const IntegerIndexData& other) const { + return basic_hit_ < other.basic_hit_; + } + + bool operator==(const IntegerIndexData& other) const { + return basic_hit_ == other.basic_hit_ && key_ == other.key_; + } + + private: + BasicHit basic_hit_; + int64_t key_; +} __attribute__((packed)); +static_assert(sizeof(IntegerIndexData) == 12, ""); + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_NUMERIC_INTEGER_INDEX_DATA_H_ diff --git a/icing/index/numeric/numeric-index.h b/icing/index/numeric/numeric-index.h new file mode 100644 index 0000000..6798f8d --- /dev/null +++ b/icing/index/numeric/numeric-index.h @@ -0,0 +1,146 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_NUMERIC_NUMERIC_INDEX_H_ +#define ICING_INDEX_NUMERIC_NUMERIC_INDEX_H_ + +#include <memory> +#include <string> +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" + +namespace icing { +namespace lib { + +template <typename T> +class NumericIndex { + public: + using value_type = T; + + // Editor class for batch adding new records into numeric index for a given + // property, DocumentId and SectionId. The caller should use BufferKey to + // buffer a key (calls several times for multiple keys) and finally call + // IndexAllBufferedKeys to batch add all buffered keys (with DocumentId + + // SectionId info, i.e. BasicHit) into numeric index. + // + // For example, there are values = [5, 1, 10, -100] in DocumentId = 5, + // SectionId = 1 (property "timestamp"). + // Then the client should call BufferKey(5), BufferKey(1), BufferKey(10), + // BufferKey(-100) first, and finally call IndexAllBufferedKeys once to batch + // add these records into numeric index. + class Editor { + public: + explicit Editor(std::string_view property_name, DocumentId document_id, + SectionId section_id) + : property_name_(property_name), + document_id_(document_id), + section_id_(section_id) {} + + virtual ~Editor() = default; + + // Buffers a new key. + // + // Returns: + // - OK on success + // - Any other errors, depending on the actual implementation + virtual libtextclassifier3::Status BufferKey(T key) = 0; + + // Adds all buffered keys into numeric index. + // + // Returns: + // - OK on success + // - Any other errors, depending on the actual implementation + virtual libtextclassifier3::Status IndexAllBufferedKeys() = 0; + + protected: + std::string property_name_; + DocumentId document_id_; + SectionId section_id_; + }; + + // Iterator class for numeric index range query [key_lower, key_upper] + // (inclusive for both side) on a given property (see GetIterator). There are + // some basic requirements for implementation: + // - Iterates through all relevant doc hits. + // - Merges multiple SectionIds of doc hits with same DocumentId into a single + // SectionIdMask and constructs DocHitInfo. + // - Returns DocHitInfo in descending DocumentId order. + // + // For example, relevant doc hits (DocumentId, SectionId) are [(2, 0), (4, 3), + // (2, 1), (6, 2), (4, 2)]. Advance() and GetDocHitInfo() should return + // DocHitInfo(6, SectionIdMask(2)), DocHitInfo(4, SectionIdMask(2, 3)) and + // DocHitInfo(2, SectionIdMask(0, 1)). + class Iterator { + public: + explicit Iterator(T key_lower, T key_upper) + : key_lower_(key_lower), key_upper_(key_upper) {} + + virtual ~Iterator() = default; + + virtual libtextclassifier3::Status Advance() = 0; + + virtual DocHitInfo GetDocHitInfo() const = 0; + + protected: + T key_lower_; + T key_upper_; + }; + + virtual ~NumericIndex() = default; + + // Returns an Editor instance for adding new records into numeric index for a + // given property, DocumentId and SectionId. See Editor for more details. + virtual std::unique_ptr<Editor> Edit(std::string_view property_name, + DocumentId document_id, + SectionId section_id) = 0; + + // Returns a DocHitInfoIteratorNumeric (in DocHitInfoIterator interface type + // format) for iterating through all docs which have the specified (numeric) + // property contents in range [key_lower, key_upper]. + // + // In general, different numeric index implementations require different data + // iterator implementations, so class Iterator is an abstraction of the data + // iterator and DocHitInfoIteratorNumeric can work with any implementation of + // it. See Iterator and DocHitInfoIteratorNumeric for more details. + // + // Returns: + // - std::unique_ptr<DocHitInfoIterator> on success + // - NOT_FOUND_ERROR if there is no numeric index for property_name + // - INVALID_ARGUMENT_ERROR if key_lower > key_upper + // - Any other errors, depending on the actual implementation + virtual libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> + GetIterator(std::string_view property_name, T key_lower, + T key_upper) const = 0; + + // Clears all files created by the index. Returns OK if all files were + // cleared. + virtual libtextclassifier3::Status Reset() = 0; + + // Syncs all the data and metadata changes to disk. + // + // Returns: + // OK on success + // INTERNAL_ERROR on I/O errors + virtual libtextclassifier3::Status PersistToDisk() = 0; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_NUMERIC_NUMERIC_INDEX_H_ diff --git a/icing/index/numeric/numeric-index_test.cc b/icing/index/numeric/numeric-index_test.cc new file mode 100644 index 0000000..38769f6 --- /dev/null +++ b/icing/index/numeric/numeric-index_test.cc @@ -0,0 +1,361 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/numeric/numeric-index.h" + +#include <limits> +#include <string> +#include <string_view> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/dummy-numeric-index.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::NotNull; + +constexpr static std::string_view kDefaultTestPropertyName = "test"; + +constexpr SectionId kDefaultSectionId = 0; + +template <typename T> +class NumericIndexTest : public ::testing::Test { + protected: + using INDEX_IMPL_TYPE = T; + + void SetUp() override { + if (std::is_same_v< + INDEX_IMPL_TYPE, + DummyNumericIndex<typename INDEX_IMPL_TYPE::value_type>>) { + numeric_index_ = std::make_unique< + DummyNumericIndex<typename INDEX_IMPL_TYPE::value_type>>(); + } + + ASSERT_THAT(numeric_index_, NotNull()); + } + + void Index(std::string_view property_name, DocumentId document_id, + SectionId section_id, + std::vector<typename INDEX_IMPL_TYPE::value_type> keys) { + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + this->numeric_index_->Edit(property_name, document_id, section_id); + + for (const auto& key : keys) { + ICING_EXPECT_OK(editor->BufferKey(key)); + } + ICING_EXPECT_OK(editor->IndexAllBufferedKeys()); + } + + libtextclassifier3::StatusOr<std::vector<DocHitInfo>> Query( + std::string_view property_name, + typename INDEX_IMPL_TYPE::value_type key_lower, + typename INDEX_IMPL_TYPE::value_type key_upper) { + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<DocHitInfoIterator> iter, + this->numeric_index_->GetIterator(property_name, key_lower, key_upper)); + + std::vector<DocHitInfo> result; + while (iter->Advance().ok()) { + result.push_back(iter->doc_hit_info()); + } + return result; + } + + std::unique_ptr<NumericIndex<typename INDEX_IMPL_TYPE::value_type>> + numeric_index_; +}; + +using TestTypes = ::testing::Types<DummyNumericIndex<int64_t>>; +TYPED_TEST_SUITE(NumericIndexTest, TestTypes); + +TYPED_TEST(NumericIndexTest, SingleKeyExactQuery) { + this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId, + /*keys=*/{1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId, + /*keys=*/{3}); + this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId, + /*keys=*/{2}); + this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId, + /*keys=*/{0}); + this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId, + /*keys=*/{4}); + this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId, + /*keys=*/{2}); + + int64_t query_key = 2; + std::vector<SectionId> expected_sections{kDefaultSectionId}; + EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/query_key, + /*key_upper=*/query_key), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(/*document_id=*/5, expected_sections), + EqualsDocHitInfo(/*document_id=*/2, expected_sections)))); +} + +TYPED_TEST(NumericIndexTest, SingleKeyRangeQuery) { + this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId, + /*keys=*/{1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId, + /*keys=*/{3}); + this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId, + /*keys=*/{2}); + this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId, + /*keys=*/{0}); + this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId, + /*keys=*/{4}); + this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId, + /*keys=*/{2}); + + std::vector<SectionId> expected_sections{kDefaultSectionId}; + EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/1, + /*key_upper=*/3), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(/*document_id=*/5, expected_sections), + EqualsDocHitInfo(/*document_id=*/2, expected_sections), + EqualsDocHitInfo(/*document_id=*/1, expected_sections), + EqualsDocHitInfo(/*document_id=*/0, expected_sections)))); +} + +TYPED_TEST(NumericIndexTest, EmptyResult) { + this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId, + /*keys=*/{1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId, + /*keys=*/{3}); + this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId, + /*keys=*/{2}); + this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId, + /*keys=*/{0}); + this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId, + /*keys=*/{4}); + this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId, + /*keys=*/{2}); + + EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/100, + /*key_upper=*/200), + IsOkAndHolds(IsEmpty())); +} + +TYPED_TEST(NumericIndexTest, MultipleKeysShouldMergeAndDedupeDocHitInfo) { + // Construct several documents with mutiple keys under the same section. + // Range query [1, 3] will find hits with same (DocumentId, SectionId) for + // mutiple times. For example, (2, kDefaultSectionId) will be found twice + // (once for key = 1 and once for key = 3). + // Test if the iterator dedupes correctly. + this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId, + /*keys=*/{-1000, 0}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId, + /*keys=*/{-100, 0, 1, 2, 3, 4, 5}); + this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId, + /*keys=*/{3, 1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId, + /*keys=*/{4, 1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId, + /*keys=*/{1, 6}); + this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId, + /*keys=*/{2, 100}); + this->Index(kDefaultTestPropertyName, /*document_id=*/6, kDefaultSectionId, + /*keys=*/{1000, 2}); + this->Index(kDefaultTestPropertyName, /*document_id=*/7, kDefaultSectionId, + /*keys=*/{4, -1000}); + + std::vector<SectionId> expected_sections{kDefaultSectionId}; + EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/1, + /*key_upper=*/3), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(/*document_id=*/6, expected_sections), + EqualsDocHitInfo(/*document_id=*/5, expected_sections), + EqualsDocHitInfo(/*document_id=*/4, expected_sections), + EqualsDocHitInfo(/*document_id=*/3, expected_sections), + EqualsDocHitInfo(/*document_id=*/2, expected_sections), + EqualsDocHitInfo(/*document_id=*/1, expected_sections)))); +} + +TYPED_TEST(NumericIndexTest, EdgeNumericValues) { + this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId, + /*keys=*/{0}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId, + /*keys=*/{-100}); + this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId, + /*keys=*/{-80}); + this->Index( + kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId, + /*keys=*/{std::numeric_limits<typename TypeParam::value_type>::max()}); + this->Index( + kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId, + /*keys=*/{std::numeric_limits<typename TypeParam::value_type>::min()}); + this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId, + /*keys=*/{200}); + this->Index(kDefaultTestPropertyName, /*document_id=*/6, kDefaultSectionId, + /*keys=*/{100}); + this->Index( + kDefaultTestPropertyName, /*document_id=*/7, kDefaultSectionId, + /*keys=*/{std::numeric_limits<typename TypeParam::value_type>::max()}); + this->Index(kDefaultTestPropertyName, /*document_id=*/8, kDefaultSectionId, + /*keys=*/{0}); + this->Index( + kDefaultTestPropertyName, /*document_id=*/9, kDefaultSectionId, + /*keys=*/{std::numeric_limits<typename TypeParam::value_type>::min()}); + + std::vector<SectionId> expected_sections{kDefaultSectionId}; + + // Negative key + EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/-100, + /*key_upper=*/-70), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(/*document_id=*/2, expected_sections), + EqualsDocHitInfo(/*document_id=*/1, expected_sections)))); + + // value_type max key + EXPECT_THAT( + this->Query(kDefaultTestPropertyName, /*key_lower=*/ + std::numeric_limits<typename TypeParam::value_type>::max(), + /*key_upper=*/ + std::numeric_limits<typename TypeParam::value_type>::max()), + IsOkAndHolds( + ElementsAre(EqualsDocHitInfo(/*document_id=*/7, expected_sections), + EqualsDocHitInfo(/*document_id=*/3, expected_sections)))); + + // value_type min key + EXPECT_THAT( + this->Query(kDefaultTestPropertyName, /*key_lower=*/ + std::numeric_limits<typename TypeParam::value_type>::min(), + /*key_upper=*/ + std::numeric_limits<typename TypeParam::value_type>::min()), + IsOkAndHolds( + ElementsAre(EqualsDocHitInfo(/*document_id=*/9, expected_sections), + EqualsDocHitInfo(/*document_id=*/4, expected_sections)))); + + // Key = 0 + EXPECT_THAT( + this->Query(kDefaultTestPropertyName, /*key_lower=*/0, /*key_upper=*/0), + IsOkAndHolds( + ElementsAre(EqualsDocHitInfo(/*document_id=*/8, expected_sections), + EqualsDocHitInfo(/*document_id=*/0, expected_sections)))); + + // All keys from value_type min to value_type max + EXPECT_THAT( + this->Query(kDefaultTestPropertyName, /*key_lower=*/ + std::numeric_limits<typename TypeParam::value_type>::min(), + /*key_upper=*/ + std::numeric_limits<typename TypeParam::value_type>::max()), + IsOkAndHolds( + ElementsAre(EqualsDocHitInfo(/*document_id=*/9, expected_sections), + EqualsDocHitInfo(/*document_id=*/8, expected_sections), + EqualsDocHitInfo(/*document_id=*/7, expected_sections), + EqualsDocHitInfo(/*document_id=*/6, expected_sections), + EqualsDocHitInfo(/*document_id=*/5, expected_sections), + EqualsDocHitInfo(/*document_id=*/4, expected_sections), + EqualsDocHitInfo(/*document_id=*/3, expected_sections), + EqualsDocHitInfo(/*document_id=*/2, expected_sections), + EqualsDocHitInfo(/*document_id=*/1, expected_sections), + EqualsDocHitInfo(/*document_id=*/0, expected_sections)))); +} + +TYPED_TEST(NumericIndexTest, + MultipleSectionsShouldMergeSectionsAndDedupeDocHitInfo) { + // Construct several documents with mutiple numeric sections. + // Range query [1, 3] will find hits with same DocumentIds but multiple + // different SectionIds. For example, there will be 2 hits (1, 0), (1, 1) for + // DocumentId=1. + // Test if the iterator merges multiple sections into a single SectionIdMask + // correctly. + this->Index(kDefaultTestPropertyName, /*document_id=*/0, /*section_id=*/0, + /*keys=*/{0}); + this->Index(kDefaultTestPropertyName, /*document_id=*/0, /*section_id=*/1, + /*keys=*/{1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/0, /*section_id=*/2, + /*keys=*/{-1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, /*section_id=*/0, + /*keys=*/{2}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, /*section_id=*/1, + /*keys=*/{1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, /*section_id=*/2, + /*keys=*/{4}); + this->Index(kDefaultTestPropertyName, /*document_id=*/2, /*section_id=*/3, + /*keys=*/{3}); + this->Index(kDefaultTestPropertyName, /*document_id=*/2, /*section_id=*/4, + /*keys=*/{2}); + this->Index(kDefaultTestPropertyName, /*document_id=*/2, /*section_id=*/5, + /*keys=*/{5}); + + EXPECT_THAT( + this->Query(kDefaultTestPropertyName, /*key_lower=*/1, + /*key_upper=*/3), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(/*document_id=*/2, std::vector<SectionId>{3, 4}), + EqualsDocHitInfo(/*document_id=*/1, std::vector<SectionId>{0, 1}), + EqualsDocHitInfo(/*document_id=*/0, std::vector<SectionId>{1})))); +} + +TYPED_TEST(NumericIndexTest, NonRelevantPropertyShouldNotBeIncluded) { + constexpr std::string_view kNonRelevantProperty = "non_relevant_property"; + this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId, + /*keys=*/{1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId, + /*keys=*/{3}); + this->Index(kNonRelevantProperty, /*document_id=*/2, kDefaultSectionId, + /*keys=*/{2}); + this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId, + /*keys=*/{0}); + this->Index(kNonRelevantProperty, /*document_id=*/4, kDefaultSectionId, + /*keys=*/{4}); + this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId, + /*keys=*/{2}); + + std::vector<SectionId> expected_sections{kDefaultSectionId}; + EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/1, + /*key_upper=*/3), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(/*document_id=*/5, expected_sections), + EqualsDocHitInfo(/*document_id=*/1, expected_sections), + EqualsDocHitInfo(/*document_id=*/0, expected_sections)))); +} + +TYPED_TEST(NumericIndexTest, + RangeQueryKeyLowerGreaterThanKeyUpperShouldReturnError) { + this->Index(kDefaultTestPropertyName, /*document_id=*/0, kDefaultSectionId, + /*keys=*/{1}); + this->Index(kDefaultTestPropertyName, /*document_id=*/1, kDefaultSectionId, + /*keys=*/{3}); + this->Index(kDefaultTestPropertyName, /*document_id=*/2, kDefaultSectionId, + /*keys=*/{2}); + this->Index(kDefaultTestPropertyName, /*document_id=*/3, kDefaultSectionId, + /*keys=*/{0}); + this->Index(kDefaultTestPropertyName, /*document_id=*/4, kDefaultSectionId, + /*keys=*/{4}); + this->Index(kDefaultTestPropertyName, /*document_id=*/5, kDefaultSectionId, + /*keys=*/{2}); + + EXPECT_THAT(this->Query(kDefaultTestPropertyName, /*key_lower=*/3, + /*key_upper=*/1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/numeric/posting-list-integer-index-data-accessor.cc b/icing/index/numeric/posting-list-integer-index-data-accessor.cc new file mode 100644 index 0000000..73b48e2 --- /dev/null +++ b/icing/index/numeric/posting-list-integer-index-data-accessor.cc @@ -0,0 +1,136 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/numeric/posting-list-integer-index-data-accessor.h" + +#include <cstdint> +#include <memory> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/index-block.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/numeric/integer-index-data.h" +#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +/* static */ libtextclassifier3::StatusOr< + std::unique_ptr<PostingListIntegerIndexDataAccessor>> +PostingListIntegerIndexDataAccessor::Create( + FlashIndexStorage* storage, + PostingListUsedIntegerIndexDataSerializer* serializer) { + uint32_t max_posting_list_bytes = IndexBlock::CalculateMaxPostingListBytes( + storage->block_size(), serializer->GetDataTypeBytes()); + std::unique_ptr<uint8_t[]> posting_list_buffer_array = + std::make_unique<uint8_t[]>(max_posting_list_bytes); + ICING_ASSIGN_OR_RETURN( + PostingListUsed posting_list_buffer, + PostingListUsed::CreateFromUnitializedRegion( + serializer, posting_list_buffer_array.get(), max_posting_list_bytes)); + return std::unique_ptr<PostingListIntegerIndexDataAccessor>( + new PostingListIntegerIndexDataAccessor( + storage, std::move(posting_list_buffer_array), + std::move(posting_list_buffer), serializer)); +} + +/* static */ libtextclassifier3::StatusOr< + std::unique_ptr<PostingListIntegerIndexDataAccessor>> +PostingListIntegerIndexDataAccessor::CreateFromExisting( + FlashIndexStorage* storage, + PostingListUsedIntegerIndexDataSerializer* serializer, + PostingListIdentifier existing_posting_list_id) { + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor, + Create(storage, serializer)); + ICING_ASSIGN_OR_RETURN(PostingListHolder holder, + storage->GetPostingList(existing_posting_list_id)); + pl_accessor->preexisting_posting_list_ = + std::make_unique<PostingListHolder>(std::move(holder)); + return pl_accessor; +} + +// Returns the next batch of integer index data for the provided posting list. +libtextclassifier3::StatusOr<std::vector<IntegerIndexData>> +PostingListIntegerIndexDataAccessor::GetNextDataBatch() { + if (preexisting_posting_list_ == nullptr) { + if (has_reached_posting_list_chain_end_) { + return std::vector<IntegerIndexData>(); + } + return absl_ports::FailedPreconditionError( + "Cannot retrieve data from a PostingListIntegerIndexDataAccessor that " + "was not created from a preexisting posting list."); + } + ICING_ASSIGN_OR_RETURN( + std::vector<IntegerIndexData> batch, + serializer_->GetData(&preexisting_posting_list_->posting_list)); + uint32_t next_block_index; + // Posting lists will only be chained when they are max-sized, in which case + // block.next_block_index() will point to the next block for the next posting + // list. Otherwise, block.next_block_index() can be kInvalidBlockIndex or be + // used to point to the next free list block, which is not relevant here. + if (preexisting_posting_list_->block.max_num_posting_lists() == 1) { + next_block_index = preexisting_posting_list_->block.next_block_index(); + } else { + next_block_index = kInvalidBlockIndex; + } + if (next_block_index != kInvalidBlockIndex) { + PostingListIdentifier next_posting_list_id( + next_block_index, /*posting_list_index=*/0, + preexisting_posting_list_->block.posting_list_index_bits()); + ICING_ASSIGN_OR_RETURN(PostingListHolder holder, + storage_->GetPostingList(next_posting_list_id)); + preexisting_posting_list_ = + std::make_unique<PostingListHolder>(std::move(holder)); + } else { + has_reached_posting_list_chain_end_ = true; + preexisting_posting_list_.reset(); + } + return batch; +} + +libtextclassifier3::Status PostingListIntegerIndexDataAccessor::PrependData( + const IntegerIndexData& data) { + PostingListUsed& active_pl = (preexisting_posting_list_ != nullptr) + ? preexisting_posting_list_->posting_list + : posting_list_buffer_; + libtextclassifier3::Status status = + serializer_->PrependData(&active_pl, data); + if (!absl_ports::IsResourceExhausted(status)) { + return status; + } + // There is no more room to add data to this current posting list! Therefore, + // we need to either move those data to a larger posting list or flush this + // posting list and create another max-sized posting list in the chain. + if (preexisting_posting_list_ != nullptr) { + FlushPreexistingPostingList(); + } else { + ICING_RETURN_IF_ERROR(FlushInMemoryPostingList()); + } + + // Re-add data. Should always fit since we just cleared posting_list_buffer_. + // It's fine to explicitly reference posting_list_buffer_ here because there's + // no way of reaching this line while preexisting_posting_list_ is still in + // use. + return serializer_->PrependData(&posting_list_buffer_, data); +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/numeric/posting-list-integer-index-data-accessor.h b/icing/index/numeric/posting-list-integer-index-data-accessor.h new file mode 100644 index 0000000..7835bf9 --- /dev/null +++ b/icing/index/numeric/posting-list-integer-index-data-accessor.h @@ -0,0 +1,108 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_NUMERIC_POSTING_LIST_INTEGER_INDEX_DATA_ACCESSOR_H_ +#define ICING_INDEX_NUMERIC_POSTING_LIST_INTEGER_INDEX_DATA_ACCESSOR_H_ + +#include <cstdint> +#include <memory> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-accessor.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/numeric/integer-index-data.h" +#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h" + +namespace icing { +namespace lib { + +// TODO(b/259743562): Refactor PostingListAccessor derived classes + +// This class is used to provide a simple abstraction for adding integer index +// data to posting lists. PostingListIntegerIndexDataAccessor handles: +// 1) selection of properly-sized posting lists for the accumulated integer +// index data during Finalize() +// 2) chaining of max-sized posting lists. +class PostingListIntegerIndexDataAccessor : public PostingListAccessor { + public: + // Creates an empty PostingListIntegerIndexDataAccessor. + // + // RETURNS: + // - On success, a valid instance of PostingListIntegerIndexDataAccessor + // - INVALID_ARGUMENT error if storage has an invalid block_size. + static libtextclassifier3::StatusOr< + std::unique_ptr<PostingListIntegerIndexDataAccessor>> + Create(FlashIndexStorage* storage, + PostingListUsedIntegerIndexDataSerializer* serializer); + + // Create a PostingListIntegerIndexDataAccessor with an existing posting list + // identified by existing_posting_list_id. + // + // RETURNS: + // - On success, a valid instance of PostingListIntegerIndexDataAccessor + // - INVALID_ARGUMENT if storage has an invalid block_size. + static libtextclassifier3::StatusOr< + std::unique_ptr<PostingListIntegerIndexDataAccessor>> + CreateFromExisting(FlashIndexStorage* storage, + PostingListUsedIntegerIndexDataSerializer* serializer, + PostingListIdentifier existing_posting_list_id); + + PostingListUsedSerializer* GetSerializer() override { return serializer_; } + + // Retrieve the next batch of data in the posting list chain + // + // RETURNS: + // - On success, a vector of integer index data in the posting list chain + // - INTERNAL if called on an instance that was created via Create, if + // unable to read the next posting list in the chain or if the posting + // list has been corrupted somehow. + libtextclassifier3::StatusOr<std::vector<IntegerIndexData>> + GetNextDataBatch(); + + // Prepend one data. This may result in flushing the posting list to disk (if + // the PostingListIntegerIndexDataAccessor holds a max-sized posting list that + // is full) or freeing a pre-existing posting list if it is too small to fit + // all data necessary. + // + // RETURNS: + // - OK, on success + // - INVALID_ARGUMENT if !data.is_valid() or if data is greater than the + // previously added data. + // - RESOURCE_EXHAUSTED error if unable to grow the index to allocate a new + // posting list. + libtextclassifier3::Status PrependData(const IntegerIndexData& data); + + // TODO(b/259743562): [Optimization 1] add GetAndClear, IsFull for split + + private: + explicit PostingListIntegerIndexDataAccessor( + FlashIndexStorage* storage, + std::unique_ptr<uint8_t[]> posting_list_buffer_array, + PostingListUsed posting_list_buffer, + PostingListUsedIntegerIndexDataSerializer* serializer) + : PostingListAccessor(storage, std::move(posting_list_buffer_array), + std::move(posting_list_buffer)), + serializer_(serializer) {} + + PostingListUsedIntegerIndexDataSerializer* serializer_; // Does not own. +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_NUMERIC_POSTING_LIST_INTEGER_INDEX_DATA_ACCESSOR_H_ diff --git a/icing/index/numeric/posting-list-integer-index-data-accessor_test.cc b/icing/index/numeric/posting-list-integer-index-data-accessor_test.cc new file mode 100644 index 0000000..ca0804e --- /dev/null +++ b/icing/index/numeric/posting-list-integer-index-data-accessor_test.cc @@ -0,0 +1,410 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/numeric/posting-list-integer-index-data-accessor.h" + +#include <cstdint> +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/filesystem.h" +#include "icing/file/posting_list/flash-index-storage.h" +#include "icing/file/posting_list/posting-list-identifier.h" +#include "icing/index/numeric/integer-index-data.h" +#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::Eq; +using ::testing::Lt; +using ::testing::SizeIs; + +class PostingListIntegerIndexDataAccessorTest : public ::testing::Test { + protected: + void SetUp() override { + test_dir_ = GetTestTempDir() + "/test_dir"; + file_name_ = test_dir_ + "/test_file.idx.index"; + + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str())); + ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(test_dir_.c_str())); + + serializer_ = std::make_unique<PostingListUsedIntegerIndexDataSerializer>(); + + ICING_ASSERT_OK_AND_ASSIGN( + FlashIndexStorage flash_index_storage, + FlashIndexStorage::Create(file_name_, &filesystem_, serializer_.get())); + flash_index_storage_ = + std::make_unique<FlashIndexStorage>(std::move(flash_index_storage)); + } + + void TearDown() override { + flash_index_storage_.reset(); + serializer_.reset(); + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str())); + } + + Filesystem filesystem_; + std::string test_dir_; + std::string file_name_; + std::unique_ptr<PostingListUsedIntegerIndexDataSerializer> serializer_; + std::unique_ptr<FlashIndexStorage> flash_index_storage_; +}; + +std::vector<IntegerIndexData> CreateData(int num_data, + DocumentId start_document_id, + int64_t start_key) { + SectionId section_id = kMaxSectionId; + + std::vector<IntegerIndexData> data; + data.reserve(num_data); + for (int i = 0; i < num_data; ++i) { + data.push_back(IntegerIndexData(section_id, start_document_id, start_key)); + + if (section_id == kMinSectionId) { + section_id = kMaxSectionId; + } else { + --section_id; + } + ++start_document_id; + ++start_key; + } + return data; +} + +TEST_F(PostingListIntegerIndexDataAccessorTest, DataAddAndRetrieveProperly) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor, + PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add some integer index data + std::vector<IntegerIndexData> data_vec = + CreateData(/*num_data=*/5, /*start_document_id=*/0, /*start_key=*/819); + for (const IntegerIndexData& data : data_vec) { + EXPECT_THAT(pl_accessor->PrependData(data), IsOk()); + } + PostingListAccessor::FinalizeResult result = + std::move(*pl_accessor).Finalize(); + EXPECT_THAT(result.status, IsOk()); + EXPECT_THAT(result.id.block_index(), Eq(1)); + EXPECT_THAT(result.id.posting_list_index(), Eq(0)); + + // Retrieve some data. + ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(result.id)); + EXPECT_THAT( + serializer_->GetData(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(data_vec.rbegin(), data_vec.rend()))); + EXPECT_THAT(pl_holder.block.next_block_index(), Eq(kInvalidBlockIndex)); +} + +TEST_F(PostingListIntegerIndexDataAccessorTest, PreexistingPLKeepOnSameBlock) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor, + PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Add a single data. This will fit in a min-sized posting list. + IntegerIndexData data1(/*section_id=*/1, /*document_id=*/0, /*key=*/12345); + ICING_ASSERT_OK(pl_accessor->PrependData(data1)); + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_ASSERT_OK(result1.status); + // Should be allocated to the first block. + ASSERT_THAT(result1.id.block_index(), Eq(1)); + ASSERT_THAT(result1.id.posting_list_index(), Eq(0)); + + // Add one more data. The minimum size for a posting list must be able to fit + // two data, so this should NOT cause the previous pl to be reallocated. + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListIntegerIndexDataAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); + IntegerIndexData data2(/*section_id=*/1, /*document_id=*/1, /*key=*/23456); + ICING_ASSERT_OK(pl_accessor->PrependData(data2)); + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor).Finalize(); + ICING_ASSERT_OK(result2.status); + // Should be in the same posting list. + EXPECT_THAT(result2.id, Eq(result1.id)); + + // The posting list at result2.id should hold all of the data that have been + // added. + ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(result2.id)); + EXPECT_THAT(serializer_->GetData(&pl_holder.posting_list), + IsOkAndHolds(ElementsAre(data2, data1))); +} + +TEST_F(PostingListIntegerIndexDataAccessorTest, + PreexistingPLReallocateToLargerPL) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor, + PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Adding 3 data should cause Finalize allocating a 48-byte posting list, + // which can store at most 4 data. + std::vector<IntegerIndexData> data_vec1 = + CreateData(/*num_data=*/3, /*start_document_id=*/0, /*start_key=*/819); + for (const IntegerIndexData& data : data_vec1) { + ICING_ASSERT_OK(pl_accessor->PrependData(data)); + } + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_ASSERT_OK(result1.status); + // Should be allocated to the first block. + ASSERT_THAT(result1.id.block_index(), Eq(1)); + ASSERT_THAT(result1.id.posting_list_index(), Eq(0)); + + // Now add more data. + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListIntegerIndexDataAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); + // The current posting list can fit 1 more data. Adding 12 more data should + // result in these data being moved to a larger posting list. Also the total + // size of these data won't exceed max size posting list, so there will be + // only one single posting list and no chain. + std::vector<IntegerIndexData> data_vec2 = CreateData( + /*num_data=*/12, + /*start_document_id=*/data_vec1.back().basic_hit().document_id() + 1, + /*start_key=*/819); + + for (const IntegerIndexData& data : data_vec2) { + ICING_ASSERT_OK(pl_accessor->PrependData(data)); + } + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor).Finalize(); + ICING_ASSERT_OK(result2.status); + // Should be allocated to the second (new) block because the posting list + // should grow beyond the size that the first block maintains. + EXPECT_THAT(result2.id.block_index(), Eq(2)); + EXPECT_THAT(result2.id.posting_list_index(), Eq(0)); + + // The posting list at result2.id should hold all of the data that have been + // added. + std::vector<IntegerIndexData> all_data_vec; + all_data_vec.reserve(data_vec1.size() + data_vec2.size()); + all_data_vec.insert(all_data_vec.end(), data_vec1.begin(), data_vec1.end()); + all_data_vec.insert(all_data_vec.end(), data_vec2.begin(), data_vec2.end()); + ICING_ASSERT_OK_AND_ASSIGN(PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(result2.id)); + EXPECT_THAT(serializer_->GetData(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(all_data_vec.rbegin(), + all_data_vec.rend()))); +} + +TEST_F(PostingListIntegerIndexDataAccessorTest, + MultiBlockChainsBlocksProperly) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor, + PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Block size is 4096, sizeof(BlockHeader) is 12 and sizeof(IntegerIndexData) + // is 12, so the max size posting list can store (4096 - 12) / 12 = 340 data. + // Adding 341 data should cause: + // - 2 max size posting lists being allocated to block 1 and block 2. + // - Chaining: block 2 -> block 1 + std::vector<IntegerIndexData> data_vec = + CreateData(/*num_data=*/341, /*start_document_id=*/0, /*start_key=*/819); + for (const IntegerIndexData& data : data_vec) { + ICING_ASSERT_OK(pl_accessor->PrependData(data)); + } + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_ASSERT_OK(result1.status); + PostingListIdentifier second_block_id = result1.id; + // Should be allocated to the second block. + EXPECT_THAT(second_block_id, Eq(PostingListIdentifier( + /*block_index=*/2, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0))); + + // We should be able to retrieve all data. + ICING_ASSERT_OK_AND_ASSIGN( + PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(second_block_id)); + // This pl_holder will only hold a posting list with the data that didn't fit + // on the first block. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<IntegerIndexData> second_block_data, + serializer_->GetData(&pl_holder.posting_list)); + ASSERT_THAT(second_block_data, SizeIs(Lt(data_vec.size()))); + auto first_block_data_start = data_vec.rbegin() + second_block_data.size(); + EXPECT_THAT(second_block_data, + ElementsAreArray(data_vec.rbegin(), first_block_data_start)); + + // Now retrieve all of the data that were on the first block. + uint32_t first_block_id = pl_holder.block.next_block_index(); + EXPECT_THAT(first_block_id, Eq(1)); + + PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0); + ICING_ASSERT_OK_AND_ASSIGN(pl_holder, + flash_index_storage_->GetPostingList(pl_id)); + EXPECT_THAT( + serializer_->GetData(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(first_block_data_start, data_vec.rend()))); +} + +TEST_F(PostingListIntegerIndexDataAccessorTest, + PreexistingMultiBlockReusesBlocksProperly) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor, + PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + // Block size is 4096, sizeof(BlockHeader) is 12 and sizeof(IntegerIndexData) + // is 12, so the max size posting list can store (4096 - 12) / 12 = 340 data. + // Adding 341 data will cause: + // - 2 max size posting lists being allocated to block 1 and block 2. + // - Chaining: block 2 -> block 1 + std::vector<IntegerIndexData> data_vec1 = + CreateData(/*num_data=*/341, /*start_document_id=*/0, /*start_key=*/819); + for (const IntegerIndexData& data : data_vec1) { + ICING_ASSERT_OK(pl_accessor->PrependData(data)); + } + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor).Finalize(); + ICING_ASSERT_OK(result1.status); + PostingListIdentifier first_add_id = result1.id; + EXPECT_THAT(first_add_id, Eq(PostingListIdentifier( + /*block_index=*/2, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0))); + + // Now add more data. These should fit on the existing second block and not + // fill it up. + ICING_ASSERT_OK_AND_ASSIGN( + pl_accessor, + PostingListIntegerIndexDataAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), first_add_id)); + std::vector<IntegerIndexData> data_vec2 = CreateData( + /*num_data=*/10, + /*start_document_id=*/data_vec1.back().basic_hit().document_id() + 1, + /*start_key=*/819); + for (const IntegerIndexData& data : data_vec2) { + ICING_ASSERT_OK(pl_accessor->PrependData(data)); + } + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor).Finalize(); + ICING_ASSERT_OK(result2.status); + PostingListIdentifier second_add_id = result2.id; + EXPECT_THAT(second_add_id, Eq(first_add_id)); + + // We should be able to retrieve all data. + std::vector<IntegerIndexData> all_data_vec; + all_data_vec.reserve(data_vec1.size() + data_vec2.size()); + all_data_vec.insert(all_data_vec.end(), data_vec1.begin(), data_vec1.end()); + all_data_vec.insert(all_data_vec.end(), data_vec2.begin(), data_vec2.end()); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListHolder pl_holder, + flash_index_storage_->GetPostingList(second_add_id)); + // This pl_holder will only hold a posting list with the data that didn't fit + // on the first block. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<IntegerIndexData> second_block_data, + serializer_->GetData(&pl_holder.posting_list)); + ASSERT_THAT(second_block_data, SizeIs(Lt(all_data_vec.size()))); + auto first_block_data_start = + all_data_vec.rbegin() + second_block_data.size(); + EXPECT_THAT(second_block_data, + ElementsAreArray(all_data_vec.rbegin(), first_block_data_start)); + + // Now retrieve all of the data that were on the first block. + uint32_t first_block_id = pl_holder.block.next_block_index(); + EXPECT_THAT(first_block_id, Eq(1)); + + PostingListIdentifier pl_id(first_block_id, /*posting_list_index=*/0, + /*posting_list_index_bits=*/0); + ICING_ASSERT_OK_AND_ASSIGN(pl_holder, + flash_index_storage_->GetPostingList(pl_id)); + EXPECT_THAT(serializer_->GetData(&pl_holder.posting_list), + IsOkAndHolds(ElementsAreArray(first_block_data_start, + all_data_vec.rend()))); +} + +TEST_F(PostingListIntegerIndexDataAccessorTest, + InvalidDataShouldReturnInvalidArgument) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor, + PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + IntegerIndexData invalid_data; + EXPECT_THAT(pl_accessor->PrependData(invalid_data), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PostingListIntegerIndexDataAccessorTest, + BasicHitIncreasingShouldReturnInvalidArgument) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor, + PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + IntegerIndexData data1(/*section_id=*/3, /*document_id=*/1, /*key=*/12345); + ICING_ASSERT_OK(pl_accessor->PrependData(data1)); + + IntegerIndexData data2(/*section_id=*/6, /*document_id=*/1, /*key=*/12345); + EXPECT_THAT(pl_accessor->PrependData(data2), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + IntegerIndexData data3(/*section_id=*/2, /*document_id=*/0, /*key=*/12345); + EXPECT_THAT(pl_accessor->PrependData(data3), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PostingListIntegerIndexDataAccessorTest, + NewPostingListNoDataAddedShouldReturnInvalidArgument) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor, + PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + PostingListAccessor::FinalizeResult result = + std::move(*pl_accessor).Finalize(); + EXPECT_THAT(result.status, + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PostingListIntegerIndexDataAccessorTest, + PreexistingPostingListNoDataAddedShouldSucceed) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor1, + PostingListIntegerIndexDataAccessor::Create(flash_index_storage_.get(), + serializer_.get())); + IntegerIndexData data1(/*section_id=*/3, /*document_id=*/1, /*key=*/12345); + ICING_ASSERT_OK(pl_accessor1->PrependData(data1)); + PostingListAccessor::FinalizeResult result1 = + std::move(*pl_accessor1).Finalize(); + ICING_ASSERT_OK(result1.status); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PostingListIntegerIndexDataAccessor> pl_accessor2, + PostingListIntegerIndexDataAccessor::CreateFromExisting( + flash_index_storage_.get(), serializer_.get(), result1.id)); + PostingListAccessor::FinalizeResult result2 = + std::move(*pl_accessor2).Finalize(); + EXPECT_THAT(result2.status, IsOk()); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/numeric/posting-list-used-integer-index-data-serializer.cc b/icing/index/numeric/posting-list-used-integer-index-data-serializer.cc new file mode 100644 index 0000000..800fd6b --- /dev/null +++ b/icing/index/numeric/posting-list-used-integer-index-data-serializer.cc @@ -0,0 +1,514 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h" + +#include <cstdint> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/numeric/integer-index-data.h" +#include "icing/legacy/core/icing-string-util.h" +#include "icing/util/logging.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +uint32_t PostingListUsedIntegerIndexDataSerializer::GetBytesUsed( + const PostingListUsed* posting_list_used) const { + // The special data will be included if they represent actual data. If they + // represent the data start offset or the invalid data sentinel, they are not + // included. + return posting_list_used->size_in_bytes() - + GetStartByteOffset(posting_list_used); +} + +uint32_t PostingListUsedIntegerIndexDataSerializer::GetMinPostingListSizeToFit( + const PostingListUsed* posting_list_used) const { + if (IsFull(posting_list_used) || IsAlmostFull(posting_list_used)) { + // If in either the FULL state or ALMOST_FULL state, this posting list *is* + // the minimum size posting list that can fit these data. So just return the + // size of the posting list. + return posting_list_used->size_in_bytes(); + } + + // In NOT_FULL state, BytesUsed contains no special data. The minimum sized + // posting list that would be guaranteed to fit these data would be + // ALMOST_FULL, with kInvalidData in special data 0, the uncompressed data in + // special data 1 and the n compressed data in the compressed region. + // BytesUsed contains one uncompressed data and n compressed data. Therefore, + // fitting these data into a posting list would require BytesUsed plus one + // extra data. + return GetBytesUsed(posting_list_used) + GetDataTypeBytes(); +} + +void PostingListUsedIntegerIndexDataSerializer::Clear( + PostingListUsed* posting_list_used) const { + // Safe to ignore return value because posting_list_used->size_in_bytes() is + // a valid argument. + SetStartByteOffset(posting_list_used, + /*offset=*/posting_list_used->size_in_bytes()); +} + +libtextclassifier3::Status PostingListUsedIntegerIndexDataSerializer::MoveFrom( + PostingListUsed* dst, PostingListUsed* src) const { + ICING_RETURN_ERROR_IF_NULL(dst); + ICING_RETURN_ERROR_IF_NULL(src); + if (GetMinPostingListSizeToFit(src) > dst->size_in_bytes()) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "src MinPostingListSizeToFit %d must be larger than size %d.", + GetMinPostingListSizeToFit(src), dst->size_in_bytes())); + } + + if (!IsPostingListValid(dst)) { + return absl_ports::FailedPreconditionError( + "Dst posting list is in an invalid state and can't be used!"); + } + if (!IsPostingListValid(src)) { + return absl_ports::InvalidArgumentError( + "Cannot MoveFrom an invalid src posting list!"); + } + + // Pop just enough data that all of src's compressed data fit in + // dst posting_list's compressed area. Then we can memcpy that area. + std::vector<IntegerIndexData> data_arr; + while (IsFull(src) || IsAlmostFull(src) || + (dst->size_in_bytes() - kSpecialDataSize < GetBytesUsed(src))) { + if (!GetDataInternal(src, /*limit=*/1, /*pop=*/true, &data_arr).ok()) { + return absl_ports::AbortedError( + "Unable to retrieve data from src posting list."); + } + } + + // memcpy the area and set up start byte offset. + Clear(dst); + memcpy(dst->posting_list_buffer() + dst->size_in_bytes() - GetBytesUsed(src), + src->posting_list_buffer() + GetStartByteOffset(src), + GetBytesUsed(src)); + // Because we popped all data from src outside of the compressed area and we + // guaranteed that GetBytesUsed(src) is less than dst->size_in_bytes() - + // kSpecialDataSize. This is guaranteed to be a valid byte offset for the + // NOT_FULL state, so ignoring the value is safe. + SetStartByteOffset(dst, dst->size_in_bytes() - GetBytesUsed(src)); + + // Put back remaining data. + for (auto riter = data_arr.rbegin(); riter != data_arr.rend(); ++riter) { + // PrependData may return: + // - INVALID_ARGUMENT: if data is invalid or not less than the previous data + // - RESOURCE_EXHAUSTED + // RESOURCE_EXHAUSTED should be impossible because we've already assured + // that there is enough room above. + ICING_RETURN_IF_ERROR(PrependData(dst, *riter)); + } + + Clear(src); + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status +PostingListUsedIntegerIndexDataSerializer::PrependDataToAlmostFull( + PostingListUsed* posting_list_used, const IntegerIndexData& data) const { + SpecialDataType special_data = GetSpecialData(posting_list_used, /*index=*/1); + if (special_data.data().basic_hit() < data.basic_hit()) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "BasicHit %d being prepended must not be greater than the most recent" + "BasicHit %d", + data.basic_hit().value(), special_data.data().basic_hit().value())); + } + + // TODO(b/259743562): [Optimization 2] compression + // Without compression, prepend a new data into ALMOST_FULL posting list will + // change the posting list to FULL state. Therefore, set special data 0 + // directly. + SetSpecialData(posting_list_used, /*index=*/0, SpecialDataType(data)); + return libtextclassifier3::Status::OK; +} + +void PostingListUsedIntegerIndexDataSerializer::PrependDataToEmpty( + PostingListUsed* posting_list_used, const IntegerIndexData& data) const { + // First data to be added. Just add verbatim, no compression. + if (posting_list_used->size_in_bytes() == kSpecialDataSize) { + // First data will be stored at special data 1. + // Safe to ignore the return value because 1 < kNumSpecialData + SetSpecialData(posting_list_used, /*index=*/1, SpecialDataType(data)); + // Safe to ignore the return value because sizeof(IntegerIndexData) is a + // valid argument. + SetStartByteOffset(posting_list_used, + /*offset=*/sizeof(IntegerIndexData)); + } else { + // Since this is the first data, size != kSpecialDataSize and + // size % sizeof(IntegerIndexData) == 0, we know that there is room to fit + // 'data' into the compressed region, so ValueOrDie is safe. + uint32_t offset = + PrependDataUncompressed(posting_list_used, data, + /*offset=*/posting_list_used->size_in_bytes()) + .ValueOrDie(); + // Safe to ignore the return value because PrependDataUncompressed is + // guaranteed to return a valid offset. + SetStartByteOffset(posting_list_used, offset); + } +} + +libtextclassifier3::Status +PostingListUsedIntegerIndexDataSerializer::PrependDataToNotFull( + PostingListUsed* posting_list_used, const IntegerIndexData& data, + uint32_t offset) const { + IntegerIndexData cur; + memcpy(&cur, posting_list_used->posting_list_buffer() + offset, + sizeof(IntegerIndexData)); + if (cur.basic_hit() < data.basic_hit()) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "BasicHit %d being prepended must not be greater than the most recent" + "BasicHit %d", + data.basic_hit().value(), cur.basic_hit().value())); + } + + // TODO(b/259743562): [Optimization 2] compression + if (offset >= kSpecialDataSize + sizeof(IntegerIndexData)) { + offset = + PrependDataUncompressed(posting_list_used, data, offset).ValueOrDie(); + SetStartByteOffset(posting_list_used, offset); + } else { + // The new data must be put in special data 1. + SetSpecialData(posting_list_used, /*index=*/1, SpecialDataType(data)); + // State ALMOST_FULL. Safe to ignore the return value because + // sizeof(IntegerIndexData) is a valid argument. + SetStartByteOffset(posting_list_used, /*offset=*/sizeof(IntegerIndexData)); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status +PostingListUsedIntegerIndexDataSerializer::PrependData( + PostingListUsed* posting_list_used, const IntegerIndexData& data) const { + static_assert( + sizeof(BasicHit::Value) <= sizeof(uint64_t), + "BasicHit::Value cannot be larger than 8 bytes because the delta " + "must be able to fit in 8 bytes."); + + if (!data.is_valid()) { + return absl_ports::InvalidArgumentError("Cannot prepend an invalid data!"); + } + if (!IsPostingListValid(posting_list_used)) { + return absl_ports::FailedPreconditionError( + "This PostingListUsed is in an invalid state and can't add any data!"); + } + + if (IsFull(posting_list_used)) { + // State FULL: no space left. + return absl_ports::ResourceExhaustedError("No more room for data"); + } else if (IsAlmostFull(posting_list_used)) { + return PrependDataToAlmostFull(posting_list_used, data); + } else if (IsEmpty(posting_list_used)) { + PrependDataToEmpty(posting_list_used, data); + return libtextclassifier3::Status::OK; + } else { + uint32_t offset = GetStartByteOffset(posting_list_used); + return PrependDataToNotFull(posting_list_used, data, offset); + } +} + +uint32_t PostingListUsedIntegerIndexDataSerializer::PrependDataArray( + PostingListUsed* posting_list_used, const IntegerIndexData* array, + uint32_t num_data, bool keep_prepended) const { + if (!IsPostingListValid(posting_list_used)) { + return 0; + } + + uint32_t i; + for (i = 0; i < num_data; ++i) { + if (!PrependData(posting_list_used, array[i]).ok()) { + break; + } + } + if (i != num_data && !keep_prepended) { + // Didn't fit. Undo everything and check that we have the same offset as + // before. PopFrontData guarantees that it will remove all 'i' data so long + // as there are at least 'i' data in the posting list, which we know there + // are. + PopFrontData(posting_list_used, /*num_data=*/i); + return 0; + } + return i; +} + +libtextclassifier3::StatusOr<std::vector<IntegerIndexData>> +PostingListUsedIntegerIndexDataSerializer::GetData( + const PostingListUsed* posting_list_used) const { + std::vector<IntegerIndexData> data_arr_out; + ICING_RETURN_IF_ERROR(GetData(posting_list_used, &data_arr_out)); + return data_arr_out; +} + +libtextclassifier3::Status PostingListUsedIntegerIndexDataSerializer::GetData( + const PostingListUsed* posting_list_used, + std::vector<IntegerIndexData>* data_arr_out) const { + return GetDataInternal(posting_list_used, + /*limit=*/std::numeric_limits<uint32_t>::max(), + /*pop=*/false, data_arr_out); +} + +libtextclassifier3::Status +PostingListUsedIntegerIndexDataSerializer::PopFrontData( + PostingListUsed* posting_list_used, uint32_t num_data) const { + if (num_data == 1 && IsFull(posting_list_used)) { + // The PL is in FULL state which means that we save 2 uncompressed data in + // the 2 special postions. But FULL state may be reached by 2 different + // states. + // (1) In ALMOST_FULL state + // +------------------+-----------------+-----+---------------------------+ + // |Data::Invalid |1st data |(pad)|(compressed) data | + // | | | | | + // +------------------+-----------------+-----+---------------------------+ + // When we prepend another data, we can only put it at special data 0, and + // thus get a FULL PL + // +------------------+-----------------+-----+---------------------------+ + // |new 1st data |original 1st data|(pad)|(compressed) data | + // | | | | | + // +------------------+-----------------+-----+---------------------------+ + // + // (2) In NOT_FULL state + // +------------------+-----------------+-------+---------+---------------+ + // |data-start-offset |Data::Invalid |(pad) |1st data |(compressed) | + // | | | | |data | + // +------------------+-----------------+-------+---------+---------------+ + // When we prepend another data, we can reach any of the 3 following + // scenarios: + // (2.1) NOT_FULL + // if the space of pad and original 1st data can accommodate the new 1st + // data and the encoded delta value. + // +------------------+-----------------+-----+--------+------------------+ + // |data-start-offset |Data::Invalid |(pad)|new |(compressed) data | + // | | | |1st data| | + // +------------------+-----------------+-----+--------+------------------+ + // (2.2) ALMOST_FULL + // If the space of pad and original 1st data cannot accommodate the new 1st + // data and the encoded delta value but can accommodate the encoded delta + // value only. We can put the new 1st data at special position 1. + // +------------------+-----------------+---------+-----------------------+ + // |Data::Invalid |new 1st data |(pad) |(compressed) data | + // | | | | | + // +------------------+-----------------+---------+-----------------------+ + // (2.3) FULL + // In very rare case, it cannot even accommodate only the encoded delta + // value. we can move the original 1st data into special position 1 and the + // new 1st data into special position 0. This may happen because we use + // VarInt encoding method which may make the encoded value longer (about + // 4/3 times of original) + // +------------------+-----------------+--------------+------------------+ + // |new 1st data |original 1st data|(pad) |(compressed) data | + // | | | | | + // +------------------+-----------------+--------------+------------------+ + // + // Suppose now the PL is in FULL state. But we don't know whether it arrived + // this state from NOT_FULL (like (2.3)) or from ALMOST_FULL (like (1)). + // We'll return to ALMOST_FULL state like (1) if we simply pop the new 1st + // data, but we want to make the prepending operation "reversible". So + // there should be some way to return to NOT_FULL if possible. A simple way + // to do is: + // - Pop 2 data out of the PL to state ALMOST_FULL or NOT_FULL. + // - Add the second data ("original 1st data") back. + // + // Then we can return to the correct original states of (2.1) or (1). This + // makes our prepending operation reversible. + std::vector<IntegerIndexData> out; + + // Popping 2 data should never fail because we've just ensured that the + // posting list is in the FULL state. + ICING_RETURN_IF_ERROR( + GetDataInternal(posting_list_used, /*limit=*/2, /*pop=*/true, &out)); + + // PrependData should never fail because: + // - out[1] is a valid data less than all previous data in the posting list. + // - There's no way that the posting list could run out of room because it + // previously stored these 2 data. + PrependData(posting_list_used, out[1]); + } else if (num_data > 0) { + return GetDataInternal(posting_list_used, /*limit=*/num_data, /*pop=*/true, + /*out=*/nullptr); + } + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status +PostingListUsedIntegerIndexDataSerializer::GetDataInternal( + const PostingListUsed* posting_list_used, uint32_t limit, bool pop, + std::vector<IntegerIndexData>* out) const { + // TODO(b/259743562): [Optimization 2] handle compressed data + + uint32_t offset = GetStartByteOffset(posting_list_used); + uint32_t count = 0; + + // First traverse the first two special positions. + while (count < limit && offset < kSpecialDataSize) { + // offset / sizeof(IntegerIndexData) < kNumSpecialData because of the check + // above. + SpecialDataType special_data = + GetSpecialData(posting_list_used, + /*index=*/offset / sizeof(IntegerIndexData)); + if (out != nullptr) { + out->push_back(special_data.data()); + } + offset += sizeof(IntegerIndexData); + ++count; + } + + // - We don't compress the data now. + // - The posting list size is a multiple of data type bytes. + // So offset of the first non-special data is guaranteed to be at + // kSpecialDataSize if in ALMOST_FULL or FULL state. In fact, we must not + // apply padding skipping logic here when still storing uncompressed data, + // because in this case 0 bytes are meanful (e.g. inverted doc id byte = 0). + // TODO(b/259743562): [Optimization 2] deal with padding skipping logic when + // apply data compression. + + while (count < limit && offset < posting_list_used->size_in_bytes()) { + IntegerIndexData data; + memcpy(&data, posting_list_used->posting_list_buffer() + offset, + sizeof(IntegerIndexData)); + offset += sizeof(IntegerIndexData); + if (out != nullptr) { + out->push_back(data); + } + ++count; + } + + if (pop) { + PostingListUsed* mutable_posting_list_used = + const_cast<PostingListUsed*>(posting_list_used); + // Modify the posting list so that we pop all data actually traversed. + if (offset >= kSpecialDataSize && + offset < posting_list_used->size_in_bytes()) { + memset( + mutable_posting_list_used->posting_list_buffer() + kSpecialDataSize, + 0, offset - kSpecialDataSize); + } + SetStartByteOffset(mutable_posting_list_used, offset); + } + + return libtextclassifier3::Status::OK; +} + +PostingListUsedIntegerIndexDataSerializer::SpecialDataType +PostingListUsedIntegerIndexDataSerializer::GetSpecialData( + const PostingListUsed* posting_list_used, uint32_t index) const { + // It is ok to temporarily construct a SpecialData with offset = 0 since we're + // going to overwrite it by memcpy. + SpecialDataType special_data(0); + memcpy(&special_data, + posting_list_used->posting_list_buffer() + + index * sizeof(SpecialDataType), + sizeof(SpecialDataType)); + return special_data; +} + +void PostingListUsedIntegerIndexDataSerializer::SetSpecialData( + PostingListUsed* posting_list_used, uint32_t index, + const SpecialDataType& special_data) const { + memcpy(posting_list_used->posting_list_buffer() + + index * sizeof(SpecialDataType), + &special_data, sizeof(SpecialDataType)); +} + +bool PostingListUsedIntegerIndexDataSerializer::IsPostingListValid( + const PostingListUsed* posting_list_used) const { + if (IsAlmostFull(posting_list_used)) { + // Special data 1 should hold a valid data. + if (!GetSpecialData(posting_list_used, /*index=*/1).data().is_valid()) { + ICING_LOG(ERROR) + << "Both special data cannot be invalid at the same time."; + return false; + } + } else if (!IsFull(posting_list_used)) { + // NOT_FULL. Special data 0 should hold a valid offset. + SpecialDataType special_data = + GetSpecialData(posting_list_used, /*index=*/0); + if (special_data.data_start_offset() > posting_list_used->size_in_bytes() || + special_data.data_start_offset() < kSpecialDataSize) { + ICING_LOG(ERROR) << "Offset: " << special_data.data_start_offset() + << " size: " << posting_list_used->size_in_bytes() + << " sp size: " << kSpecialDataSize; + return false; + } + } + return true; +} + +uint32_t PostingListUsedIntegerIndexDataSerializer::GetStartByteOffset( + const PostingListUsed* posting_list_used) const { + if (IsFull(posting_list_used)) { + return 0; + } else if (IsAlmostFull(posting_list_used)) { + return sizeof(IntegerIndexData); + } else { + return GetSpecialData(posting_list_used, /*index=*/0).data_start_offset(); + } +} + +bool PostingListUsedIntegerIndexDataSerializer::SetStartByteOffset( + PostingListUsed* posting_list_used, uint32_t offset) const { + if (offset > posting_list_used->size_in_bytes()) { + ICING_LOG(ERROR) << "offset cannot be a value greater than size " + << posting_list_used->size_in_bytes() << ". offset is " + << offset << "."; + return false; + } + if (offset < kSpecialDataSize && offset > sizeof(IntegerIndexData)) { + ICING_LOG(ERROR) << "offset cannot be a value between (" + << sizeof(IntegerIndexData) << ", " << kSpecialDataSize + << "). offset is " << offset << "."; + return false; + } + if (offset < sizeof(IntegerIndexData) && offset != 0) { + ICING_LOG(ERROR) << "offset cannot be a value between (0, " + << sizeof(IntegerIndexData) << "). offset is " << offset + << "."; + return false; + } + + if (offset >= kSpecialDataSize) { + // NOT_FULL state. + SetSpecialData(posting_list_used, /*index=*/0, SpecialDataType(offset)); + SetSpecialData(posting_list_used, /*index=*/1, + SpecialDataType(IntegerIndexData())); + } else if (offset == sizeof(IntegerIndexData)) { + // ALMOST_FULL state. + SetSpecialData(posting_list_used, /*index=*/0, + SpecialDataType(IntegerIndexData())); + } + // Nothing to do for the FULL state - the offset isn't actually stored + // anywhere and both 2 special data hold valid data. + return true; +} + +libtextclassifier3::StatusOr<uint32_t> +PostingListUsedIntegerIndexDataSerializer::PrependDataUncompressed( + PostingListUsed* posting_list_used, const IntegerIndexData& data, + uint32_t offset) const { + if (offset < kSpecialDataSize + sizeof(IntegerIndexData)) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Not enough room to prepend IntegerIndexData at offset %d.", offset)); + } + offset -= sizeof(IntegerIndexData); + memcpy(posting_list_used->posting_list_buffer() + offset, &data, + sizeof(IntegerIndexData)); + return offset; +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/numeric/posting-list-used-integer-index-data-serializer.h b/icing/index/numeric/posting-list-used-integer-index-data-serializer.h new file mode 100644 index 0000000..49007e3 --- /dev/null +++ b/icing/index/numeric/posting-list-used-integer-index-data-serializer.h @@ -0,0 +1,338 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_NUMERIC_POSTING_LIST_USED_INTEGER_INDEX_DATA_SERIALIZER_H_ +#define ICING_INDEX_NUMERIC_POSTING_LIST_USED_INTEGER_INDEX_DATA_SERIALIZER_H_ + +#include <cstdint> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/file/posting_list/posting-list-common.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/numeric/integer-index-data.h" + +namespace icing { +namespace lib { + +// A serializer class to serialize IntegerIndexData to PostingListUsed. +class PostingListUsedIntegerIndexDataSerializer + : public PostingListUsedSerializer { + public: + using SpecialDataType = SpecialData<IntegerIndexData>; + static_assert(sizeof(SpecialDataType) == sizeof(IntegerIndexData), ""); + + static constexpr uint32_t kSpecialDataSize = + kNumSpecialData * sizeof(SpecialDataType); + + uint32_t GetDataTypeBytes() const override { + return sizeof(IntegerIndexData); + } + + uint32_t GetMinPostingListSize() const override { + static constexpr uint32_t kMinPostingListSize = kSpecialDataSize; + static_assert(sizeof(PostingListIndex) <= kMinPostingListSize, + "PostingListIndex must be small enough to fit in a " + "minimum-sized Posting List."); + + return kMinPostingListSize; + } + + uint32_t GetMinPostingListSizeToFit( + const PostingListUsed* posting_list_used) const override; + + uint32_t GetBytesUsed( + const PostingListUsed* posting_list_used) const override; + + void Clear(PostingListUsed* posting_list_used) const override; + + libtextclassifier3::Status MoveFrom(PostingListUsed* dst, + PostingListUsed* src) const override; + + // Prepend an IntegerIndexData to the posting list. + // + // RETURNS: + // - INVALID_ARGUMENT if !data.is_valid() or if data is not less than the + // previously added data. + // - RESOURCE_EXHAUSTED if there is no more room to add data to the posting + // list. + libtextclassifier3::Status PrependData(PostingListUsed* posting_list_used, + const IntegerIndexData& data) const; + + // Prepend multiple IntegerIndexData to the posting list. Data should be + // sorted in ascending order (as defined by the less than operator for + // IntegerIndexData) + // If keep_prepended is true, whatever could be prepended is kept, otherwise + // the posting list is reverted and left in its original state. + // + // RETURNS: + // The number of data that have been prepended to the posting list. If + // keep_prepended is false and reverted, then it returns 0. + uint32_t PrependDataArray(PostingListUsed* posting_list_used, + const IntegerIndexData* array, uint32_t num_data, + bool keep_prepended) const; + + // Retrieves all data stored in the posting list. + // + // RETURNS: + // - On success, a vector of IntegerIndexData sorted by the reverse order of + // prepending. + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::StatusOr<std::vector<IntegerIndexData>> GetData( + const PostingListUsed* posting_list_used) const; + + // Same as GetData but appends data to data_arr_out. + // + // RETURNS: + // - OK on success, and data_arr_out will be appended IntegerIndexData + // sorted by the reverse order of prepending. + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::Status GetData( + const PostingListUsed* posting_list_used, + std::vector<IntegerIndexData>* data_arr_out) const; + + // Undo the last num_data data prepended. If num_data > number of data, then + // we clear all data. + // + // RETURNS: + // - OK on success + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::Status PopFrontData(PostingListUsed* posting_list_used, + uint32_t num_data) const; + + private: + // Posting list layout formats: + // + // NOT_FULL + // +-special-data-0--+-special-data-1--+------------+-----------------------+ + // | | | | | + // |data-start-offset| Data::Invalid | 0x00000000 | (compressed) data | + // | | | | | + // +-----------------+-----------------+------------+-----------------------+ + // + // ALMOST_FULL + // +-special-data-0--+-special-data-1--+-----+------------------------------+ + // | | | | | + // | Data::Invalid | 1st data |(pad)| (compressed) data | + // | | | | | + // +-----------------+-----------------+-----+------------------------------+ + // + // FULL + // +-special-data-0--+-special-data-1--+-----+------------------------------+ + // | | | | | + // | 1st data | 2nd data |(pad)| (compressed) data | + // | | | | | + // +-----------------+-----------------+-----+------------------------------+ + // + // The first two uncompressed (special) data also implicitly encode + // information about the size of the compressed data region. + // + // 1. If the posting list is NOT_FULL, then special_data_0 contains the byte + // offset of the start of the compressed data. Thus, the size of the + // compressed data is + // posting_list_used->size_in_bytes() - special_data_0.data_start_offset(). + // + // 2. If posting list is ALMOST_FULL or FULL, then the compressed data region + // starts somewhere between + // [kSpecialDataSize, kSpecialDataSize + sizeof(IntegerIndexData) - 1] and + // ends at posting_list_used->size_in_bytes() - 1. + // + // EXAMPLE + // Posting list storage. Posting list size: 36 bytes + // + // EMPTY! + // +--- byte 0-11 ---+----- 12-23 -----+-------------- 24-35 ---------------+ + // | | | | + // | 36 | Data::Invalid | 0x00000000 | + // | | | | + // +-----------------+-----------------+------------------------------------+ + // + // Add IntegerIndexData(0x0FFFFCC3, 5) + // (DocumentId = 12, SectionId = 3; Key = 5) + // (VarInt64(5) is encoded as 10 (b'1010), requires 1 byte) + // NOT FULL! + // +--- byte 0-11 ---+----- 12-23 -----+------- 24-30 -------+--- 31-35 ----+ + // | | | | 0x0FFFFCC3 | + // | 31 | Data::Invalid | 0x00000000 | VI64(5) | + // | | | | | + // +-----------------+-----------------+---------------------+--------------+ + // + // Add IntegerIndexData(0x0FFFFB40, -2) + // (DocumentId = 18, SectionId = 0; Key = -2) + // (VarInt64(-2) is encoded as 3 (b'11), requires 1 byte) + // Previous IntegerIndexData BasicHit delta varint encoding: + // 0x0FFFFCC3 - 0x0FFFFB40 = 387, VarUnsignedInt(387) requires 2 bytes + // +--- byte 0-11 ---+----- 12-23 -----+-- 24-27 ---+--- 28-32 ----+ 33-35 -+ + // | | | | 0x0FFFFB40 |VUI(387)| + // | 28 | Data::Invalid | 0x00 | VI64(-2) |VI64(5) | + // | | | | | | + // +-----------------+-----------------+------------+--------------+--------+ + // + // Add IntegerIndexData(0x0FFFFA4A, 3) + // (DocumentId = 22, SectionId = 10; Key = 3) + // (VarInt64(3) is encoded as 6 (b'110), requires 1 byte) + // Previous IntegerIndexData BasicHit delta varint encoding: + // 0x0FFFFB40 - 0x0FFFFA4A = 246, VarUnsignedInt(246) requires 2 bytes + // +--- byte 0-11 ---+----- 12-23 -----+---+--- 25-29 ----+ 30-32 -+ 33-35 -+ + // | | | | 0x0FFFFA4A |VUI(246)|VUI(387)| + // | 25 | Data::Invalid | | VI64(3) |VI64(-2)|VI64(5) | + // | | | | | | | + // +-----------------+-----------------+---+--------------+--------+--------+ + // + // Add IntegerIndexData(0x0FFFFA01, -4) + // (DocumentId = 23, SectionId = 1; Key = -4) + // (No VarInt64 for key, since it is stored in special data section) + // Previous IntegerIndexData BasicHit delta varint encoding: + // 0x0FFFFA4A - 0x0FFFFA01 = 73, VarUnsignedInt(73) requires 1 byte) + // ALMOST_FULL! + // +--- byte 0-11 ---+----- 12-23 -----+-- 24-27 ---+28-29+ 30-32 -+ 33-35 -+ + // | | 0x0FFFFA01 | |(73) |VUI(246)|VUI(387)| + // | Data::Invalid | 0xFFFFFFFF | (pad) |(3) |VI64(-2)|VI64(5) | + // | | 0xFFFFFFFC | | | | | + // +-----------------+-----------------+------------+-----+--------+--------+ + // + // Add IntegerIndexData(0x0FFFF904, 0) + // (DocumentId = 27, SectionId = 4; Key = 0) + // (No VarInt64 for key, since it is stored in special data section) + // Previous IntegerIndexData: + // Since 0x0FFFFA01 - 0x0FFFF904 = 253 and VarInt64(-4) is encoded as 7 + // (b'111), it requires only 3 bytes after compression. It's able to fit + // into the padding section. + // Still ALMOST_FULL! + // +--- byte 0-11 ---+----- 12-23 -----+---+ 25-27 -+28-29+ 30-32 -+ 33-35 -+ + // | | 0x0FFFF904 | |VUI(253)|(73) |VUI(246)|VUI(387)| + // | Data::Invalid | 0x00000000 | |VI64(-4)|(3) |VI64(-2)|VI64(5) | + // | | 0x00000000 | | | | | | + // +-----------------+-----------------+---+--------+-----+--------+--------+ + // + // Add IntegerIndexData(0x0FFFF8C3, -1) + // (DocumentId = 28, SectionId = 3; Key = -1) + // (No VarInt64 for key, since it is stored in special data section) + // (No VarUnsignedInt for previous IntegerIndexData BasicHit) + // FULL! + // +--- byte 0-11 ---+----- 12-23 -----+---+ 25-27 -+28-29+ 30-32 -+ 33-35 -+ + // | 0x0FFFF8C3 | 0x0FFFF904 | |VUI(253)|(73) |VUI(246)|VUI(387)| + // | 0xFFFFFFFF | 0x00000000 | |VI64(-4)|(3) |VI64(-2)|VI64(5) | + // | 0xFFFFFFFF | 0x00000000 | | | | | | + // +-----------------+-----------------+---+--------+-----+--------+--------+ + + // Helpers to determine what state the posting list is in. + bool IsFull(const PostingListUsed* posting_list_used) const { + return GetSpecialData(posting_list_used, /*index=*/0).data().is_valid() && + GetSpecialData(posting_list_used, /*index=*/1).data().is_valid(); + } + + bool IsAlmostFull(const PostingListUsed* posting_list_used) const { + return !GetSpecialData(posting_list_used, /*index=*/0).data().is_valid() && + GetSpecialData(posting_list_used, /*index=*/1).data().is_valid(); + } + + bool IsEmpty(const PostingListUsed* posting_list_used) const { + return GetSpecialData(posting_list_used, /*index=*/0).data_start_offset() == + posting_list_used->size_in_bytes() && + !GetSpecialData(posting_list_used, /*index=*/1).data().is_valid(); + } + + // Returns false if both special data are invalid or if data start offset + // stored in the special data is less than kSpecialDataSize or greater than + // posting_list_used->size_in_bytes(). Returns true, otherwise. + bool IsPostingListValid(const PostingListUsed* posting_list_used) const; + + // Prepend data to a posting list that is in the ALMOST_FULL state. + // + // RETURNS: + // - OK, if successful + // - INVALID_ARGUMENT if data is not less than the previously added data. + libtextclassifier3::Status PrependDataToAlmostFull( + PostingListUsed* posting_list_used, const IntegerIndexData& data) const; + + // Prepend data to a posting list that is in the EMPTY state. This will always + // succeed because there are no pre-existing data and no validly constructed + // posting list could fail to fit one data. + void PrependDataToEmpty(PostingListUsed* posting_list_used, + const IntegerIndexData& data) const; + + // Prepend data to a posting list that is in the NOT_FULL state. + // + // RETURNS: + // - OK, if successful + // - INVALID_ARGUMENT if data is not less than the previously added data. + libtextclassifier3::Status PrependDataToNotFull( + PostingListUsed* posting_list_used, const IntegerIndexData& data, + uint32_t offset) const; + + // Returns either 0 (FULL state), sizeof(IntegerIndexData) (ALMOST_FULL state) + // or a byte offset between kSpecialDataSize and + // posting_list_used->size_in_bytes() (inclusive) (NOT_FULL state). + uint32_t GetStartByteOffset(const PostingListUsed* posting_list_used) const; + + // Sets special data 0 to properly reflect what start byte offset is (see + // layout comment for further details). + // + // Returns false if offset > posting_list_used->size_in_bytes() or offset is + // in range (kSpecialDataSize, sizeof(IntegerIndexData)) or + // (sizeof(IntegerIndexData), 0). True, otherwise. + bool SetStartByteOffset(PostingListUsed* posting_list_used, + uint32_t offset) const; + + // Helper for MoveFrom/GetData/PopFrontData. Adds limit number of data to out + // or all data in the posting list if the posting list contains less than + // limit number of data. out can be NULL. + // + // NOTE: If called with limit=1, pop=true on a posting list that transitioned + // from NOT_FULL directly to FULL, GetDataInternal will not return the posting + // list to NOT_FULL. Instead it will leave it in a valid state, but it will be + // ALMOST_FULL. + // + // RETURNS: + // - OK on success + // - INTERNAL_ERROR if the posting list has been corrupted somehow. + libtextclassifier3::Status GetDataInternal( + const PostingListUsed* posting_list_used, uint32_t limit, bool pop, + std::vector<IntegerIndexData>* out) const; + + // Retrieves the value stored in the index-th special data. + // + // REQUIRES: + // 0 <= index < kNumSpecialData. + // + // RETURNS: + // - A valid SpecialData<IntegerIndexData>. + SpecialDataType GetSpecialData(const PostingListUsed* posting_list_used, + uint32_t index) const; + + // Sets the value stored in the index-th special data to special_data. + // + // REQUIRES: + // 0 <= index < kNumSpecialData. + void SetSpecialData(PostingListUsed* posting_list_used, uint32_t index, + const SpecialDataType& special_data) const; + + // Prepends data to the memory region [offset - sizeof(IntegerIndexData), + // offset - 1] and returns the new beginning of the region. + // + // RETURNS: + // - The new beginning of the padded region, if successful. + // - INVALID_ARGUMENT if data will not fit (uncompressed) between + // [kSpecialDataSize, offset - 1] + libtextclassifier3::StatusOr<uint32_t> PrependDataUncompressed( + PostingListUsed* posting_list_used, const IntegerIndexData& data, + uint32_t offset) const; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_NUMERIC_POSTING_LIST_USED_INTEGER_INDEX_DATA_SERIALIZER_H_ diff --git a/icing/index/numeric/posting-list-used-integer-index-data-serializer_test.cc b/icing/index/numeric/posting-list-used-integer-index-data-serializer_test.cc new file mode 100644 index 0000000..c270137 --- /dev/null +++ b/icing/index/numeric/posting-list-used-integer-index-data-serializer_test.cc @@ -0,0 +1,523 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/numeric/posting-list-used-integer-index-data-serializer.h" + +#include <memory> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/posting_list/posting-list-used.h" +#include "icing/index/numeric/integer-index-data.h" +#include "icing/testing/common-matchers.h" + +using testing::ElementsAre; +using testing::ElementsAreArray; +using testing::Eq; +using testing::IsEmpty; + +namespace icing { +namespace lib { + +namespace { + +// TODO(b/259743562): [Optimization 2] update unit tests after applying +// compression. Remember to create varint/delta encoding +// overflow (which causes state NOT_FULL -> FULL directly +// without ALMOST_FULL) test cases, including for +// PopFrontData. + +TEST(PostingListUsedIntegerIndexDataSerializerTest, + GetMinPostingListSizeToFitNotNull) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 2551 * sizeof(IntegerIndexData); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + + ASSERT_THAT(serializer.PrependData( + &pl_used, IntegerIndexData(/*section_id=*/0, + /*document_id=*/0, /*key=*/2)), + IsOk()); + EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used), + Eq(2 * sizeof(IntegerIndexData))); + + ASSERT_THAT(serializer.PrependData( + &pl_used, IntegerIndexData(/*section_id=*/0, + /*document_id=*/1, /*key=*/5)), + IsOk()); + EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used), + Eq(3 * sizeof(IntegerIndexData))); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, + GetMinPostingListSizeToFitAlmostFull) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 3 * sizeof(IntegerIndexData); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + + ASSERT_THAT(serializer.PrependData( + &pl_used, IntegerIndexData(/*section_id=*/0, + /*document_id=*/0, /*key=*/2)), + IsOk()); + ASSERT_THAT(serializer.PrependData( + &pl_used, IntegerIndexData(/*section_id=*/0, + /*document_id=*/1, /*key=*/5)), + IsOk()); + EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used), Eq(size)); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, + GetMinPostingListSizeToFitFull) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 3 * sizeof(IntegerIndexData); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + + ASSERT_THAT(serializer.PrependData( + &pl_used, IntegerIndexData(/*section_id=*/0, + /*document_id=*/0, /*key=*/2)), + IsOk()); + ASSERT_THAT(serializer.PrependData( + &pl_used, IntegerIndexData(/*section_id=*/0, + /*document_id=*/1, /*key=*/5)), + IsOk()); + ASSERT_THAT(serializer.PrependData( + &pl_used, IntegerIndexData(/*section_id=*/0, + /*document_id=*/2, /*key=*/0)), + IsOk()); + EXPECT_THAT(serializer.GetMinPostingListSizeToFit(&pl_used), Eq(size)); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, PrependDataNotFull) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 2551 * sizeof(IntegerIndexData); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + + // Make used. + IntegerIndexData data0(/*section_id=*/0, /*document_id=*/0, /*key=*/2); + EXPECT_THAT(serializer.PrependData(&pl_used, data0), IsOk()); + // Size = sizeof(uncompressed data0) + int expected_size = sizeof(IntegerIndexData); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetData(&pl_used), IsOkAndHolds(ElementsAre(data0))); + + IntegerIndexData data1(/*section_id=*/0, /*document_id=*/1, /*key=*/5); + EXPECT_THAT(serializer.PrependData(&pl_used, data1), IsOk()); + // Size = sizeof(uncompressed data1) + // + sizeof(uncompressed data0) + expected_size += sizeof(IntegerIndexData); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAre(data1, data0))); + + IntegerIndexData data2(/*section_id=*/0, /*document_id=*/2, /*key=*/0); + EXPECT_THAT(serializer.PrependData(&pl_used, data2), IsOk()); + // Size = sizeof(uncompressed data2) + // + sizeof(uncompressed data1) + // + sizeof(uncompressed data0) + expected_size += sizeof(IntegerIndexData); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAre(data2, data1, data0))); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, PrependDataAlmostFull) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 4 * sizeof(IntegerIndexData); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + + // Fill up the compressed region. + // Transitions: + // Adding data0: EMPTY -> NOT_FULL + // Adding data1: NOT_FULL -> NOT_FULL + IntegerIndexData data0(/*section_id=*/0, /*document_id=*/0, /*key=*/2); + IntegerIndexData data1(/*section_id=*/0, /*document_id=*/1, /*key=*/5); + EXPECT_THAT(serializer.PrependData(&pl_used, data0), IsOk()); + EXPECT_THAT(serializer.PrependData(&pl_used, data1), IsOk()); + int expected_size = 2 * sizeof(IntegerIndexData); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAre(data1, data0))); + + // Add one more data to transition NOT_FULL -> ALMOST_FULL + IntegerIndexData data2(/*section_id=*/0, /*document_id=*/2, /*key=*/0); + EXPECT_THAT(serializer.PrependData(&pl_used, data2), IsOk()); + expected_size = 3 * sizeof(IntegerIndexData); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAre(data2, data1, data0))); + + // Add one more data to transition ALMOST_FULL -> FULL + IntegerIndexData data3(/*section_id=*/0, /*document_id=*/3, /*key=*/-3); + EXPECT_THAT(serializer.PrependData(&pl_used, data3), IsOk()); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(size)); + EXPECT_THAT(serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAre(data3, data2, data1, data0))); + + // The posting list is FULL. Adding another data should fail. + IntegerIndexData data4(/*section_id=*/0, /*document_id=*/4, /*key=*/100); + EXPECT_THAT(serializer.PrependData(&pl_used, data4), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, + PrependDataPostingListUsedMinSize) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = serializer.GetMinPostingListSize(); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + + // PL State: EMPTY + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(0)); + EXPECT_THAT(serializer.GetData(&pl_used), IsOkAndHolds(IsEmpty())); + + // Add a data. PL should shift to ALMOST_FULL state + IntegerIndexData data0(/*section_id=*/0, /*document_id=*/0, /*key=*/2); + EXPECT_THAT(serializer.PrependData(&pl_used, data0), IsOk()); + // Size = sizeof(uncompressed data0) + int expected_size = sizeof(IntegerIndexData); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetData(&pl_used), IsOkAndHolds(ElementsAre(data0))); + + // Add another data. PL should shift to FULL state. + IntegerIndexData data1(/*section_id=*/0, /*document_id=*/1, /*key=*/5); + EXPECT_THAT(serializer.PrependData(&pl_used, data1), IsOk()); + // Size = sizeof(uncompressed data1) + sizeof(uncompressed data0) + expected_size += sizeof(IntegerIndexData); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), Eq(expected_size)); + EXPECT_THAT(serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAre(data1, data0))); + + // The posting list is FULL. Adding another data should fail. + IntegerIndexData data2(/*section_id=*/0, /*document_id=*/2, /*key=*/0); + EXPECT_THAT(serializer.PrependData(&pl_used, data2), + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, + PrependDataArrayDoNotKeepPrepended) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 6 * sizeof(IntegerIndexData); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + + std::vector<IntegerIndexData> data_in; + std::vector<IntegerIndexData> data_pushed; + + // Add 3 data. The PL is in the empty state and should be able to fit all 3 + // data without issue, transitioning the PL from EMPTY -> NOT_FULL. + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2)); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5)); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0)); + EXPECT_THAT( + serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(), + /*keep_prepended=*/false), + Eq(data_in.size())); + std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), + Eq(data_pushed.size() * sizeof(IntegerIndexData))); + EXPECT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend()))); + + // Add 2 data. The PL should transition from NOT_FULL to ALMOST_FULL. + data_in.clear(); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/3, /*key=*/-3)); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/4, /*key=*/100)); + EXPECT_THAT( + serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(), + /*keep_prepended=*/false), + Eq(data_in.size())); + std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), + Eq(data_pushed.size() * sizeof(IntegerIndexData))); + EXPECT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend()))); + + // Add 2 data. The PL should remain ALMOST_FULL since the remaining space can + // only fit 1 data. + data_in.clear(); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/5, /*key=*/-200)); + data_in.push_back(IntegerIndexData(/*section_id=*/0, /*document_id=*/6, + /*key=*/2147483647)); + EXPECT_THAT( + serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(), + /*keep_prepended=*/false), + Eq(0)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), + Eq(data_pushed.size() * sizeof(IntegerIndexData))); + EXPECT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend()))); + + // Add 1 data. The PL should transition from ALMOST_FULL to FULL. + data_in.resize(1); + EXPECT_THAT( + serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(), + /*keep_prepended=*/false), + Eq(data_in.size())); + std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), + Eq(data_pushed.size() * sizeof(IntegerIndexData))); + EXPECT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend()))); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, + PrependDataArrayKeepPrepended) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 6 * sizeof(IntegerIndexData); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + + std::vector<IntegerIndexData> data_in; + std::vector<IntegerIndexData> data_pushed; + + // Add 3 data. The PL is in the empty state and should be able to fit all 3 + // data without issue, transitioning the PL from EMPTY -> NOT_FULL. + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2)); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5)); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0)); + EXPECT_THAT( + serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(), + /*keep_prepended=*/true), + Eq(data_in.size())); + std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), + Eq(data_pushed.size() * sizeof(IntegerIndexData))); + EXPECT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend()))); + + // Add 4 data. The PL should prepend 3 data and transition from NOT_FULL to + // FULL. + data_in.clear(); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/3, /*key=*/-3)); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/4, /*key=*/100)); + data_in.push_back( + IntegerIndexData(/*section_id=*/0, /*document_id=*/5, /*key=*/-200)); + data_in.push_back(IntegerIndexData(/*section_id=*/0, /*document_id=*/6, + /*key=*/2147483647)); + EXPECT_THAT( + serializer.PrependDataArray(&pl_used, data_in.data(), data_in.size(), + /*keep_prepended=*/true), + Eq(3)); + data_in.resize(3); + std::move(data_in.begin(), data_in.end(), std::back_inserter(data_pushed)); + EXPECT_THAT(serializer.GetBytesUsed(&pl_used), + Eq(data_pushed.size() * sizeof(IntegerIndexData))); + EXPECT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_pushed.rbegin(), data_pushed.rend()))); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, MoveFrom) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 3 * serializer.GetMinPostingListSize(); + std::unique_ptr<char[]> buf1 = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf1.get()), size)); + + std::vector<IntegerIndexData> data_arr1 = { + IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2), + IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5)}; + ASSERT_THAT( + serializer.PrependDataArray(&pl_used1, data_arr1.data(), data_arr1.size(), + /*keep_prepended=*/false), + Eq(data_arr1.size())); + + std::unique_ptr<char[]> buf2 = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf2.get()), size)); + std::vector<IntegerIndexData> data_arr2 = { + IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0), + IntegerIndexData(/*section_id=*/0, /*document_id=*/3, /*key=*/-3), + IntegerIndexData(/*section_id=*/0, /*document_id=*/4, /*key=*/100), + IntegerIndexData(/*section_id=*/0, /*document_id=*/5, /*key=*/-200)}; + ASSERT_THAT( + serializer.PrependDataArray(&pl_used2, data_arr2.data(), data_arr2.size(), + /*keep_prepended=*/false), + Eq(data_arr2.size())); + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + IsOk()); + EXPECT_THAT( + serializer.GetData(&pl_used2), + IsOkAndHolds(ElementsAreArray(data_arr1.rbegin(), data_arr1.rend()))); + EXPECT_THAT(serializer.GetData(&pl_used1), IsOkAndHolds(IsEmpty())); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, + MoveToNullReturnsFailedPrecondition) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 3 * serializer.GetMinPostingListSize(); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + std::vector<IntegerIndexData> data_arr = { + IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2), + IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5)}; + ASSERT_THAT( + serializer.PrependDataArray(&pl_used, data_arr.data(), data_arr.size(), + /*keep_prepended=*/false), + Eq(data_arr.size())); + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used, /*src=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_arr.rbegin(), data_arr.rend()))); + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/nullptr, /*src=*/&pl_used), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_arr.rbegin(), data_arr.rend()))); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, MoveToPostingListTooSmall) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size1 = 3 * serializer.GetMinPostingListSize(); + std::unique_ptr<char[]> buf1 = std::make_unique<char[]>(size1); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used1, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf1.get()), size1)); + std::vector<IntegerIndexData> data_arr1 = { + IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2), + IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5), + IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0), + IntegerIndexData(/*section_id=*/0, /*document_id=*/3, /*key=*/-3), + IntegerIndexData(/*section_id=*/0, /*document_id=*/4, /*key=*/100)}; + ASSERT_THAT( + serializer.PrependDataArray(&pl_used1, data_arr1.data(), data_arr1.size(), + /*keep_prepended=*/false), + Eq(data_arr1.size())); + + int size2 = serializer.GetMinPostingListSize(); + std::unique_ptr<char[]> buf2 = std::make_unique<char[]>(size2); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used2, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf2.get()), size2)); + std::vector<IntegerIndexData> data_arr2 = { + IntegerIndexData(/*section_id=*/0, /*document_id=*/5, /*key=*/-200)}; + ASSERT_THAT( + serializer.PrependDataArray(&pl_used2, data_arr2.data(), data_arr2.size(), + /*keep_prepended=*/false), + Eq(data_arr2.size())); + + EXPECT_THAT(serializer.MoveFrom(/*dst=*/&pl_used2, /*src=*/&pl_used1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT( + serializer.GetData(&pl_used1), + IsOkAndHolds(ElementsAreArray(data_arr1.rbegin(), data_arr1.rend()))); + EXPECT_THAT( + serializer.GetData(&pl_used2), + IsOkAndHolds(ElementsAreArray(data_arr2.rbegin(), data_arr2.rend()))); +} + +TEST(PostingListUsedIntegerIndexDataSerializerTest, PopFrontData) { + PostingListUsedIntegerIndexDataSerializer serializer; + + int size = 2 * serializer.GetMinPostingListSize(); + std::unique_ptr<char[]> buf = std::make_unique<char[]>(size); + ICING_ASSERT_OK_AND_ASSIGN( + PostingListUsed pl_used, + PostingListUsed::CreateFromUnitializedRegion( + &serializer, static_cast<void*>(buf.get()), size)); + + std::vector<IntegerIndexData> data_arr = { + IntegerIndexData(/*section_id=*/0, /*document_id=*/0, /*key=*/2), + IntegerIndexData(/*section_id=*/0, /*document_id=*/1, /*key=*/5), + IntegerIndexData(/*section_id=*/0, /*document_id=*/2, /*key=*/0)}; + ASSERT_THAT( + serializer.PrependDataArray(&pl_used, data_arr.data(), data_arr.size(), + /*keep_prepended=*/false), + Eq(data_arr.size())); + ASSERT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_arr.rbegin(), data_arr.rend()))); + + // Now, pop the last data. The posting list should contain the first three + // data. + EXPECT_THAT(serializer.PopFrontData(&pl_used, /*num_data=*/1), IsOk()); + data_arr.pop_back(); + EXPECT_THAT( + serializer.GetData(&pl_used), + IsOkAndHolds(ElementsAreArray(data_arr.rbegin(), data_arr.rend()))); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/index/section-indexing-handler.h b/icing/index/section-indexing-handler.h new file mode 100644 index 0000000..ff461cb --- /dev/null +++ b/icing/index/section-indexing-handler.h @@ -0,0 +1,60 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_SECTION_INDEXING_HANDLER_H_ +#define ICING_INDEX_SECTION_INDEXING_HANDLER_H_ + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/proto/logging.pb.h" +#include "icing/store/document-id.h" +#include "icing/util/clock.h" +#include "icing/util/tokenized-document.h" + +namespace icing { +namespace lib { + +// Parent class for indexing different types of sections in TokenizedDocument. +class SectionIndexingHandler { + public: + explicit SectionIndexingHandler(const Clock* clock) : clock_(*clock) {} + + virtual ~SectionIndexingHandler() = default; + + // Handles the indexing process: add data (hits) into the specific type index + // (e.g. string index, integer index) for all contents in the corresponding + // type of sections in tokenized_document. + // For example, IntegerSectionIndexingHandler::Handle should add data into + // integer index for all contents in tokenized_document.integer_sections. + // + // tokenized_document: document object with different types of tokenized + // sections. + // document_id: id of the document. + // put_document_stats: object for collecting stats during indexing. It can be + // nullptr. + // + /// Returns: + // - OK on success + // - Any other errors. It depends on each implementation. + virtual libtextclassifier3::Status Handle( + const TokenizedDocument& tokenized_document, DocumentId document_id, + PutDocumentStatsProto* put_document_stats) = 0; + + protected: + const Clock& clock_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_SECTION_INDEXING_HANDLER_H_ diff --git a/icing/index/string-section-indexing-handler.cc b/icing/index/string-section-indexing-handler.cc new file mode 100644 index 0000000..9b1db7e --- /dev/null +++ b/icing/index/string-section-indexing-handler.cc @@ -0,0 +1,146 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/index/string-section-indexing-handler.h" + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/index.h" +#include "icing/legacy/core/icing-string-util.h" +#include "icing/proto/logging.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" +#include "icing/transform/normalizer.h" +#include "icing/util/clock.h" +#include "icing/util/tokenized-document.h" + +namespace icing { +namespace lib { + +libtextclassifier3::Status StringSectionIndexingHandler::Handle( + const TokenizedDocument& tokenized_document, DocumentId document_id, + PutDocumentStatsProto* put_document_stats) { + std::unique_ptr<Timer> index_timer = clock_.GetNewTimer(); + + if (index_.last_added_document_id() != kInvalidDocumentId && + document_id <= index_.last_added_document_id()) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "DocumentId %d must be greater than last added document_id %d", + document_id, index_.last_added_document_id())); + } + // TODO(b/259744228): revisit last_added_document_id with numeric index for + // index rebuilding before rollout. + index_.set_last_added_document_id(document_id); + uint32_t num_tokens = 0; + libtextclassifier3::Status status; + for (const TokenizedSection& section : + tokenized_document.tokenized_string_sections()) { + if (section.metadata.tokenizer == + StringIndexingConfig::TokenizerType::NONE) { + ICING_LOG(WARNING) + << "Unexpected TokenizerType::NONE found when indexing document."; + } + // TODO(b/152934343): pass real namespace ids in + Index::Editor editor = + index_.Edit(document_id, section.metadata.id, + section.metadata.term_match_type, /*namespace_id=*/0); + for (std::string_view token : section.token_sequence) { + ++num_tokens; + + switch (section.metadata.tokenizer) { + case StringIndexingConfig::TokenizerType::VERBATIM: + // data() is safe to use here because a token created from the + // VERBATIM tokenizer is the entire string value. The character at + // data() + token.length() is guaranteed to be a null char. + status = editor.BufferTerm(token.data()); + break; + case StringIndexingConfig::TokenizerType::NONE: + [[fallthrough]]; + case StringIndexingConfig::TokenizerType::RFC822: + [[fallthrough]]; + case StringIndexingConfig::TokenizerType::URL: + [[fallthrough]]; + case StringIndexingConfig::TokenizerType::PLAIN: + std::string normalized_term = normalizer_.NormalizeTerm(token); + status = editor.BufferTerm(normalized_term.c_str()); + } + + if (!status.ok()) { + // We've encountered a failure. Bail out. We'll mark this doc as deleted + // and signal a failure to the client. + ICING_LOG(WARNING) << "Failed to buffer term in lite lexicon due to: " + << status.error_message(); + break; + } + } + if (!status.ok()) { + break; + } + // Add all the seen terms to the index with their term frequency. + status = editor.IndexAllBufferedTerms(); + if (!status.ok()) { + ICING_LOG(WARNING) << "Failed to add hits in lite index due to: " + << status.error_message(); + break; + } + } + + if (put_document_stats != nullptr) { + // TODO(b/259744228): switch to set individual index latency. + put_document_stats->set_index_latency_ms( + index_timer->GetElapsedMilliseconds()); + put_document_stats->mutable_tokenization_stats()->set_num_tokens_indexed( + num_tokens); + } + + // If we're either successful or we've hit resource exhausted, then attempt a + // merge. + if ((status.ok() || absl_ports::IsResourceExhausted(status)) && + index_.WantsMerge()) { + ICING_LOG(ERROR) << "Merging the index at docid " << document_id << "."; + + std::unique_ptr<Timer> merge_timer = clock_.GetNewTimer(); + libtextclassifier3::Status merge_status = index_.Merge(); + + if (!merge_status.ok()) { + ICING_LOG(ERROR) << "Index merging failed. Clearing index."; + if (!index_.Reset().ok()) { + return absl_ports::InternalError(IcingStringUtil::StringPrintf( + "Unable to reset to clear index after merge failure. Merge " + "failure=%d:%s", + merge_status.error_code(), merge_status.error_message().c_str())); + } else { + return absl_ports::DataLossError(IcingStringUtil::StringPrintf( + "Forced to reset index after merge failure. Merge failure=%d:%s", + merge_status.error_code(), merge_status.error_message().c_str())); + } + } + + if (put_document_stats != nullptr) { + put_document_stats->set_index_merge_latency_ms( + merge_timer->GetElapsedMilliseconds()); + } + } + + return status; +} + +} // namespace lib +} // namespace icing diff --git a/icing/index/string-section-indexing-handler.h b/icing/index/string-section-indexing-handler.h new file mode 100644 index 0000000..4906f97 --- /dev/null +++ b/icing/index/string-section-indexing-handler.h @@ -0,0 +1,67 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_INDEX_STRING_SECTION_INDEXING_HANDLER_H_ +#define ICING_INDEX_STRING_SECTION_INDEXING_HANDLER_H_ + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/index/index.h" +#include "icing/index/section-indexing-handler.h" +#include "icing/proto/logging.pb.h" +#include "icing/store/document-id.h" +#include "icing/transform/normalizer.h" +#include "icing/util/clock.h" +#include "icing/util/tokenized-document.h" + +namespace icing { +namespace lib { + +class StringSectionIndexingHandler : public SectionIndexingHandler { + public: + explicit StringSectionIndexingHandler(const Clock* clock, + const Normalizer* normalizer, + Index* index) + : SectionIndexingHandler(clock), + normalizer_(*normalizer), + index_(*index) {} + + ~StringSectionIndexingHandler() override = default; + + // Handles the string indexing process: add hits into the lite index for all + // contents in tokenized_document.tokenized_string_sections and merge lite + // index into main index if necessary. + // + /// Returns: + // - OK on success + // - INVALID_ARGUMENT_ERROR if document_id is less than the document_id of a + // previously indexed document. + // - RESOURCE_EXHAUSTED_ERROR if the index is full and can't add anymore + // content. + // - DATA_LOSS_ERROR if an attempt to merge the index fails and both indices + // are cleared as a result. + // - INTERNAL_ERROR if any other errors occur. + // - Any main/lite index errors. + libtextclassifier3::Status Handle( + const TokenizedDocument& tokenized_document, DocumentId document_id, + PutDocumentStatsProto* put_document_stats) override; + + private: + const Normalizer& normalizer_; + Index& index_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_INDEX_STRING_SECTION_INDEXING_HANDLER_H_ diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc index 283c6f5..9a7df38 100644 --- a/icing/jni/icing-search-engine-jni.cc +++ b/icing/jni/icing-search-engine-jni.cc @@ -83,7 +83,7 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { } JNIEXPORT jlong JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeCreate( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeCreate( JNIEnv* env, jclass clazz, jbyteArray icing_search_engine_options_bytes) { icing::lib::IcingSearchEngineOptions options; if (!ParseProtoFromJniByteArray(env, icing_search_engine_options_bytes, @@ -103,7 +103,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeCreate( } JNIEXPORT void JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeDestroy( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeDestroy( JNIEnv* env, jclass clazz, jobject object) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -111,7 +111,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDestroy( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeInitialize( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeInitialize( JNIEnv* env, jclass clazz, jobject object) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -123,7 +123,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeInitialize( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeSetSchema( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetSchema( JNIEnv* env, jclass clazz, jobject object, jbyteArray schema_bytes, jboolean ignore_errors_and_delete_documents) { icing::lib::IcingSearchEngine* icing = @@ -142,7 +142,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSetSchema( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeGetSchema( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetSchema( JNIEnv* env, jclass clazz, jobject object) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -153,7 +153,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetSchema( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeGetSchemaType( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetSchemaType( JNIEnv* env, jclass clazz, jobject object, jstring schema_type) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -166,7 +166,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetSchemaType( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativePut( +Java_com_google_android_icing_IcingSearchEngineImpl_nativePut( JNIEnv* env, jclass clazz, jobject object, jbyteArray document_bytes) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -184,7 +184,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativePut( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeGet( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeGet( JNIEnv* env, jclass clazz, jobject object, jstring name_space, jstring uri, jbyteArray result_spec_bytes) { icing::lib::IcingSearchEngine* icing = @@ -205,7 +205,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGet( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeReportUsage( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeReportUsage( JNIEnv* env, jclass clazz, jobject object, jbyteArray usage_report_bytes) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -223,7 +223,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeReportUsage( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeGetAllNamespaces( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetAllNamespaces( JNIEnv* env, jclass clazz, jobject object) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -235,7 +235,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetAllNamespaces( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeGetNextPage( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetNextPage( JNIEnv* env, jclass clazz, jobject object, jlong next_page_token, jlong java_to_native_start_timestamp_ms) { icing::lib::IcingSearchEngine* icing = @@ -252,13 +252,14 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetNextPage( icing::lib::QueryStatsProto* query_stats = next_page_result_proto.mutable_query_stats(); query_stats->set_java_to_native_jni_latency_ms(java_to_native_jni_latency_ms); - query_stats->set_native_to_java_start_timestamp_ms(clock->GetSystemTimeMilliseconds()); + query_stats->set_native_to_java_start_timestamp_ms( + clock->GetSystemTimeMilliseconds()); return SerializeProtoToJniByteArray(env, next_page_result_proto); } JNIEXPORT void JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeInvalidateNextPageToken( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeInvalidateNextPageToken( JNIEnv* env, jclass clazz, jobject object, jlong next_page_token) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -269,7 +270,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeInvalidateNextPageToken( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeSearch( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearch( JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes, jbyteArray scoring_spec_bytes, jbyteArray result_spec_bytes, jlong java_to_native_start_timestamp_ms) { @@ -306,13 +307,14 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSearch( icing::lib::QueryStatsProto* query_stats = search_result_proto.mutable_query_stats(); query_stats->set_java_to_native_jni_latency_ms(java_to_native_jni_latency_ms); - query_stats->set_native_to_java_start_timestamp_ms(clock->GetSystemTimeMilliseconds()); + query_stats->set_native_to_java_start_timestamp_ms( + clock->GetSystemTimeMilliseconds()); return SerializeProtoToJniByteArray(env, search_result_proto); } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeDelete( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeDelete( JNIEnv* env, jclass clazz, jobject object, jstring name_space, jstring uri) { icing::lib::IcingSearchEngine* icing = @@ -327,7 +329,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDelete( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByNamespace( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByNamespace( JNIEnv* env, jclass clazz, jobject object, jstring name_space) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -340,7 +342,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByNamespace( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeDeleteBySchemaType( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteBySchemaType( JNIEnv* env, jclass clazz, jobject object, jstring schema_type) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -353,7 +355,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDeleteBySchemaType( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByQuery( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByQuery( JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes, jboolean return_deleted_document_info) { icing::lib::IcingSearchEngine* icing = @@ -371,7 +373,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByQuery( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativePersistToDisk( +Java_com_google_android_icing_IcingSearchEngineImpl_nativePersistToDisk( JNIEnv* env, jclass clazz, jobject object, jint persist_type_code) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -390,7 +392,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativePersistToDisk( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeOptimize( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeOptimize( JNIEnv* env, jclass clazz, jobject object) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -401,7 +403,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeOptimize( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeGetOptimizeInfo( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetOptimizeInfo( JNIEnv* env, jclass clazz, jobject object) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -413,7 +415,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetOptimizeInfo( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeGetStorageInfo( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetStorageInfo( JNIEnv* env, jclass clazz, jobject object) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -425,7 +427,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetStorageInfo( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeReset( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeReset( JNIEnv* env, jclass clazz, jobject object) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -436,7 +438,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeReset( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearchSuggestions( JNIEnv* env, jclass clazz, jobject object, jbyteArray suggestion_spec_bytes) { icing::lib::IcingSearchEngine* icing = @@ -455,7 +457,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions( } JNIEXPORT jbyteArray JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeGetDebugInfo( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetDebugInfo( JNIEnv* env, jclass clazz, jobject object, jint verbosity) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -473,7 +475,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetDebugInfo( } JNIEXPORT jboolean JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeShouldLog( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeShouldLog( JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) { if (!icing::lib::LogSeverity::Code_IsValid(severity)) { ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity; @@ -484,7 +486,7 @@ Java_com_google_android_icing_IcingSearchEngine_nativeShouldLog( } JNIEXPORT jboolean JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeSetLoggingLevel( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetLoggingLevel( JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) { if (!icing::lib::LogSeverity::Code_IsValid(severity)) { ICING_LOG(ERROR) << "Invalid value for logging severity: " << severity; @@ -495,8 +497,215 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSetLoggingLevel( } JNIEXPORT jstring JNICALL -Java_com_google_android_icing_IcingSearchEngine_nativeGetLoggingTag( +Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetLoggingTag( JNIEnv* env, jclass clazz) { return env->NewStringUTF(icing::lib::kIcingLoggingTag); } + +// TODO(b/240333360) Remove the methods below for IcingSearchEngine once we have +// a sync from Jetpack to g3 to contain the refactored IcingSearchEngine(with +// IcingSearchEngineImpl). +JNIEXPORT jlong JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeCreate( + JNIEnv* env, jclass clazz, jbyteArray icing_search_engine_options_bytes) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeCreate( + env, clazz, icing_search_engine_options_bytes); +} + +JNIEXPORT void JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeDestroy(JNIEnv* env, + jclass clazz, + jobject object) { + Java_com_google_android_icing_IcingSearchEngineImpl_nativeDestroy(env, clazz, + object); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeInitialize( + JNIEnv* env, jclass clazz, jobject object) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeInitialize( + env, clazz, object); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeSetSchema( + JNIEnv* env, jclass clazz, jobject object, jbyteArray schema_bytes, + jboolean ignore_errors_and_delete_documents) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetSchema( + env, clazz, object, schema_bytes, ignore_errors_and_delete_documents); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetSchema( + JNIEnv* env, jclass clazz, jobject object) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetSchema( + env, clazz, object); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetSchemaType( + JNIEnv* env, jclass clazz, jobject object, jstring schema_type) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetSchemaType( + env, clazz, object, schema_type); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativePut( + JNIEnv* env, jclass clazz, jobject object, jbyteArray document_bytes) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativePut( + env, clazz, object, document_bytes); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGet( + JNIEnv* env, jclass clazz, jobject object, jstring name_space, jstring uri, + jbyteArray result_spec_bytes) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGet( + env, clazz, object, name_space, uri, result_spec_bytes); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeReportUsage( + JNIEnv* env, jclass clazz, jobject object, jbyteArray usage_report_bytes) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeReportUsage( + env, clazz, object, usage_report_bytes); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetAllNamespaces( + JNIEnv* env, jclass clazz, jobject object) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetAllNamespaces( + env, clazz, object); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetNextPage( + JNIEnv* env, jclass clazz, jobject object, jlong next_page_token, + jlong java_to_native_start_timestamp_ms) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetNextPage( + env, clazz, object, next_page_token, java_to_native_start_timestamp_ms); +} + +JNIEXPORT void JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeInvalidateNextPageToken( + JNIEnv* env, jclass clazz, jobject object, jlong next_page_token) { + Java_com_google_android_icing_IcingSearchEngineImpl_nativeInvalidateNextPageToken( + env, clazz, object, next_page_token); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeSearch( + JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes, + jbyteArray scoring_spec_bytes, jbyteArray result_spec_bytes, + jlong java_to_native_start_timestamp_ms) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearch( + env, clazz, object, search_spec_bytes, scoring_spec_bytes, + result_spec_bytes, java_to_native_start_timestamp_ms); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeDelete(JNIEnv* env, + jclass clazz, + jobject object, + jstring name_space, + jstring uri) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeDelete( + env, clazz, object, name_space, uri); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByNamespace( + JNIEnv* env, jclass clazz, jobject object, jstring name_space) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByNamespace( + env, clazz, object, name_space); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeDeleteBySchemaType( + JNIEnv* env, jclass clazz, jobject object, jstring schema_type) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteBySchemaType( + env, clazz, object, schema_type); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeDeleteByQuery( + JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes, + jboolean return_deleted_document_info) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeDeleteByQuery( + env, clazz, object, search_spec_bytes, return_deleted_document_info); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativePersistToDisk( + JNIEnv* env, jclass clazz, jobject object, jint persist_type_code) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativePersistToDisk( + env, clazz, object, persist_type_code); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeOptimize(JNIEnv* env, + jclass clazz, + jobject object) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeOptimize( + env, clazz, object); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetOptimizeInfo( + JNIEnv* env, jclass clazz, jobject object) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetOptimizeInfo( + env, clazz, object); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetStorageInfo( + JNIEnv* env, jclass clazz, jobject object) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetStorageInfo( + env, clazz, object); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeReset(JNIEnv* env, + jclass clazz, + jobject object) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeReset( + env, clazz, object); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions( + JNIEnv* env, jclass clazz, jobject object, + jbyteArray suggestion_spec_bytes) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeSearchSuggestions( + env, clazz, object, suggestion_spec_bytes); +} + +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetDebugInfo( + JNIEnv* env, jclass clazz, jobject object, jint verbosity) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetDebugInfo( + env, clazz, object, verbosity); +} + +JNIEXPORT jboolean JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeShouldLog( + JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeShouldLog( + env, clazz, severity, verbosity); +} + +JNIEXPORT jboolean JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeSetLoggingLevel( + JNIEnv* env, jclass clazz, jshort severity, jshort verbosity) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeSetLoggingLevel( + env, clazz, severity, verbosity); +} + +JNIEXPORT jstring JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeGetLoggingTag( + JNIEnv* env, jclass clazz) { + return Java_com_google_android_icing_IcingSearchEngineImpl_nativeGetLoggingTag( + env, clazz); +} + } // extern "C" diff --git a/icing/join/aggregate-scorer.cc b/icing/join/aggregate-scorer.cc new file mode 100644 index 0000000..7b17482 --- /dev/null +++ b/icing/join/aggregate-scorer.cc @@ -0,0 +1,117 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/join/aggregate-scorer.h" + +#include <algorithm> +#include <memory> +#include <numeric> +#include <vector> + +#include "icing/proto/search.pb.h" +#include "icing/scoring/scored-document-hit.h" + +namespace icing { +namespace lib { + +class MinAggregateScorer : public AggregateScorer { + public: + double GetScore(const ScoredDocumentHit& parent, + const std::vector<ScoredDocumentHit>& children) override { + return std::min_element(children.begin(), children.end(), + [](const ScoredDocumentHit& lhs, + const ScoredDocumentHit& rhs) -> bool { + return lhs.score() < rhs.score(); + }) + ->score(); + } +}; + +class MaxAggregateScorer : public AggregateScorer { + public: + double GetScore(const ScoredDocumentHit& parent, + const std::vector<ScoredDocumentHit>& children) override { + return std::max_element(children.begin(), children.end(), + [](const ScoredDocumentHit& lhs, + const ScoredDocumentHit& rhs) -> bool { + return lhs.score() < rhs.score(); + }) + ->score(); + } +}; + +class AverageAggregateScorer : public AggregateScorer { + public: + double GetScore(const ScoredDocumentHit& parent, + const std::vector<ScoredDocumentHit>& children) override { + if (children.empty()) return 0.0; + return std::reduce( + children.begin(), children.end(), 0.0, + [](const double& prev, const ScoredDocumentHit& item) -> double { + return prev + item.score(); + }) / + children.size(); + } +}; + +class CountAggregateScorer : public AggregateScorer { + public: + double GetScore(const ScoredDocumentHit& parent, + const std::vector<ScoredDocumentHit>& children) override { + return children.size(); + } +}; + +class SumAggregateScorer : public AggregateScorer { + public: + double GetScore(const ScoredDocumentHit& parent, + const std::vector<ScoredDocumentHit>& children) override { + return std::reduce( + children.begin(), children.end(), 0.0, + [](const double& prev, const ScoredDocumentHit& item) -> double { + return prev + item.score(); + }); + } +}; + +class DefaultAggregateScorer : public AggregateScorer { + public: + double GetScore(const ScoredDocumentHit& parent, + const std::vector<ScoredDocumentHit>& children) override { + return parent.score(); + } +}; + +std::unique_ptr<AggregateScorer> AggregateScorer::Create( + const JoinSpecProto& join_spec) { + switch (join_spec.aggregation_score_strategy()) { + case JoinSpecProto_AggregationScore_MIN: + return std::make_unique<MinAggregateScorer>(); + case JoinSpecProto_AggregationScore_MAX: + return std::make_unique<MaxAggregateScorer>(); + case JoinSpecProto_AggregationScore_COUNT: + return std::make_unique<CountAggregateScorer>(); + case JoinSpecProto_AggregationScore_AVG: + return std::make_unique<AverageAggregateScorer>(); + case JoinSpecProto_AggregationScore_SUM: + return std::make_unique<SumAggregateScorer>(); + case JoinSpecProto_AggregationScore_UNDEFINED: + [[fallthrough]]; + default: + return std::make_unique<DefaultAggregateScorer>(); + } +} + +} // namespace lib +} // namespace icing diff --git a/icing/join/aggregate-scorer.h b/icing/join/aggregate-scorer.h new file mode 100644 index 0000000..27731b9 --- /dev/null +++ b/icing/join/aggregate-scorer.h @@ -0,0 +1,41 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_JOIN_AGGREGATE_SCORER_H_ +#define ICING_JOIN_AGGREGATE_SCORER_H_ + +#include <memory> +#include <vector> + +#include "icing/proto/search.pb.h" +#include "icing/scoring/scored-document-hit.h" + +namespace icing { +namespace lib { + +class AggregateScorer { + public: + static std::unique_ptr<AggregateScorer> Create( + const JoinSpecProto& join_spec); + + virtual ~AggregateScorer() = default; + + virtual double GetScore(const ScoredDocumentHit& parent, + const std::vector<ScoredDocumentHit>& children) = 0; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_JOIN_AGGREGATE_SCORER_H_ diff --git a/icing/join/join-processor.cc b/icing/join/join-processor.cc new file mode 100644 index 0000000..7abd821 --- /dev/null +++ b/icing/join/join-processor.cc @@ -0,0 +1,180 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/join/join-processor.h" + +#include <algorithm> +#include <functional> +#include <string_view> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/proto/scoring.pb.h" +#include "icing/proto/search.pb.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-id.h" +#include "icing/util/snippet-helpers.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>> +JoinProcessor::Join( + const JoinSpecProto& join_spec, + std::vector<ScoredDocumentHit>&& parent_scored_document_hits, + std::vector<ScoredDocumentHit>&& child_scored_document_hits) { + std::sort( + child_scored_document_hits.begin(), child_scored_document_hits.end(), + ScoredDocumentHitComparator( + /*is_descending=*/join_spec.nested_spec().scoring_spec().order_by() == + ScoringSpecProto::Order::DESC)); + + // TODO(b/256022027): + // - Aggregate scoring + // - Calculate the aggregated score if strategy is AGGREGATION_SCORING. + // - Optimization + // - Cache property to speed up property retrieval. + // - If there is no cache, then we still have the flexibility to fetch it + // from actual docs via DocumentStore. + + // Break children down into maps. The keys of this map are the DocumentIds of + // the parent docs the child ScoredDocumentHits refer to. The values in this + // map are vectors of child ScoredDocumentHits that refer to a parent + // DocumentId. + std::unordered_map<DocumentId, std::vector<ScoredDocumentHit>> + parent_to_child_map; + for (const ScoredDocumentHit& child : child_scored_document_hits) { + std::string property_content = FetchPropertyExpressionValue( + child.document_id(), join_spec.child_property_expression()); + + // Try to split the property content by separators. + std::vector<int> separators_in_property_content = + GetSeparatorLocations(property_content, "#"); + + if (separators_in_property_content.size() != 1) { + // Skip the document if the qualified id isn't made up of the namespace + // and uri. StrSplit will return just the original string if there are no + // spaces. + continue; + } + + std::string ns = + property_content.substr(0, separators_in_property_content[0]); + std::string uri = + property_content.substr(separators_in_property_content[0] + 1); + + UnescapeSeparator(ns, "#"); + UnescapeSeparator(uri, "#"); + + libtextclassifier3::StatusOr<DocumentId> doc_id_or = + doc_store_->GetDocumentId(ns, uri); + + if (!doc_id_or.ok()) { + // Skip the document if getting errors. + continue; + } + + DocumentId parent_doc_id = std::move(doc_id_or).ValueOrDie(); + + // This assumes the child docs are already sorted. + if (parent_to_child_map[parent_doc_id].size() < + join_spec.max_joined_child_count()) { + parent_to_child_map[parent_doc_id].push_back(std::move(child)); + } + } + + std::vector<JoinedScoredDocumentHit> joined_scored_document_hits; + joined_scored_document_hits.reserve(parent_scored_document_hits.size()); + + // Then add use child maps to add to parent ScoredDocumentHits. + for (ScoredDocumentHit& parent : parent_scored_document_hits) { + DocumentId parent_doc_id = kInvalidDocumentId; + if (join_spec.parent_property_expression() == kFullyQualifiedIdExpr) { + parent_doc_id = parent.document_id(); + } else { + // TODO(b/256022027): So far we only support kFullyQualifiedIdExpr for + // parent_property_expression, we could support more. + return absl_ports::UnimplementedError( + join_spec.parent_property_expression() + + " must be \"fullyQualifiedId(this)\""); + } + + // TODO(b/256022027): Derive final score from + // parent_to_child_map[parent_doc_id] and + // join_spec.aggregation_score_strategy() + double final_score = parent.score(); + joined_scored_document_hits.emplace_back( + final_score, std::move(parent), + std::vector<ScoredDocumentHit>( + std::move(parent_to_child_map[parent_doc_id]))); + } + + return joined_scored_document_hits; +} + +// This loads a document and uses a property expression to fetch the value of +// the property from the document. The property expression may refer to nested +// document properties. We do not allow for repeated values in this property +// path, as that would allow for a single document to join to multiple +// documents. +// +// Returns: +// "" on document load error. +// "" if the property path is not found in the document. +// "" if part of the property path is a repeated value. +std::string JoinProcessor::FetchPropertyExpressionValue( + const DocumentId& document_id, + const std::string& property_expression) const { + // TODO(b/256022027): Add caching of document_id -> {expression -> value} + libtextclassifier3::StatusOr<DocumentProto> document_or = + doc_store_->Get(document_id); + if (!document_or.ok()) { + // Skip the document if getting errors. + return ""; + } + + DocumentProto document = std::move(document_or).ValueOrDie(); + + return std::string(GetString(&document, property_expression)); +} + +std::vector<int> JoinProcessor::GetSeparatorLocations( + const std::string& content, const std::string& separator) const { + std::vector<int> separators_in_property_content; + + for (int i = 0; i < content.length(); ++i) { + if (content[i] == '\\') { + // Skip the following character + i++; + } else if (content[i] == '#') { + // Unescaped separator + separators_in_property_content.push_back(i); + } + } + return separators_in_property_content; +} + +void JoinProcessor::UnescapeSeparator(std::string& property, + const std::string& separator) { + size_t start_pos = 0; + while ((start_pos = property.find("\\" + separator, start_pos)) != + std::string::npos) { + property.replace(start_pos, 2, "#"); + start_pos += 1; + } +} + +} // namespace lib +} // namespace icing diff --git a/icing/join/join-processor.h b/icing/join/join-processor.h new file mode 100644 index 0000000..c919b22 --- /dev/null +++ b/icing/join/join-processor.h @@ -0,0 +1,57 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_JOIN_JOIN_PROCESSOR_H_ +#define ICING_JOIN_JOIN_PROCESSOR_H_ + +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/proto/search.pb.h" +#include "icing/scoring/scored-document-hit.h" +#include "icing/store/document-store.h" + +namespace icing { +namespace lib { + +class JoinProcessor { + public: + static constexpr std::string_view kFullyQualifiedIdExpr = + "this.fullyQualifiedId()"; + + explicit JoinProcessor(const DocumentStore* doc_store) + : doc_store_(doc_store) {} + + libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>> Join( + const JoinSpecProto& join_spec, + std::vector<ScoredDocumentHit>&& parent_scored_document_hits, + std::vector<ScoredDocumentHit>&& child_scored_document_hits); + + private: + std::string FetchPropertyExpressionValue( + const DocumentId& document_id, + const std::string& property_expression) const; + + void UnescapeSeparator(std::string& property, const std::string& separator); + + std::vector<int> GetSeparatorLocations(const std::string& content, + const std::string& separator) const; + + const DocumentStore* doc_store_; // Does not own. +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_JOIN_JOIN_PROCESSOR_H_ diff --git a/icing/monkey_test/icing-monkey-test-runner.cc b/icing/monkey_test/icing-monkey-test-runner.cc index 2dd5a03..a2a6c9b 100644 --- a/icing/monkey_test/icing-monkey-test-runner.cc +++ b/icing/monkey_test/icing-monkey-test-runner.cc @@ -40,42 +40,12 @@ using ::testing::Le; using ::testing::SizeIs; using ::testing::UnorderedElementsAreArray; -inline constexpr int kNumTypes = 30; -const std::vector<int> kPossibleNumProperties = {0, - 1, - 2, - 4, - 8, - 16, - kTotalNumSections / 2, - kTotalNumSections, - kTotalNumSections + 1, - kTotalNumSections * 2}; -inline constexpr int kNumNamespaces = 100; -inline constexpr int kNumURIs = 1000; - -// Merge per 131072 hits -const int kIndexMergeSize = 1024 * 1024; - -// An array of pairs of monkey test APIs with frequencies. -// If f_sum is the sum of all the frequencies, an operation with frequency f -// means for every f_sum iterations, the operation is expected to run f times. -const std::vector< - std::pair<std::function<void(IcingMonkeyTestRunner*)>, uint32_t>> - kMonkeyAPISchedules = {{&IcingMonkeyTestRunner::DoPut, 500}, - {&IcingMonkeyTestRunner::DoSearch, 200}, - {&IcingMonkeyTestRunner::DoGet, 70}, - {&IcingMonkeyTestRunner::DoGetAllNamespaces, 50}, - {&IcingMonkeyTestRunner::DoDelete, 50}, - {&IcingMonkeyTestRunner::DoDeleteByNamespace, 50}, - {&IcingMonkeyTestRunner::DoDeleteBySchemaType, 50}, - {&IcingMonkeyTestRunner::DoDeleteByQuery, 20}, - {&IcingMonkeyTestRunner::DoOptimize, 5}, - {&IcingMonkeyTestRunner::ReloadFromDisk, 5}}; - -SchemaProto GenerateRandomSchema(MonkeyTestRandomEngine* random) { +SchemaProto GenerateRandomSchema( + const IcingMonkeyTestRunnerConfiguration& config, + MonkeyTestRandomEngine* random) { MonkeySchemaGenerator schema_generator(random); - return schema_generator.GenerateSchema(kNumTypes, kPossibleNumProperties); + return schema_generator.GenerateSchema(config.num_types, + config.possible_num_properties); } SearchSpecProto GenerateRandomSearchSpecProto( @@ -166,18 +136,20 @@ void SortDocuments(std::vector<DocumentProto>& documents) { } // namespace -IcingMonkeyTestRunner::IcingMonkeyTestRunner(uint32_t seed) - : random_(seed), in_memory_icing_() { - ICING_LOG(INFO) << "Monkey test runner started with seed: " << seed; +IcingMonkeyTestRunner::IcingMonkeyTestRunner( + const IcingMonkeyTestRunnerConfiguration& config) + : config_(config), random_(config.seed), in_memory_icing_() { + ICING_LOG(INFO) << "Monkey test runner started with seed: " << config_.seed; - SchemaProto schema = GenerateRandomSchema(&random_); + SchemaProto schema = GenerateRandomSchema(config_, &random_); ICING_LOG(DBG) << "Schema Generated: " << schema.DebugString(); in_memory_icing_ = std::make_unique<InMemoryIcingSearchEngine>(&random_, std::move(schema)); document_generator_ = std::make_unique<MonkeyDocumentGenerator>( - &random_, in_memory_icing_->GetSchema(), kNumNamespaces, kNumURIs); + &random_, in_memory_icing_->GetSchema(), config_.possible_num_tokens_, + config_.num_namespaces, config_.num_uris); std::string dir = GetTestTempDir() + "/icing/monkey"; filesystem_.DeleteDirectoryRecursively(dir.c_str()); @@ -190,13 +162,13 @@ void IcingMonkeyTestRunner::Run(uint32_t num) { "CreateIcingSearchEngineWithSchema() first"; uint32_t frequency_sum = 0; - for (const auto& schedule : kMonkeyAPISchedules) { + for (const auto& schedule : config_.monkey_api_schedules) { frequency_sum += schedule.second; } std::uniform_int_distribution<> dist(0, frequency_sum - 1); for (; num; --num) { int p = dist(random_); - for (const auto& schedule : kMonkeyAPISchedules) { + for (const auto& schedule : config_.monkey_api_schedules) { if (p < schedule.second) { ASSERT_NO_FATAL_FAILURE(schedule.first(this)); break; @@ -404,6 +376,11 @@ void IcingMonkeyTestRunner::DoSearch() { search_result = icing_->GetNextPage(search_result.next_page_token()); ASSERT_THAT(search_result.status(), ProtoIsOk()); } + // The maximum number of scored documents allowed in Icing is 30000, in which + // case we are not able to compare the results with the in-memory Icing. + if (exp_documents.size() >= 30000) { + return; + } if (snippet_spec.num_matches_per_property() > 0) { ASSERT_THAT(num_snippeted, Eq(std::min<uint32_t>(exp_documents.size(), @@ -432,7 +409,7 @@ void IcingMonkeyTestRunner::DoOptimize() { void IcingMonkeyTestRunner::CreateIcingSearchEngine() { IcingSearchEngineOptions icing_options; - icing_options.set_index_merge_size(kIndexMergeSize); + icing_options.set_index_merge_size(config_.index_merge_size); icing_options.set_base_dir(icing_dir_->dir()); icing_ = std::make_unique<IcingSearchEngine>(icing_options); ASSERT_THAT(icing_->Initialize().status(), ProtoIsOk()); diff --git a/icing/monkey_test/icing-monkey-test-runner.h b/icing/monkey_test/icing-monkey-test-runner.h index 5f5649c..fbaaaaa 100644 --- a/icing/monkey_test/icing-monkey-test-runner.h +++ b/icing/monkey_test/icing-monkey-test-runner.h @@ -26,9 +26,42 @@ namespace icing { namespace lib { +class IcingMonkeyTestRunner; + +struct IcingMonkeyTestRunnerConfiguration { + explicit IcingMonkeyTestRunnerConfiguration(uint32_t seed, int num_types, + int num_namespaces, int num_uris, + int index_merge_size) + : seed(seed), + num_types(num_types), + num_namespaces(num_namespaces), + num_uris(num_uris), + index_merge_size(index_merge_size) {} + + uint32_t seed; + int num_types; + int num_namespaces; + int num_uris; + int index_merge_size; + + // The possible number of properties that may appear in generated schema + // types. + std::vector<int> possible_num_properties; + + // The possible number of tokens that may appear in generated documents, with + // a noise factor from 0.5 to 1 applied. + std::vector<int> possible_num_tokens_; + + // An array of pairs of monkey test APIs with frequencies. + // If f_sum is the sum of all the frequencies, an operation with frequency f + // means for every f_sum iterations, the operation is expected to run f times. + std::vector<std::pair<std::function<void(IcingMonkeyTestRunner*)>, uint32_t>> + monkey_api_schedules; +}; + class IcingMonkeyTestRunner { public: - IcingMonkeyTestRunner(uint32_t seed = std::random_device()()); + IcingMonkeyTestRunner(const IcingMonkeyTestRunnerConfiguration& config); IcingMonkeyTestRunner(const IcingMonkeyTestRunner&) = delete; IcingMonkeyTestRunner& operator=(const IcingMonkeyTestRunner&) = delete; @@ -54,6 +87,7 @@ class IcingMonkeyTestRunner { void DoOptimize(); private: + IcingMonkeyTestRunnerConfiguration config_; MonkeyTestRandomEngine random_; Filesystem filesystem_; std::unique_ptr<DestructibleDirectory> icing_dir_; diff --git a/icing/monkey_test/icing-search-engine_monkey_test.cc b/icing/monkey_test/icing-search-engine_monkey_test.cc index ad887b8..a24e57f 100644 --- a/icing/monkey_test/icing-search-engine_monkey_test.cc +++ b/icing/monkey_test/icing-search-engine_monkey_test.cc @@ -20,11 +20,71 @@ namespace icing { namespace lib { TEST(IcingSearchEngineMonkeyTest, MonkeyTest) { + IcingMonkeyTestRunnerConfiguration config( + /*seed=*/std::random_device()(), + /*num_types=*/30, + /*num_namespaces=*/100, + /*num_uris=*/1000, + /*index_merge_size=*/1024 * 1024); + config.possible_num_properties = {0, + 1, + 2, + 4, + 8, + 16, + kTotalNumSections / 2, + kTotalNumSections, + kTotalNumSections + 1, + kTotalNumSections * 2}; + config.possible_num_tokens_ = {0, 1, 4, 16, 64, 256}; + config.monkey_api_schedules = { + {&IcingMonkeyTestRunner::DoPut, 500}, + {&IcingMonkeyTestRunner::DoSearch, 200}, + {&IcingMonkeyTestRunner::DoGet, 70}, + {&IcingMonkeyTestRunner::DoGetAllNamespaces, 50}, + {&IcingMonkeyTestRunner::DoDelete, 50}, + {&IcingMonkeyTestRunner::DoDeleteByNamespace, 50}, + {&IcingMonkeyTestRunner::DoDeleteBySchemaType, 50}, + {&IcingMonkeyTestRunner::DoDeleteByQuery, 20}, + {&IcingMonkeyTestRunner::DoOptimize, 5}, + {&IcingMonkeyTestRunner::ReloadFromDisk, 5}}; uint32_t num_iterations = IsAndroidArm() ? 1000 : 5000; - IcingMonkeyTestRunner runner; + IcingMonkeyTestRunner runner(config); ASSERT_NO_FATAL_FAILURE(runner.CreateIcingSearchEngineWithSchema()); ASSERT_NO_FATAL_FAILURE(runner.Run(num_iterations)); } +TEST(DISABLED_IcingSearchEngineMonkeyTest, MonkeyManyDocTest) { + IcingMonkeyTestRunnerConfiguration config( + /*seed=*/std::random_device()(), + /*num_types=*/30, + /*num_namespaces=*/200, + /*num_uris=*/100000, + /*index_merge_size=*/1024 * 1024); + + // Due to the large amount of documents, we need to make each document smaller + // to finish the test. + config.possible_num_properties = {0, 1, 2}; + config.possible_num_tokens_ = {0, 1, 4}; + + // No deletion is performed to preserve a large number of documents. + config.monkey_api_schedules = { + {&IcingMonkeyTestRunner::DoPut, 500}, + {&IcingMonkeyTestRunner::DoSearch, 200}, + {&IcingMonkeyTestRunner::DoGet, 70}, + {&IcingMonkeyTestRunner::DoGetAllNamespaces, 50}, + {&IcingMonkeyTestRunner::DoOptimize, 5}, + {&IcingMonkeyTestRunner::ReloadFromDisk, 5}}; + IcingMonkeyTestRunner runner(config); + ASSERT_NO_FATAL_FAILURE(runner.CreateIcingSearchEngineWithSchema()); + // Pre-fill with 4 million documents + SetLoggingLevel(LogSeverity::WARNING); + for (int i = 0; i < 4000000; i++) { + ASSERT_NO_FATAL_FAILURE(runner.DoPut()); + } + SetLoggingLevel(LogSeverity::INFO); + ASSERT_NO_FATAL_FAILURE(runner.Run(1000)); +} + } // namespace lib } // namespace icing diff --git a/icing/monkey_test/monkey-test-generators.cc b/icing/monkey_test/monkey-test-generators.cc index 88fc0b6..7b2ff56 100644 --- a/icing/monkey_test/monkey-test-generators.cc +++ b/icing/monkey_test/monkey-test-generators.cc @@ -106,19 +106,8 @@ std::string MonkeyDocumentGenerator::GetUri() const { } int MonkeyDocumentGenerator::GetNumTokens() const { - std::uniform_int_distribution<> int_dist(-1, 4); - int n = int_dist(*random_); - if (n == -1) { - // 1/6 chance of getting zero token for a property - return 0; - } - if (n == 0) { - // 1/6 chance of getting one token for a property - return 1; - } - // 1/6 chance of getting one of 4, 16, 64, 256 - n = 1 << (2 * n); - + std::uniform_int_distribution<> dist(0, possible_num_tokens_.size() - 1); + int n = possible_num_tokens_[dist(*random_)]; // Add some noise std::uniform_real_distribution<> real_dist(0.5, 1); float p = real_dist(*random_); diff --git a/icing/monkey_test/monkey-test-generators.h b/icing/monkey_test/monkey-test-generators.h index 68c5e92..6349918 100644 --- a/icing/monkey_test/monkey-test-generators.h +++ b/icing/monkey_test/monkey-test-generators.h @@ -70,10 +70,12 @@ class MonkeyDocumentGenerator { public: explicit MonkeyDocumentGenerator(MonkeyTestRandomEngine* random, const SchemaProto* schema, + std::vector<int> possible_num_tokens, uint32_t num_namespaces, uint32_t num_uris = 0) : random_(random), schema_(schema), + possible_num_tokens_(std::move(possible_num_tokens)), num_namespaces_(num_namespaces), num_uris_(num_uris) {} @@ -104,6 +106,11 @@ class MonkeyDocumentGenerator { private: MonkeyTestRandomEngine* random_; // Does not own. const SchemaProto* schema_; // Does not own. + + // The possible number of tokens that may appear in generated documents, with + // a noise factor from 0.5 to 1 applied. + std::vector<int> possible_num_tokens_; + uint32_t num_namespaces_; uint32_t num_uris_; uint32_t num_docs_generated_ = 0; diff --git a/icing/portable/platform.h b/icing/portable/platform.h index 150eede..b68c026 100644 --- a/icing/portable/platform.h +++ b/icing/portable/platform.h @@ -15,6 +15,8 @@ #ifndef ICING_PORTABLE_PLATFORM_H_ #define ICING_PORTABLE_PLATFORM_H_ +#include "unicode/uvernum.h" + namespace icing { namespace lib { @@ -34,6 +36,14 @@ inline bool IsReverseJniTokenization() { return false; } +inline bool IsIcuTokenization() { + return !IsReverseJniTokenization() && !IsCfStringTokenization(); +} + +inline bool IsIcu72PlusTokenization() { + return IsIcuTokenization() && U_ICU_VERSION_MAJOR_NUM >= 72; +} + // Whether we're running on android_x86 inline bool IsAndroidX86() { #if defined(__ANDROID__) && defined(__i386__) @@ -58,6 +68,15 @@ inline bool IsIosPlatform() { return false; } +// TODO(b/259129263): verify the flag works for different platforms. +#if defined(__arm__) || defined(__i386__) +#define ICING_ARCH_BIT_32 +#elif defined(__aarch64__) || defined(__x86_64__) +#define ICING_ARCH_BIT_64 +#else +#define ICING_ARCH_BIT_UNKNOWN +#endif + enum Architecture { UNKNOWN, BIT_32, @@ -69,9 +88,9 @@ enum Architecture { // Architecture macros pulled from // https://developer.android.com/ndk/guides/cpu-features inline Architecture GetArchitecture() { -#if defined(__arm__) || defined(__i386__) +#if defined(ICING_ARCH_BIT_32) return BIT_32; -#elif defined(__aarch64__) || defined(__x86_64__) +#elif defined(ICING_ARCH_BIT_64) return BIT_64; #else return UNKNOWN; diff --git a/icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h b/icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h new file mode 100644 index 0000000..42be07d --- /dev/null +++ b/icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h @@ -0,0 +1,108 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_TEST_UTILS_H_ +#define ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_TEST_UTILS_H_ + +#include <memory> +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" + +namespace icing { +namespace lib { + +// A visitor that simply collects the nodes and flattens them in left-side +// depth-first order. +enum class NodeType { + kFunctionName, + kString, + kText, + kMember, + kFunction, + kUnaryOperator, + kNaryOperator +}; + +struct NodeInfo { + std::string value; + NodeType type; + + bool operator==(const NodeInfo& rhs) const { + return value == rhs.value && type == rhs.type; + } +}; + +MATCHER_P2(EqualsNodeInfo, value, type, "") { + if (arg.value != value || arg.type != type) { + *result_listener << "(Expected: value=\"" << value + << "\", type=" << static_cast<int>(type) + << ". Actual: value=\"" << arg.value + << "\", type=" << static_cast<int>(arg.type) << ")"; + return false; + } + return true; +} + +class SimpleVisitor : public AbstractSyntaxTreeVisitor { + public: + void VisitFunctionName(const FunctionNameNode* node) override { + nodes_.push_back({node->value(), NodeType::kFunctionName}); + } + void VisitString(const StringNode* node) override { + nodes_.push_back({node->value(), NodeType::kString}); + } + void VisitText(const TextNode* node) override { + nodes_.push_back({node->value(), NodeType::kText}); + } + void VisitMember(const MemberNode* node) override { + for (const std::unique_ptr<TextNode>& child : node->children()) { + child->Accept(this); + } + if (node->function() != nullptr) { + node->function()->Accept(this); + } + nodes_.push_back({"", NodeType::kMember}); + } + void VisitFunction(const FunctionNode* node) override { + node->function_name()->Accept(this); + for (const std::unique_ptr<Node>& arg : node->args()) { + arg->Accept(this); + } + nodes_.push_back({"", NodeType::kFunction}); + } + void VisitUnaryOperator(const UnaryOperatorNode* node) override { + node->child()->Accept(this); + nodes_.push_back({node->operator_text(), NodeType::kUnaryOperator}); + } + void VisitNaryOperator(const NaryOperatorNode* node) override { + for (const std::unique_ptr<Node>& child : node->children()) { + child->Accept(this); + } + nodes_.push_back({node->operator_text(), NodeType::kNaryOperator}); + } + + const std::vector<NodeInfo>& nodes() const { return nodes_; } + + private: + std::vector<NodeInfo> nodes_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_TEST_UTILS_H_ diff --git a/icing/query/advanced_query_parser/abstract-syntax-tree.h b/icing/query/advanced_query_parser/abstract-syntax-tree.h new file mode 100644 index 0000000..dc28ab6 --- /dev/null +++ b/icing/query/advanced_query_parser/abstract-syntax-tree.h @@ -0,0 +1,168 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_H_ +#define ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_H_ + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +namespace icing { +namespace lib { + +class FunctionNameNode; +class StringNode; +class TextNode; +class MemberNode; +class FunctionNode; +class UnaryOperatorNode; +class NaryOperatorNode; + +class AbstractSyntaxTreeVisitor { + public: + virtual ~AbstractSyntaxTreeVisitor() = default; + + virtual void VisitFunctionName(const FunctionNameNode* node) = 0; + virtual void VisitString(const StringNode* node) = 0; + virtual void VisitText(const TextNode* node) = 0; + virtual void VisitMember(const MemberNode* node) = 0; + virtual void VisitFunction(const FunctionNode* node) = 0; + virtual void VisitUnaryOperator(const UnaryOperatorNode* node) = 0; + virtual void VisitNaryOperator(const NaryOperatorNode* node) = 0; +}; + +class Node { + public: + virtual ~Node() = default; + virtual void Accept(AbstractSyntaxTreeVisitor* visitor) const = 0; +}; + +class TerminalNode : public Node { + public: + explicit TerminalNode(std::string value) : value_(std::move(value)) {} + + const std::string& value() const { return value_; } + + private: + std::string value_; +}; + +class FunctionNameNode : public TerminalNode { + public: + explicit FunctionNameNode(std::string value) + : TerminalNode(std::move(value)) {} + void Accept(AbstractSyntaxTreeVisitor* visitor) const override { + visitor->VisitFunctionName(this); + } +}; + +class StringNode : public TerminalNode { + public: + explicit StringNode(std::string value) : TerminalNode(std::move(value)) {} + void Accept(AbstractSyntaxTreeVisitor* visitor) const override { + visitor->VisitString(this); + } +}; + +class TextNode : public TerminalNode { + public: + explicit TextNode(std::string value) : TerminalNode(std::move(value)) {} + void Accept(AbstractSyntaxTreeVisitor* visitor) const override { + visitor->VisitText(this); + } +}; + +class MemberNode : public Node { + public: + explicit MemberNode(std::vector<std::unique_ptr<TextNode>> children, + std::unique_ptr<FunctionNode> function) + : children_(std::move(children)), function_(std::move(function)) {} + + void Accept(AbstractSyntaxTreeVisitor* visitor) const override { + visitor->VisitMember(this); + } + const std::vector<std::unique_ptr<TextNode>>& children() const { + return children_; + } + const FunctionNode* function() const { return function_.get(); } + + private: + std::vector<std::unique_ptr<TextNode>> children_; + // This is nullable. When it is not nullptr, this class will represent a + // function call. + std::unique_ptr<FunctionNode> function_; +}; + +class FunctionNode : public Node { + public: + explicit FunctionNode(std::unique_ptr<FunctionNameNode> function_name) + : FunctionNode(std::move(function_name), {}) {} + explicit FunctionNode(std::unique_ptr<FunctionNameNode> function_name, + std::vector<std::unique_ptr<Node>> args) + : function_name_(std::move(function_name)), args_(std::move(args)) {} + + void Accept(AbstractSyntaxTreeVisitor* visitor) const override { + visitor->VisitFunction(this); + } + const FunctionNameNode* function_name() const { return function_name_.get(); } + const std::vector<std::unique_ptr<Node>>& args() const { return args_; } + + private: + std::unique_ptr<FunctionNameNode> function_name_; + std::vector<std::unique_ptr<Node>> args_; +}; + +class UnaryOperatorNode : public Node { + public: + explicit UnaryOperatorNode(std::string operator_text, + std::unique_ptr<Node> child) + : operator_text_(std::move(operator_text)), child_(std::move(child)) {} + + void Accept(AbstractSyntaxTreeVisitor* visitor) const override { + visitor->VisitUnaryOperator(this); + } + const std::string& operator_text() const { return operator_text_; } + const Node* child() const { return child_.get(); } + + private: + std::string operator_text_; + std::unique_ptr<Node> child_; +}; + +class NaryOperatorNode : public Node { + public: + explicit NaryOperatorNode(std::string operator_text, + std::vector<std::unique_ptr<Node>> children) + : operator_text_(std::move(operator_text)), + children_(std::move(children)) {} + + void Accept(AbstractSyntaxTreeVisitor* visitor) const override { + visitor->VisitNaryOperator(this); + } + const std::string& operator_text() const { return operator_text_; } + const std::vector<std::unique_ptr<Node>>& children() const { + return children_; + } + + private: + std::string operator_text_; + std::vector<std::unique_ptr<Node>> children_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_ABSTRACT_SYNTAX_TREE_H_ diff --git a/icing/query/advanced_query_parser/abstract-syntax-tree_test.cc b/icing/query/advanced_query_parser/abstract-syntax-tree_test.cc new file mode 100644 index 0000000..a8599fd --- /dev/null +++ b/icing/query/advanced_query_parser/abstract-syntax-tree_test.cc @@ -0,0 +1,141 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" + +#include <memory> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h" + +namespace icing { +namespace lib { +namespace { + +using ::testing::ElementsAre; + +TEST(AbstractSyntaxTreeTest, Simple) { + // foo + std::unique_ptr<Node> root = std::make_unique<TextNode>("foo"); + SimpleVisitor visitor; + root->Accept(&visitor); + + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText))); +} + +TEST(AbstractSyntaxTreeTest, Composite) { + // (foo bar) OR baz + std::vector<std::unique_ptr<Node>> and_args; + and_args.push_back(std::make_unique<TextNode>("foo")); + and_args.push_back(std::make_unique<TextNode>("bar")); + auto and_node = + std::make_unique<NaryOperatorNode>("AND", std::move(and_args)); + + std::vector<std::unique_ptr<Node>> or_args; + or_args.push_back(std::move(and_node)); + or_args.push_back(std::make_unique<TextNode>("baz")); + std::unique_ptr<Node> root = + std::make_unique<NaryOperatorNode>("OR", std::move(or_args)); + + SimpleVisitor visitor; + root->Accept(&visitor); + + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("AND", NodeType::kNaryOperator), + EqualsNodeInfo("baz", NodeType::kText), + EqualsNodeInfo("OR", NodeType::kNaryOperator))); +} + +TEST(AbstractSyntaxTreeTest, Function) { + // foo() + std::unique_ptr<Node> root = + std::make_unique<FunctionNode>(std::make_unique<FunctionNameNode>("foo")); + SimpleVisitor visitor; + root->Accept(&visitor); + + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction))); + + // foo("bar") + std::vector<std::unique_ptr<Node>> args; + args.push_back(std::make_unique<StringNode>("bar")); + root = std::make_unique<FunctionNode>( + std::make_unique<FunctionNameNode>("foo"), std::move(args)); + visitor = SimpleVisitor(); + root->Accept(&visitor); + + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("bar", NodeType::kString), + EqualsNodeInfo("", NodeType::kFunction))); + + // foo(bar("baz")) + std::vector<std::unique_ptr<Node>> inner_args; + inner_args.push_back(std::make_unique<StringNode>("baz")); + args.clear(); + args.push_back(std::make_unique<FunctionNode>( + std::make_unique<FunctionNameNode>("bar"), std::move(inner_args))); + root = std::make_unique<FunctionNode>( + std::make_unique<FunctionNameNode>("foo"), std::move(args)); + visitor = SimpleVisitor(); + root->Accept(&visitor); + + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("bar", NodeType::kFunctionName), + EqualsNodeInfo("baz", NodeType::kString), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kFunction))); +} + +TEST(AbstractSyntaxTreeTest, Restriction) { + // sender.name:(IMPORTANT OR URGENT) + std::vector<std::unique_ptr<TextNode>> member_args; + member_args.push_back(std::make_unique<TextNode>("sender")); + member_args.push_back(std::make_unique<TextNode>("name")); + + std::vector<std::unique_ptr<Node>> or_args; + or_args.push_back(std::make_unique<TextNode>("IMPORTANT")); + or_args.push_back(std::make_unique<TextNode>("URGENT")); + + std::vector<std::unique_ptr<Node>> has_args; + has_args.push_back(std::make_unique<MemberNode>(std::move(member_args), + /*function=*/nullptr)); + has_args.push_back( + std::make_unique<NaryOperatorNode>("OR", std::move(or_args))); + + std::unique_ptr<Node> root = + std::make_unique<NaryOperatorNode>(":", std::move(has_args)); + + SimpleVisitor visitor; + root->Accept(&visitor); + + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("sender", NodeType::kText), + EqualsNodeInfo("name", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("IMPORTANT", NodeType::kText), + EqualsNodeInfo("URGENT", NodeType::kText), + EqualsNodeInfo("OR", NodeType::kNaryOperator), + EqualsNodeInfo(":", NodeType::kNaryOperator))); +} + +} // namespace +} // namespace lib +} // namespace icing diff --git a/icing/query/advanced_query_parser/lexer.cc b/icing/query/advanced_query_parser/lexer.cc new file mode 100644 index 0000000..18932f6 --- /dev/null +++ b/icing/query/advanced_query_parser/lexer.cc @@ -0,0 +1,228 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/advanced_query_parser/lexer.h" + +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/util/i18n-utils.h" + +namespace icing { +namespace lib { + +bool Lexer::ConsumeWhitespace() { + if (current_char_ == '\0') { + return false; + } + if (i18n_utils::IsWhitespaceAt(query_, current_index_)) { + UChar32 uchar32 = i18n_utils::GetUChar32At(query_.data(), query_.length(), + current_index_); + int length = i18n_utils::GetUtf8Length(uchar32); + Advance(length); + return true; + } + return false; +} + +bool Lexer::ConsumeQuerySingleChar() { + if (current_char_ != ':') { + return false; + } + tokens_.push_back({":", TokenType::COMPARATOR}); + Advance(); + return true; +} + +bool Lexer::ConsumeScoringSingleChar() { + switch (current_char_) { + case '+': + tokens_.push_back({"", TokenType::PLUS}); + break; + case '*': + tokens_.push_back({"", TokenType::TIMES}); + break; + case '/': + tokens_.push_back({"", TokenType::DIV}); + break; + default: + return false; + } + Advance(); + return true; +} + +bool Lexer::ConsumeGeneralSingleChar() { + switch (current_char_) { + case ',': + tokens_.push_back({"", TokenType::COMMA}); + break; + case '.': + tokens_.push_back({"", TokenType::DOT}); + break; + case '-': + tokens_.push_back({"", TokenType::MINUS}); + break; + case '(': + tokens_.push_back({"", TokenType::LPAREN}); + break; + case ')': + tokens_.push_back({"", TokenType::RPAREN}); + break; + default: + return false; + } + Advance(); + return true; +} + +bool Lexer::ConsumeSingleChar() { + if (language_ == Language::QUERY) { + if (ConsumeQuerySingleChar()) { + return true; + } + } else if (language_ == Language::SCORING) { + if (ConsumeScoringSingleChar()) { + return true; + } + } + return ConsumeGeneralSingleChar(); +} + +bool Lexer::ConsumeComparator() { + if (current_char_ != '<' && current_char_ != '>' && current_char_ != '!' && + current_char_ != '=') { + return false; + } + // Now, current_char_ must be one of '<', '>', '!', or '='. + // Matching for '<=', '>=', '!=', or '=='. + char next_char = PeekNext(1); + if (next_char == '=') { + tokens_.push_back({{current_char_, next_char}, TokenType::COMPARATOR}); + Advance(2); + return true; + } + // Now, next_char must not be '='. Let's match for '<' and '>'. + if (current_char_ == '<' || current_char_ == '>') { + tokens_.push_back({{current_char_}, TokenType::COMPARATOR}); + Advance(); + return true; + } + return false; +} + +bool Lexer::ConsumeAndOr() { + if (current_char_ != '&' && current_char_ != '|') { + return false; + } + char next_char = PeekNext(1); + if (current_char_ != next_char) { + return false; + } + if (current_char_ == '&') { + tokens_.push_back({"", TokenType::AND}); + } else { + tokens_.push_back({"", TokenType::OR}); + } + Advance(2); + return true; +} + +bool Lexer::ConsumeStringLiteral() { + if (current_char_ != '"') { + return false; + } + std::string text; + Advance(); + while (current_char_ != '\0' && current_char_ != '"') { + // When getting a backslash, we will always match the next character, even + // if the next character is a quotation mark + if (current_char_ == '\\') { + text.push_back(current_char_); + Advance(); + if (current_char_ == '\0') { + // In this case, we are missing a terminating quotation mark. + break; + } + } + text.push_back(current_char_); + Advance(); + } + if (current_char_ == '\0') { + SyntaxError("missing terminating \" character"); + return false; + } + tokens_.push_back({text, TokenType::STRING}); + Advance(); + return true; +} + +bool Lexer::Text() { + if (current_char_ == '\0') { + return false; + } + tokens_.push_back({"", TokenType::TEXT}); + int token_index = tokens_.size() - 1; + while (!ConsumeNonText() && current_char_ != '\0') { + // When getting a backslash in TEXT, unescape it by accepting its following + // character no matter which character it is, including white spaces, + // operator symbols, parentheses, etc. + if (current_char_ == '\\') { + Advance(); + if (current_char_ == '\0') { + SyntaxError("missing a escaping character after \\"); + break; + } + } + tokens_[token_index].text.push_back(current_char_); + Advance(); + if (current_char_ == '(') { + // A TEXT followed by a LPAREN is a FUNCTION_NAME. + tokens_.back().type = TokenType::FUNCTION_NAME; + // No need to break, since NonText() must be true at this point. + } + } + if (language_ == Lexer::Language::QUERY) { + std::string &text = tokens_[token_index].text; + TokenType &type = tokens_[token_index].type; + if (text == "AND") { + text.clear(); + type = TokenType::AND; + } else if (text == "OR") { + text.clear(); + type = TokenType::OR; + } else if (text == "NOT") { + text.clear(); + type = TokenType::NOT; + } + } + return true; +} + +libtextclassifier3::StatusOr<std::vector<Lexer::LexerToken>> +Lexer::ExtractTokens() { + while (current_char_ != '\0') { + // Clear out any non-text before matching a Text. + while (ConsumeNonText()) { + } + Text(); + } + if (!error_.empty()) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat("Syntax Error: ", error_)); + } + return tokens_; +} + +} // namespace lib +} // namespace icing diff --git a/icing/query/advanced_query_parser/lexer.h b/icing/query/advanced_query_parser/lexer.h new file mode 100644 index 0000000..f72affb --- /dev/null +++ b/icing/query/advanced_query_parser/lexer.h @@ -0,0 +1,153 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_LEXER_H_ +#define ICING_QUERY_ADVANCED_QUERY_PARSER_LEXER_H_ + +#include <cstdint> +#include <string> +#include <string_view> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" + +namespace icing { +namespace lib { + +class Lexer { + public: + enum class Language { QUERY, SCORING }; + + enum class TokenType { + COMMA, // ',' + DOT, // '.' + PLUS, // '+' Not allowed in QUERY language. + MINUS, // '-' + TIMES, // '*' Not allowed in QUERY language. + DIV, // '/' Not allowed in QUERY language. + LPAREN, // '(' + RPAREN, // ')' + COMPARATOR, // '<=' | '<' | '>=' | '>' | '!=' | '==' | ':' + // Not allowed in SCORING language. + AND, // 'AND' | '&&' Not allowed in SCORING language. + OR, // 'OR' | '||' Not allowed in SCORING language. + NOT, // 'NOT' Not allowed in SCORING language. + STRING, // String literal surrounded by quotation marks + TEXT, // A sequence of chars that are not any above-listed operator + FUNCTION_NAME, // A TEXT followed by LPAREN. + // Whitespaces not inside a string literal will be skipped. + // WS: " " | "\t" | "\n" | "\r" | "\f" -> skip ; + }; + + struct LexerToken { + // For STRING, text will contain the raw original text of the token + // in between quotation marks, without unescaping. + // + // For TEXT, text will contain the text of the token after unescaping all + // escaped characters. + // + // For FUNCTION_NAME, this field will contain the name of the function. + // + // For COMPARATOR, this field will contain the comparator. + // + // For other types, this field will be empty. + std::string text; + + // The type of the token. + TokenType type; + }; + + explicit Lexer(std::string_view query, Language language) + : query_(query), language_(language) { + Advance(); + } + + // Get a vector of LexerToken after lexing the query given in the constructor. + // + // Returns: + // A vector of LexerToken on success + // INVALID_ARGUMENT on syntax error. + libtextclassifier3::StatusOr<std::vector<LexerToken>> ExtractTokens(); + + private: + // Advance to current_index_ + n. + void Advance(uint32_t n = 1) { + if (current_index_ + n >= query_.size()) { + current_index_ = query_.size(); + current_char_ = '\0'; + } else { + current_index_ += n; + current_char_ = query_[current_index_]; + } + } + + // Get the character at current_index_ + n. + char PeekNext(uint32_t n = 1) { + if (current_index_ + n >= query_.size()) { + return '\0'; + } else { + return query_[current_index_ + n]; + } + } + + void SyntaxError(std::string error) { + current_index_ = query_.size(); + current_char_ = '\0'; + error_ = std::move(error); + } + + // Try to match a whitespace token and skip it. + bool ConsumeWhitespace(); + + // Try to match a single-char token other than '<' and '>'. + bool ConsumeSingleChar(); + bool ConsumeQuerySingleChar(); + bool ConsumeScoringSingleChar(); + bool ConsumeGeneralSingleChar(); + + // Try to match a comparator token other than ':'. + bool ConsumeComparator(); + + // Try to match '&&' and '||'. + // 'AND' and 'OR' will be handled in Text() instead, so that 'ANDfoo' and + // 'fooOR' is a TEXT, instead of an 'AND' or 'OR'. + bool ConsumeAndOr(); + + // Try to match a string literal. + bool ConsumeStringLiteral(); + + // Try to match a non-text. + bool ConsumeNonText() { + return ConsumeWhitespace() || ConsumeSingleChar() || + (language_ == Language::QUERY && ConsumeComparator()) || + (language_ == Language::QUERY && ConsumeAndOr()) || + ConsumeStringLiteral(); + } + + // Try to match TEXT, FUNCTION_NAME, 'AND', 'OR' and 'NOT'. + // Should make sure that NonText() is false before calling into this method. + bool Text(); + + std::string_view query_; + std::string error_; + Language language_; + int32_t current_index_ = -1; + char current_char_ = '\0'; + std::vector<LexerToken> tokens_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_LEXER_H_ diff --git a/icing/query/advanced_query_parser/lexer_fuzz_test.cc b/icing/query/advanced_query_parser/lexer_fuzz_test.cc new file mode 100644 index 0000000..f9190db --- /dev/null +++ b/icing/query/advanced_query_parser/lexer_fuzz_test.cc @@ -0,0 +1,37 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <cstdint> +#include <memory> +#include <string_view> + +#include "icing/query/advanced_query_parser/lexer.h" + +namespace icing { +namespace lib { + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + std::string_view text(reinterpret_cast<const char*>(data), size); + + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>(text, Lexer::Language::QUERY); + lexer->ExtractTokens(); + + lexer = std::make_unique<Lexer>(text, Lexer::Language::SCORING); + lexer->ExtractTokens(); + return 0; +} + +} // namespace lib +} // namespace icing diff --git a/icing/query/advanced_query_parser/lexer_test.cc b/icing/query/advanced_query_parser/lexer_test.cc new file mode 100644 index 0000000..41e78fe --- /dev/null +++ b/icing/query/advanced_query_parser/lexer_test.cc @@ -0,0 +1,613 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/advanced_query_parser/lexer.h" + +#include <memory> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +using ::testing::ElementsAre; + +MATCHER_P2(EqualsLexerToken, text, type, "") { + const Lexer::LexerToken& actual = arg; + *result_listener << "actual is {text=" << actual.text + << ", type=" << static_cast<int>(actual.type) + << "}, but expected was {text=" << text + << ", type=" << static_cast<int>(type) << "}."; + return actual.text == text && actual.type == type; +} + +MATCHER_P(EqualsLexerToken, type, "") { + const Lexer::LexerToken& actual = arg; + *result_listener << "actual is {text=" << actual.text + << ", type=" << static_cast<int>(actual.type) + << "}, but expected was {text=(empty), type=" + << static_cast<int>(type) << "}."; + return actual.text.empty() && actual.type == type; +} + +TEST(LexerTest, SimpleQuery) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("foo", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("fooAND", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("fooAND", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("ORfoo", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("ORfoo", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("fooANDbar", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("fooANDbar", + Lexer::TokenType::TEXT))); +} + +TEST(LexerTest, PrefixQuery) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("foo*", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo*", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("fooAND*", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("fooAND*", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("*ORfoo", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("*ORfoo", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("fooANDbar*", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("fooANDbar*", + Lexer::TokenType::TEXT))); +} + +TEST(LexerTest, SimpleStringQuery) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("\"foo\"", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::STRING))); + + lexer = std::make_unique<Lexer>("\"fooAND\"", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("fooAND", + Lexer::TokenType::STRING))); + + lexer = std::make_unique<Lexer>("\"ORfoo\"", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("ORfoo", Lexer::TokenType::STRING))); + + lexer = std::make_unique<Lexer>("\"fooANDbar\"", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("fooANDbar", + Lexer::TokenType::STRING))); +} + +TEST(LexerTest, TwoTermQuery) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("foo AND bar", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("bar", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("foo && bar", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("bar", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("foo&&bar", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("bar", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("foo OR \"bar\"", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::OR), + EqualsLexerToken("bar", Lexer::TokenType::STRING))); +} + +TEST(LexerTest, QueryWithSpecialSymbol) { + // With escaping + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("foo\\ \\&\\&bar", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("foo &&bar", + Lexer::TokenType::TEXT))); + lexer = std::make_unique<Lexer>("foo\\&\\&bar&&baz", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo&&bar", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("baz", Lexer::TokenType::TEXT))); + lexer = std::make_unique<Lexer>("foo\\\"", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo\"", Lexer::TokenType::TEXT))); + + // With quotation marks + lexer = std::make_unique<Lexer>("\"foo &&bar\"", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("foo &&bar", + Lexer::TokenType::STRING))); + lexer = std::make_unique<Lexer>("\"foo&&bar\"&&baz", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT( + tokens, + ElementsAre(EqualsLexerToken("foo&&bar", Lexer::TokenType::STRING), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("baz", Lexer::TokenType::TEXT))); + lexer = std::make_unique<Lexer>("\"foo\\\"\"", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("foo\\\"", + Lexer::TokenType::STRING))); +} + +TEST(LexerTest, TextInStringShouldBeOriginal) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("\"foo\\nbar\"", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, ElementsAre(EqualsLexerToken("foo\\nbar", + Lexer::TokenType::STRING))); +} + +TEST(LexerTest, QueryWithFunctionCalls) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("foo AND fun(bar)", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT( + tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("fun", Lexer::TokenType::FUNCTION_NAME), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken("bar", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::RPAREN))); + + // Not a function call + lexer = std::make_unique<Lexer>("foo AND fun (bar)", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("fun", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken("bar", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::RPAREN))); +} + +TEST(LexerTest, QueryWithComparator) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("name: foo", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("name", Lexer::TokenType::TEXT), + EqualsLexerToken(":", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("foo", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("email.name:foo", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("email", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::DOT), + EqualsLexerToken("name", Lexer::TokenType::TEXT), + EqualsLexerToken(":", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("foo", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("age > 20", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT), + EqualsLexerToken(">", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("20", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("age>=20", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT), + EqualsLexerToken(">=", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("20", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("age <20", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT), + EqualsLexerToken("<", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("20", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("age<= 20", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT), + EqualsLexerToken("<=", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("20", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("age == 20", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT), + EqualsLexerToken("==", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("20", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("age != 20", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT), + EqualsLexerToken("!=", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("20", Lexer::TokenType::TEXT))); +} + +TEST(LexerTest, ComplexQuery) { + std::unique_ptr<Lexer> lexer = std::make_unique<Lexer>( + "email.sender: (foo* AND bar OR pow(age, 2)>100) || (-baz foo) && " + "NOT verbatimSearch(\"hello world\")", + Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT( + tokens, + ElementsAre( + EqualsLexerToken("email", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::DOT), + EqualsLexerToken("sender", Lexer::TokenType::TEXT), + EqualsLexerToken(":", Lexer::TokenType::COMPARATOR), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken("foo*", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("bar", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::OR), + EqualsLexerToken("pow", Lexer::TokenType::FUNCTION_NAME), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken("age", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::COMMA), + EqualsLexerToken("2", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::RPAREN), + EqualsLexerToken(">", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("100", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::RPAREN), + EqualsLexerToken(Lexer::TokenType::OR), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken(Lexer::TokenType::MINUS), + EqualsLexerToken("baz", Lexer::TokenType::TEXT), + EqualsLexerToken("foo", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::RPAREN), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken(Lexer::TokenType::NOT), + EqualsLexerToken("verbatimSearch", Lexer::TokenType::FUNCTION_NAME), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken("hello world", Lexer::TokenType::STRING), + EqualsLexerToken(Lexer::TokenType::RPAREN))); +} + +TEST(LexerTest, UTF8WhiteSpace) { + std::unique_ptr<Lexer> lexer = std::make_unique<Lexer>( + "\xe2\x80\x88" + "foo" + "\xe2\x80\x89" + "\xe2\x80\x89" + "bar" + "\xe2\x80\x8a", + Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT), + EqualsLexerToken("bar", Lexer::TokenType::TEXT))); +} + +TEST(LexerTest, CJKT) { + std::unique_ptr<Lexer> lexer = std::make_unique<Lexer>( + "我 && 每天 || 走路 OR 去 -上班", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("我", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("每天", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::OR), + EqualsLexerToken("走路", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::OR), + EqualsLexerToken("去", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::MINUS), + EqualsLexerToken("上班", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("私&& は ||毎日 AND 仕事 -に 歩い て い ます", + Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("私", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("は", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::OR), + EqualsLexerToken("毎日", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("仕事", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::MINUS), + EqualsLexerToken("に", Lexer::TokenType::TEXT), + EqualsLexerToken("歩い", Lexer::TokenType::TEXT), + EqualsLexerToken("て", Lexer::TokenType::TEXT), + EqualsLexerToken("い", Lexer::TokenType::TEXT), + EqualsLexerToken("ます", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("ញុំ&&ដើរទៅ||ធ្វើការ-រាល់ថ្ងៃ", + Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("ញុំ", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::AND), + EqualsLexerToken("ដើរទៅ", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::OR), + EqualsLexerToken("ធ្វើការ", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::MINUS), + EqualsLexerToken("រាល់ថ្ងៃ", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>( + "나는" + "\xe2\x80\x88" // White Space + "매일" + "\xe2\x80\x89" // White Space + "출근합니다", + Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT( + tokens, + ElementsAre(EqualsLexerToken("나는", Lexer::TokenType::TEXT), + EqualsLexerToken("매일", Lexer::TokenType::TEXT), + EqualsLexerToken("출근합니다", Lexer::TokenType::TEXT))); +} + +TEST(LexerTest, SyntaxError) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("\"foo", Lexer::Language::QUERY); + EXPECT_THAT(lexer->ExtractTokens(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + lexer = std::make_unique<Lexer>("\"foo\\", Lexer::Language::QUERY); + EXPECT_THAT(lexer->ExtractTokens(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + lexer = std::make_unique<Lexer>("foo\\", Lexer::Language::QUERY); + EXPECT_THAT(lexer->ExtractTokens(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +// "!", "=", "&" and "|" should be treated as valid symbols in TEXT, if not +// matched as "!=", "==", "&&", or "||". +TEST(LexerTest, SpecialSymbolAsText) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("age=20", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("age=20", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("age !20", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("age", Lexer::TokenType::TEXT), + EqualsLexerToken("!20", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("foo& bar", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo&", Lexer::TokenType::TEXT), + EqualsLexerToken("bar", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("foo | bar", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("foo", Lexer::TokenType::TEXT), + EqualsLexerToken("|", Lexer::TokenType::TEXT), + EqualsLexerToken("bar", Lexer::TokenType::TEXT))); +} + +TEST(LexerTest, ScoringArithmetic) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("1 + 2", Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::PLUS), + EqualsLexerToken("2", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("1+2*3/4", Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::PLUS), + EqualsLexerToken("2", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::TIMES), + EqualsLexerToken("3", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::DIV), + EqualsLexerToken("4", Lexer::TokenType::TEXT))); + + // Arithmetic operators will not be produced in query language. + lexer = std::make_unique<Lexer>("1 + 2", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT), + EqualsLexerToken("+", Lexer::TokenType::TEXT), + EqualsLexerToken("2", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("1+2*3/4", Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1+2*3/4", Lexer::TokenType::TEXT))); +} + +// Currently, in scoring language, the lexer will view these logic operators as +// TEXTs. In the future, they may be rejected instead. +TEST(LexerTest, LogicOperatorNotInScoring) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("1 && 2", Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT), + EqualsLexerToken("&&", Lexer::TokenType::TEXT), + EqualsLexerToken("2", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("1&&2", Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1&&2", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("1&&2 ||3", Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1&&2", Lexer::TokenType::TEXT), + EqualsLexerToken("||3", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("1 AND 2 OR 3 AND NOT 4", + Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT), + EqualsLexerToken("AND", Lexer::TokenType::TEXT), + EqualsLexerToken("2", Lexer::TokenType::TEXT), + EqualsLexerToken("OR", Lexer::TokenType::TEXT), + EqualsLexerToken("3", Lexer::TokenType::TEXT), + EqualsLexerToken("AND", Lexer::TokenType::TEXT), + EqualsLexerToken("NOT", Lexer::TokenType::TEXT), + EqualsLexerToken("4", Lexer::TokenType::TEXT))); +} + +TEST(LexerTest, ComparatorNotInScoring) { + std::unique_ptr<Lexer> lexer = + std::make_unique<Lexer>("1 > 2", Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT), + EqualsLexerToken(">", Lexer::TokenType::TEXT), + EqualsLexerToken("2", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("1>2", Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1>2", Lexer::TokenType::TEXT))); + + lexer = std::make_unique<Lexer>("1>2>=3 <= 4:5== 6<7<=8!= 9", + Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1>2>=3", Lexer::TokenType::TEXT), + EqualsLexerToken("<=", Lexer::TokenType::TEXT), + EqualsLexerToken("4:5==", Lexer::TokenType::TEXT), + EqualsLexerToken("6<7<=8!=", Lexer::TokenType::TEXT), + EqualsLexerToken("9", Lexer::TokenType::TEXT))); + + // Comparator should be produced in query language. + lexer = std::make_unique<Lexer>("1>2>=3 <= 4:5== 6<7<=8!= 9", + Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(tokens, lexer->ExtractTokens()); + EXPECT_THAT(tokens, + ElementsAre(EqualsLexerToken("1", Lexer::TokenType::TEXT), + EqualsLexerToken(">", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("2", Lexer::TokenType::TEXT), + EqualsLexerToken(">=", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("3", Lexer::TokenType::TEXT), + EqualsLexerToken("<=", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("4", Lexer::TokenType::TEXT), + EqualsLexerToken(":", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("5", Lexer::TokenType::TEXT), + EqualsLexerToken("==", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("6", Lexer::TokenType::TEXT), + EqualsLexerToken("<", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("7", Lexer::TokenType::TEXT), + EqualsLexerToken("<=", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("8", Lexer::TokenType::TEXT), + EqualsLexerToken("!=", Lexer::TokenType::COMPARATOR), + EqualsLexerToken("9", Lexer::TokenType::TEXT))); +} + +TEST(LexerTest, ComplexScoring) { + std::unique_ptr<Lexer> lexer = std::make_unique<Lexer>( + "1/log( (CreationTimestamp(document) + LastUsedTimestamp(document)) / 2 " + ") * pow(2.3, DocumentScore())", + Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> tokens, + lexer->ExtractTokens()); + EXPECT_THAT( + tokens, + ElementsAre( + EqualsLexerToken("1", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::DIV), + EqualsLexerToken("log", Lexer::TokenType::FUNCTION_NAME), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken("CreationTimestamp", + Lexer::TokenType::FUNCTION_NAME), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken("document", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::RPAREN), + EqualsLexerToken(Lexer::TokenType::PLUS), + EqualsLexerToken("LastUsedTimestamp", + Lexer::TokenType::FUNCTION_NAME), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken("document", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::RPAREN), + EqualsLexerToken(Lexer::TokenType::RPAREN), + EqualsLexerToken(Lexer::TokenType::DIV), + EqualsLexerToken("2", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::RPAREN), + EqualsLexerToken(Lexer::TokenType::TIMES), + EqualsLexerToken("pow", Lexer::TokenType::FUNCTION_NAME), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken("2", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::DOT), + EqualsLexerToken("3", Lexer::TokenType::TEXT), + EqualsLexerToken(Lexer::TokenType::COMMA), + EqualsLexerToken("DocumentScore", Lexer::TokenType::FUNCTION_NAME), + EqualsLexerToken(Lexer::TokenType::LPAREN), + EqualsLexerToken(Lexer::TokenType::RPAREN), + EqualsLexerToken(Lexer::TokenType::RPAREN))); +} + +} // namespace lib +} // namespace icing diff --git a/icing/query/advanced_query_parser/parser.cc b/icing/query/advanced_query_parser/parser.cc new file mode 100644 index 0000000..086f038 --- /dev/null +++ b/icing/query/advanced_query_parser/parser.cc @@ -0,0 +1,414 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/advanced_query_parser/parser.h" + +#include <memory> +#include <string_view> + +#include "icing/absl_ports/canonical_errors.h" +#include "icing/legacy/core/icing-string-util.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +namespace { + +std::unique_ptr<Node> CreateNaryNode( + std::string_view operator_text, + std::vector<std::unique_ptr<Node>>&& operands) { + if (operands.empty()) { + return nullptr; + } + if (operands.size() == 1) { + return std::move(operands.at(0)); + } + return std::make_unique<NaryOperatorNode>(std::string(operator_text), + std::move(operands)); +} + +} // namespace + +libtextclassifier3::Status Parser::Consume(Lexer::TokenType token_type) { + if (!Match(token_type)) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Unable to consume token %d.", static_cast<int>(token_type))); + } + ++current_token_; + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::StatusOr<std::unique_ptr<TextNode>> Parser::ConsumeText() { + if (!Match(Lexer::TokenType::TEXT)) { + return absl_ports::InvalidArgumentError("Unable to consume token as TEXT."); + } + auto text_node = std::make_unique<TextNode>(std::move(current_token_->text)); + ++current_token_; + return text_node; +} + +libtextclassifier3::StatusOr<std::unique_ptr<FunctionNameNode>> +Parser::ConsumeFunctionName() { + if (!Match(Lexer::TokenType::FUNCTION_NAME)) { + return absl_ports::InvalidArgumentError( + "Unable to consume token as FUNCTION_NAME."); + } + auto function_name_node = + std::make_unique<FunctionNameNode>(std::move(current_token_->text)); + ++current_token_; + return function_name_node; +} + +libtextclassifier3::StatusOr<std::unique_ptr<StringNode>> +Parser::ConsumeString() { + if (!Match(Lexer::TokenType::STRING)) { + return absl_ports::InvalidArgumentError( + "Unable to consume token as STRING."); + } + auto node = std::make_unique<StringNode>(std::move(current_token_->text)); + ++current_token_; + return node; +} + +libtextclassifier3::StatusOr<std::string> Parser::ConsumeComparator() { + if (!Match(Lexer::TokenType::COMPARATOR)) { + return absl_ports::InvalidArgumentError( + "Unable to consume token as COMPARATOR."); + } + std::string comparator = std::move(current_token_->text); + ++current_token_; + return comparator; +} + +// member +// : TEXT (DOT TEXT)* (DOT function)? +// ; +libtextclassifier3::StatusOr<std::unique_ptr<MemberNode>> +Parser::ConsumeMember() { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<TextNode> text_node, ConsumeText()); + std::vector<std::unique_ptr<TextNode>> children; + children.push_back(std::move(text_node)); + + while (Match(Lexer::TokenType::DOT)) { + Consume(Lexer::TokenType::DOT); + if (MatchFunction()) { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<FunctionNode> function_node, + ConsumeFunction()); + // Once a function is matched, we should exit the current rule based on + // the grammar. + return std::make_unique<MemberNode>(std::move(children), + std::move(function_node)); + } + ICING_ASSIGN_OR_RETURN(text_node, ConsumeText()); + children.push_back(std::move(text_node)); + } + return std::make_unique<MemberNode>(std::move(children), + /*function=*/nullptr); +} + +// function +// : FUNCTION_NAME LPAREN argList? RPAREN +// ; +libtextclassifier3::StatusOr<std::unique_ptr<FunctionNode>> +Parser::ConsumeFunction() { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<FunctionNameNode> function_name, + ConsumeFunctionName()); + ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::LPAREN)); + + std::vector<std::unique_ptr<Node>> args; + if (Match(Lexer::TokenType::RPAREN)) { + // Got empty argument. + ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::RPAREN)); + } else { + ICING_ASSIGN_OR_RETURN(args, ConsumeArgs()); + ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::RPAREN)); + } + return std::make_unique<FunctionNode>(std::move(function_name), + std::move(args)); +} + +// comparable +// : STRING +// | member +// | function +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> +Parser::ConsumeComparable() { + if (Match(Lexer::TokenType::STRING)) { + return ConsumeString(); + } else if (MatchMember()) { + return ConsumeMember(); + } + // The current token sequence isn't a STRING or member. Therefore, it must be + // a function. + return ConsumeFunction(); +} + +// composite +// : LPAREN expression RPAREN +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeComposite() { + ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::LPAREN)); + + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> expression, ConsumeExpression()); + + ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::RPAREN)); + return expression; +} + +// argList +// : expression (COMMA expression)* +// ; +libtextclassifier3::StatusOr<std::vector<std::unique_ptr<Node>>> +Parser::ConsumeArgs() { + std::vector<std::unique_ptr<Node>> args; + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> arg, ConsumeExpression()); + args.push_back(std::move(arg)); + while (Match(Lexer::TokenType::COMMA)) { + Consume(Lexer::TokenType::COMMA); + ICING_ASSIGN_OR_RETURN(arg, ConsumeExpression()); + args.push_back(std::move(arg)); + } + return args; +} + +// restriction +// : comparable (COMPARATOR (comparable | composite))? +// ; +// COMPARATOR will not be produced in Scoring Lexer. +libtextclassifier3::StatusOr<std::unique_ptr<Node>> +Parser::ConsumeRestriction() { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> comparable, ConsumeComparable()); + + if (!Match(Lexer::TokenType::COMPARATOR)) { + return comparable; + } + ICING_ASSIGN_OR_RETURN(std::string operator_text, ConsumeComparator()); + std::unique_ptr<Node> arg; + if (MatchComposite()) { + ICING_ASSIGN_OR_RETURN(arg, ConsumeComposite()); + } else if (MatchComparable()) { + ICING_ASSIGN_OR_RETURN(arg, ConsumeComparable()); + } else { + return absl_ports::InvalidArgumentError( + "ARG: must begin with LPAREN or FIRST(comparable)"); + } + std::vector<std::unique_ptr<Node>> args; + args.push_back(std::move(comparable)); + args.push_back(std::move(arg)); + return std::make_unique<NaryOperatorNode>(std::move(operator_text), + std::move(args)); +} + +// simple +// : restriction +// | composite +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeSimple() { + if (MatchComposite()) { + return ConsumeComposite(); + } else if (MatchRestriction()) { + return ConsumeRestriction(); + } + return absl_ports::InvalidArgumentError( + "SIMPLE: must be a restriction or composite"); +} + +// term +// : NOT? simple +// | MINUS simple +// ; +// NOT will not be produced in Scoring Lexer. +libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeTerm() { + if (!Match(Lexer::TokenType::NOT) && !Match(Lexer::TokenType::MINUS)) { + return ConsumeSimple(); + } + std::string operator_text; + if (language_ == Lexer::Language::SCORING) { + ICING_RETURN_IF_ERROR(Consume(Lexer::TokenType::MINUS)); + operator_text = "MINUS"; + } else { + if (Match(Lexer::TokenType::NOT)) { + Consume(Lexer::TokenType::NOT); + } else { + Consume(Lexer::TokenType::MINUS); + } + operator_text = "NOT"; + } + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> simple, ConsumeSimple()); + return std::make_unique<UnaryOperatorNode>(operator_text, std::move(simple)); +} + +// factor +// : term (OR term)* +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeFactor() { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> term, ConsumeTerm()); + std::vector<std::unique_ptr<Node>> terms; + terms.push_back(std::move(term)); + + while (Match(Lexer::TokenType::OR)) { + Consume(Lexer::TokenType::OR); + ICING_ASSIGN_OR_RETURN(term, ConsumeTerm()); + terms.push_back(std::move(term)); + } + + return CreateNaryNode("OR", std::move(terms)); +} + +// sequence +// : (factor)+ +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeSequence() { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> factor, ConsumeFactor()); + std::vector<std::unique_ptr<Node>> factors; + factors.push_back(std::move(factor)); + + while (MatchFactor()) { + ICING_ASSIGN_OR_RETURN(factor, ConsumeFactor()); + factors.push_back(std::move(factor)); + } + + return CreateNaryNode("AND", std::move(factors)); +} + +// expression +// : sequence (AND sequence)* +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> +Parser::ConsumeQueryExpression() { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> sequence, ConsumeSequence()); + std::vector<std::unique_ptr<Node>> sequences; + sequences.push_back(std::move(sequence)); + + while (Match(Lexer::TokenType::AND)) { + Consume(Lexer::TokenType::AND); + ICING_ASSIGN_OR_RETURN(sequence, ConsumeSequence()); + sequences.push_back(std::move(sequence)); + } + + return CreateNaryNode("AND", std::move(sequences)); +} + +// multExpr +// : term ((TIMES | DIV) term)* +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeMultExpr() { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> node, ConsumeTerm()); + std::vector<std::unique_ptr<Node>> stack; + stack.push_back(std::move(node)); + + while (Match(Lexer::TokenType::TIMES) || Match(Lexer::TokenType::DIV)) { + while (Match(Lexer::TokenType::TIMES)) { + Consume(Lexer::TokenType::TIMES); + ICING_ASSIGN_OR_RETURN(node, ConsumeTerm()); + stack.push_back(std::move(node)); + } + node = CreateNaryNode("TIMES", std::move(stack)); + stack.clear(); + stack.push_back(std::move(node)); + + while (Match(Lexer::TokenType::DIV)) { + Consume(Lexer::TokenType::DIV); + ICING_ASSIGN_OR_RETURN(node, ConsumeTerm()); + stack.push_back(std::move(node)); + } + node = CreateNaryNode("DIV", std::move(stack)); + stack.clear(); + stack.push_back(std::move(node)); + } + + return std::move(stack[0]); +} + +// expression +// : multExpr ((PLUS | MINUS) multExpr)* +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> +Parser::ConsumeScoringExpression() { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> node, ConsumeMultExpr()); + std::vector<std::unique_ptr<Node>> stack; + stack.push_back(std::move(node)); + + while (Match(Lexer::TokenType::PLUS) || Match(Lexer::TokenType::MINUS)) { + while (Match(Lexer::TokenType::PLUS)) { + Consume(Lexer::TokenType::PLUS); + ICING_ASSIGN_OR_RETURN(node, ConsumeMultExpr()); + stack.push_back(std::move(node)); + } + node = CreateNaryNode("PLUS", std::move(stack)); + stack.clear(); + stack.push_back(std::move(node)); + + while (Match(Lexer::TokenType::MINUS)) { + Consume(Lexer::TokenType::MINUS); + ICING_ASSIGN_OR_RETURN(node, ConsumeMultExpr()); + stack.push_back(std::move(node)); + } + node = CreateNaryNode("MINUS", std::move(stack)); + stack.clear(); + stack.push_back(std::move(node)); + } + + return std::move(stack[0]); +} + +libtextclassifier3::StatusOr<std::unique_ptr<Node>> +Parser::ConsumeExpression() { + switch (language_) { + case Lexer::Language::QUERY: + return ConsumeQueryExpression(); + case Lexer::Language::SCORING: + return ConsumeScoringExpression(); + } +} + +// query +// : expression? EOF +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeQuery() { + language_ = Lexer::Language::QUERY; + std::unique_ptr<Node> node; + if (current_token_ != lexer_tokens_.end()) { + ICING_ASSIGN_OR_RETURN(node, ConsumeExpression()); + } + if (current_token_ != lexer_tokens_.end()) { + return absl_ports::InvalidArgumentError( + "Error parsing Query. Must reach EOF after parsing Expression!"); + } + return node; +} + +// scoring +// : expression EOF +// ; +libtextclassifier3::StatusOr<std::unique_ptr<Node>> Parser::ConsumeScoring() { + language_ = Lexer::Language::SCORING; + std::unique_ptr<Node> node; + if (current_token_ == lexer_tokens_.end()) { + return absl_ports::InvalidArgumentError("Got empty scoring expression!"); + } + ICING_ASSIGN_OR_RETURN(node, ConsumeExpression()); + if (current_token_ != lexer_tokens_.end()) { + return absl_ports::InvalidArgumentError( + "Error parsing the scoring expression. Must reach EOF after parsing " + "Expression!"); + } + return node; +} + +} // namespace lib +} // namespace icing diff --git a/icing/query/advanced_query_parser/parser.h b/icing/query/advanced_query_parser/parser.h new file mode 100644 index 0000000..330b8b9 --- /dev/null +++ b/icing/query/advanced_query_parser/parser.h @@ -0,0 +1,140 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_PARSER_H_ +#define ICING_QUERY_ADVANCED_QUERY_PARSER_PARSER_H_ + +#include <memory> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/query/advanced_query_parser/lexer.h" + +namespace icing { +namespace lib { + +class Parser { + public: + static Parser Create(std::vector<Lexer::LexerToken>&& lexer_tokens) { + return Parser(std::move(lexer_tokens)); + } + + // Returns: + // On success, pointer to the root node of the AST + // INVALID_ARGUMENT for input that does not conform to the grammar + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeQuery(); + + // Returns: + // On success, pointer to the root node of the AST + // INVALID_ARGUMENT for input that does not conform to the grammar + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeScoring(); + + private: + explicit Parser(std::vector<Lexer::LexerToken>&& lexer_tokens) + : lexer_tokens_(std::move(lexer_tokens)), + current_token_(lexer_tokens_.begin()) {} + + // Match Functions + // These functions are used to test whether the current_token matches a member + // of the FIRST set of a particular symbol in our grammar. + bool Match(Lexer::TokenType token_type) const { + return current_token_ != lexer_tokens_.end() && + current_token_->type == token_type; + } + + bool MatchMember() const { return Match(Lexer::TokenType::TEXT); } + + bool MatchFunction() const { return Match(Lexer::TokenType::FUNCTION_NAME); } + + bool MatchComparable() const { + return Match(Lexer::TokenType::STRING) || MatchMember() || MatchFunction(); + } + + bool MatchComposite() const { return Match(Lexer::TokenType::LPAREN); } + + bool MatchRestriction() const { return MatchComparable(); } + + bool MatchSimple() const { return MatchRestriction() || MatchComposite(); } + + bool MatchTerm() const { + return MatchSimple() || Match(Lexer::TokenType::NOT) || + Match(Lexer::TokenType::MINUS); + } + + bool MatchFactor() const { return MatchTerm(); } + + // Consume Functions + // These functions attempt to parse the token sequence starting at + // current_token_. + // Returns INVALID_ARGUMENT if unable to parse the token sequence starting at + // current_token_ as that particular grammar symbol. There are no guarantees + // about what state current_token and lexer_tokens_ are in when returning an + // error. + // + // Consume functions for terminal symbols. These are the only Consume + // functions that will directly modify current_token_. + // The Consume functions for terminals will guarantee not to modify + // current_token_ and lexer_tokens_ when returning an error. + libtextclassifier3::Status Consume(Lexer::TokenType token_type); + + libtextclassifier3::StatusOr<std::unique_ptr<TextNode>> ConsumeText(); + + libtextclassifier3::StatusOr<std::unique_ptr<FunctionNameNode>> + ConsumeFunctionName(); + + libtextclassifier3::StatusOr<std::unique_ptr<StringNode>> ConsumeString(); + + libtextclassifier3::StatusOr<std::string> ConsumeComparator(); + + // Consume functions for non-terminal symbols. + libtextclassifier3::StatusOr<std::unique_ptr<MemberNode>> ConsumeMember(); + + libtextclassifier3::StatusOr<std::unique_ptr<FunctionNode>> ConsumeFunction(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeComparable(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeComposite(); + + libtextclassifier3::StatusOr<std::vector<std::unique_ptr<Node>>> + ConsumeArgs(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeRestriction(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeSimple(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeTerm(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeFactor(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeSequence(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeQueryExpression(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeMultExpr(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> + ConsumeScoringExpression(); + + libtextclassifier3::StatusOr<std::unique_ptr<Node>> ConsumeExpression(); + + std::vector<Lexer::LexerToken> lexer_tokens_; + std::vector<Lexer::LexerToken>::const_iterator current_token_; + Lexer::Language language_ = Lexer::Language::QUERY; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_PARSER_H_ diff --git a/icing/query/advanced_query_parser/parser_integration_test.cc b/icing/query/advanced_query_parser/parser_integration_test.cc new file mode 100644 index 0000000..75be15b --- /dev/null +++ b/icing/query/advanced_query_parser/parser_integration_test.cc @@ -0,0 +1,945 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/query/advanced_query_parser/lexer.h" +#include "icing/query/advanced_query_parser/parser.h" +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::IsNull; + +TEST(ParserIntegrationTest, EmptyQuery) { + std::string query = ""; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + EXPECT_THAT(tree_root, IsNull()); +} + +TEST(ParserIntegrationTest, EmptyScoring) { + std::string query = ""; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeScoring(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, SingleTerm) { + std::string query = "foo"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // member + // | + // text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember))); +} + +TEST(ParserIntegrationTest, ImplicitAnd) { + std::string query = "foo bar"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // AND + // / \ + // member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, AND } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("AND", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, Or) { + std::string query = "foo OR bar"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // OR + // / \ + // member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, OR } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("OR", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, And) { + std::string query = "foo AND bar"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // AND + // / \ + // member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, AND } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("AND", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, Not) { + std::string query = "NOT foo"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // NOT + // | + // member + // | + // text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, NOT } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("NOT", NodeType::kUnaryOperator))); +} + +TEST(ParserIntegrationTest, Minus) { + std::string query = "-foo"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // NOT + // | + // member + // | + // text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, NOT } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("NOT", NodeType::kUnaryOperator))); +} + +TEST(ParserIntegrationTest, Has) { + std::string query = "subject:foo"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // : + // / \ + // member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, binaryOp } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("subject", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo(":", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, HasNested) { + std::string query = "sender.name:foo"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // : + // / \ + // member member + // / \ | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, text, member, text, member, binaryOp } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("sender", NodeType::kText), + EqualsNodeInfo("name", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo(":", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, EmptyFunction) { + std::string query = "foo()"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // function + // | + // function_name + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, function } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction))); +} + +TEST(ParserIntegrationTest, FunctionSingleArg) { + std::string query = "foo(\"bar\")"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // function + // / \ + // function_name string + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, string, function } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("bar", NodeType::kString), + EqualsNodeInfo("", NodeType::kFunction))); +} + +TEST(ParserIntegrationTest, FunctionMultiArg) { + std::string query = "foo(\"bar\", \"baz\")"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // function + // / | \ + // function_name string string + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, string, string, function } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("bar", NodeType::kString), + EqualsNodeInfo("baz", NodeType::kString), + EqualsNodeInfo("", NodeType::kFunction))); +} + +TEST(ParserIntegrationTest, FunctionNested) { + std::string query = "foo(bar())"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // function + // / \ + // function_name function + // | + // function_name + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, function_name, function, function } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("bar", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kFunction))); +} + +TEST(ParserIntegrationTest, FunctionWithTrailingSequence) { + std::string query = "foo() OR bar"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // OR + // / \ + // function member + // | | + // function_name text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, function, text, member, OR } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("OR", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, Composite) { + std::string query = "foo OR (bar baz)"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // OR + // / \ + // member AND + // | / \ + // text member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, text, member, AND, OR } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("baz", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("AND", NodeType::kNaryOperator), + EqualsNodeInfo("OR", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, CompositeWithTrailingSequence) { + std::string query = "(bar baz) OR foo"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // OR + // / \ + // AND member + // / \ | + // member member text + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, AND, text, member, OR } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("baz", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("AND", NodeType::kNaryOperator), + EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("OR", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, Complex) { + std::string query = "foo bar:baz OR pal(\"bat\")"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // AND + // / \ + // member OR + // | / \ + // text : function + // / \ / \ + // member member function_name string + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, text, member, :, function_name, string, + // function, OR, AND } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("baz", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo(":", NodeType::kNaryOperator), + EqualsNodeInfo("pal", NodeType::kFunctionName), + EqualsNodeInfo("bat", NodeType::kString), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("OR", NodeType::kNaryOperator), + EqualsNodeInfo("AND", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, InvalidHas) { + std::string query = "foo:"; // No right hand operand to : + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, InvalidComposite) { + std::string query = "(foo bar"; // No terminating RPAREN + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, InvalidMember) { + std::string query = "foo."; // DOT must have succeeding TEXT + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, InvalidOr) { + std::string query = "foo OR"; // No right hand operand to OR + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, InvalidAnd) { + std::string query = "foo AND"; // No right hand operand to AND + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, InvalidNot) { + std::string query = "NOT"; // No right hand operand to NOT + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, InvalidMinus) { + std::string query = "-"; // No right hand operand to - + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, InvalidFunctionCallNoRparen) { + std::string query = "foo("; // No terminating RPAREN + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, InvalidFunctionArgsHangingComma) { + std::string query = "foo(\"bar\",)"; // no valid arg following COMMA + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserIntegrationTest, ScoringPlus) { + std::string scoring = "1 + 1 + 1"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // PLUS + // / | \ + // member member member + // | | | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, ScoringMinus) { + std::string scoring = "1 - 1 - 1"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // MINUS + // / | \ + // member member member + // | | | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("MINUS", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, ScoringUnaryMinus) { + std::string scoring = "1 + -1 + 1"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // PLUS + // / | \ + // member MINUS member + // | | | + // text member text + // | + // text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("MINUS", NodeType::kUnaryOperator), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, ScoringPlusMinus) { + std::string scoring = "11 + 12 - 13 + 14"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // PLUS + // / \ + // MINUS member + // / \ | + // PLUS member text + // / \ | + // member member text + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("11", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("12", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator), + EqualsNodeInfo("13", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("MINUS", NodeType::kNaryOperator), + EqualsNodeInfo("14", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, ScoringTimes) { + std::string scoring = "1 * 1 * 1"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // TIMES + // / | \ + // member member member + // | | | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("TIMES", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, ScoringDiv) { + std::string scoring = "1 / 1 / 1"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // DIV + // / | \ + // member member member + // | | | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("DIV", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, ScoringTimesDiv) { + std::string scoring = "11 / 12 * 13 / 14 / 15"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // DIV + // / | \ + // TIMES member member + // / \ | | + // DIV member text text + // / \ | + // member member text + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("11", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("12", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("DIV", NodeType::kNaryOperator), + EqualsNodeInfo("13", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("TIMES", NodeType::kNaryOperator), + EqualsNodeInfo("14", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("15", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("DIV", NodeType::kNaryOperator))); +} + +TEST(ParserIntegrationTest, ComplexScoring) { + // With parentheses in function arguments. + std::string scoring = "1 + pow((2 * sin(3)), 4) + -5 / 6"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + SimpleVisitor visitor; + tree_root->Accept(&visitor); + std::vector<NodeInfo> node = visitor.nodes(); + EXPECT_THAT(node, + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("pow", NodeType::kFunctionName), + EqualsNodeInfo("2", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("sin", NodeType::kFunctionName), + EqualsNodeInfo("3", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("TIMES", NodeType::kNaryOperator), + EqualsNodeInfo("4", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("5", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("MINUS", NodeType::kUnaryOperator), + EqualsNodeInfo("6", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("DIV", NodeType::kNaryOperator), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator))); + + // Without parentheses in function arguments. + scoring = "1 + pow(2 * sin(3), 4) + -5 / 6"; + lexer = Lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(lexer_tokens, lexer.ExtractTokens()); + parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(tree_root, parser.ConsumeScoring()); + visitor = SimpleVisitor(); + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), ElementsAreArray(node)); +} + +TEST(ParserIntegrationTest, ScoringMemberFunction) { + std::string scoring = "this.CreationTimestamp()"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // member + // / \ + // text function + // | + // function_name + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT( + visitor.nodes(), + ElementsAre(EqualsNodeInfo("this", NodeType::kText), + EqualsNodeInfo("CreationTimestamp", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kMember))); +} + +TEST(ParserIntegrationTest, QueryMemberFunction) { + std::string query = "this.foo()"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // member + // / \ + // text function + // | + // function_name + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("this", NodeType::kText), + EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kMember))); +} + +TEST(ParserIntegrationTest, ScoringComplexMemberFunction) { + std::string scoring = "a.b.fun(c, d)"; + Lexer lexer(scoring, Lexer::Language::SCORING); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // member + // / | \ + // text text function + // / | \ + // function_name member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("a", NodeType::kText), + EqualsNodeInfo("b", NodeType::kText), + EqualsNodeInfo("fun", NodeType::kFunctionName), + EqualsNodeInfo("c", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("d", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kMember))); +} + +TEST(ParserTest, QueryComplexMemberFunction) { + std::string query = "this.abc.fun(def, ghi)"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // member + // / | \ + // text text function + // / | \ + // function_name member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("this", NodeType::kText), + EqualsNodeInfo("abc", NodeType::kText), + EqualsNodeInfo("fun", NodeType::kFunctionName), + EqualsNodeInfo("def", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("ghi", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kMember))); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/query/advanced_query_parser/parser_test.cc b/icing/query/advanced_query_parser/parser_test.cc new file mode 100644 index 0000000..f997329 --- /dev/null +++ b/icing/query/advanced_query_parser/parser_test.cc @@ -0,0 +1,1043 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/advanced_query_parser/parser.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree-test-utils.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/query/advanced_query_parser/lexer.h" +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; +using ::testing::IsNull; + +TEST(ParserTest, EmptyQuery) { + std::vector<Lexer::LexerToken> lexer_tokens; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + EXPECT_THAT(tree_root, IsNull()); +} + +TEST(ParserTest, EmptyScoring) { + std::vector<Lexer::LexerToken> lexer_tokens; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeScoring(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, SingleTerm) { + // Query: "foo" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // member + // | + // text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember))); +} + +TEST(ParserTest, ImplicitAnd) { + // Query: "foo bar" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}, {"bar", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // AND + // / \ + // member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, AND } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("AND", NodeType::kNaryOperator))); +} + +TEST(ParserTest, Or) { + // Query: "foo OR bar" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::OR}, + {"bar", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // OR + // / \ + // member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, OR } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("OR", NodeType::kNaryOperator))); +} + +TEST(ParserTest, And) { + // Query: "foo AND bar" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::AND}, + {"bar", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // AND + // / \ + // member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, AND } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("AND", NodeType::kNaryOperator))); +} + +TEST(ParserTest, Not) { + // Query: "NOT foo" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"", Lexer::TokenType::NOT}, {"foo", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // NOT + // | + // member + // | + // text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, NOT } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("NOT", NodeType::kUnaryOperator))); +} + +TEST(ParserTest, Minus) { + // Query: "-foo" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"", Lexer::TokenType::MINUS}, {"foo", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // NOT + // | + // member + // | + // text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, NOT } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("NOT", NodeType::kUnaryOperator))); +} + +TEST(ParserTest, Has) { + // Query: "subject:foo" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"subject", Lexer::TokenType::TEXT}, + {":", Lexer::TokenType::COMPARATOR}, + {"foo", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // : + // / \ + // member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, binaryOp } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("subject", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo(":", NodeType::kNaryOperator))); +} + +TEST(ParserTest, HasNested) { + // Query: "sender.name:foo" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"sender", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::DOT}, + {"name", Lexer::TokenType::TEXT}, + {":", Lexer::TokenType::COMPARATOR}, + {"foo", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // : + // / \ + // member member + // / \ | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, text, member, text, member, binaryOp } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("sender", NodeType::kText), + EqualsNodeInfo("name", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo(":", NodeType::kNaryOperator))); +} + +TEST(ParserTest, EmptyFunction) { + // Query: "foo()" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // function + // | + // function_name + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, function } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction))); +} + +TEST(ParserTest, FunctionSingleArg) { + // Query: "foo("bar")" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"bar", Lexer::TokenType::STRING}, + {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // function + // / \ + // function_name string + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, string, function } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("bar", NodeType::kString), + EqualsNodeInfo("", NodeType::kFunction))); +} + +TEST(ParserTest, FunctionMultiArg) { + // Query: "foo("bar", "baz")" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN}, + {"bar", Lexer::TokenType::STRING}, {"", Lexer::TokenType::COMMA}, + {"baz", Lexer::TokenType::STRING}, {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // function + // / | \ + // function_name string string + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, string, string, function } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("bar", NodeType::kString), + EqualsNodeInfo("baz", NodeType::kString), + EqualsNodeInfo("", NodeType::kFunction))); +} + +TEST(ParserTest, FunctionNested) { + // Query: "foo(bar())" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN}, + {"bar", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN}, + {"", Lexer::TokenType::RPAREN}, {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // function + // / \ + // function_name function + // | + // function_name + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, function_name, function, function } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("bar", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kFunction))); +} + +TEST(ParserTest, FunctionWithTrailingSequence) { + // Query: "foo() OR bar" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"", Lexer::TokenType::RPAREN}, + {"", Lexer::TokenType::OR}, + {"bar", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // OR + // / \ + // function member + // | | + // function_name text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { function_name, function, text, member, OR } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("OR", NodeType::kNaryOperator))); +} + +TEST(ParserTest, Composite) { + // Query: "foo OR (bar baz)" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::OR}, + {"", Lexer::TokenType::LPAREN}, {"bar", Lexer::TokenType::TEXT}, + {"baz", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // OR + // / \ + // member AND + // | / \ + // text member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, text, member, AND, OR } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("baz", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("AND", NodeType::kNaryOperator), + EqualsNodeInfo("OR", NodeType::kNaryOperator))); +} + +TEST(ParserTest, CompositeWithTrailingSequence) { + // Query: "(bar baz) OR foo" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"", Lexer::TokenType::LPAREN}, {"bar", Lexer::TokenType::TEXT}, + {"baz", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::RPAREN}, + {"", Lexer::TokenType::OR}, {"foo", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // OR + // / \ + // AND member + // / \ | + // member member text + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, AND, text, member, OR } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("baz", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("AND", NodeType::kNaryOperator), + EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("OR", NodeType::kNaryOperator))); +} + +TEST(ParserTest, Complex) { + // Query: "foo bar:baz OR pal("bat")" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}, + {"bar", Lexer::TokenType::TEXT}, + {":", Lexer::TokenType::COMPARATOR}, + {"baz", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::OR}, + {"pal", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"bat", Lexer::TokenType::STRING}, + {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // AND + // / \ + // member OR + // | / \ + // text : function + // / \ / \ + // member member function_name string + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + // SimpleVisitor ordering + // { text, member, text, member, text, member, :, function_name, string, + // function, OR, AND } + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("foo", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("bar", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("baz", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo(":", NodeType::kNaryOperator), + EqualsNodeInfo("pal", NodeType::kFunctionName), + EqualsNodeInfo("bat", NodeType::kString), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("OR", NodeType::kNaryOperator), + EqualsNodeInfo("AND", NodeType::kNaryOperator))); +} + +TEST(ParserTest, InvalidHas) { + // Query: "foo:" No right hand operand to : + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}, {":", Lexer::TokenType::COMPARATOR}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, InvalidComposite) { + // Query: "(foo bar" No terminating RPAREN + std::vector<Lexer::LexerToken> lexer_tokens = { + {"", Lexer::TokenType::LPAREN}, + {"foo", Lexer::TokenType::TEXT}, + {"bar", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, InvalidMember) { + // Query: "foo." DOT must have succeeding TEXT + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DOT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, InvalidOr) { + // Query: "foo OR" No right hand operand to OR + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::OR}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, InvalidAnd) { + // Query: "foo AND" No right hand operand to AND + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::AND}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, InvalidNot) { + // Query: "NOT" No right hand operand to NOT + std::vector<Lexer::LexerToken> lexer_tokens = {{"", Lexer::TokenType::NOT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, InvalidMinus) { + // Query: "-" No right hand operand to - + std::vector<Lexer::LexerToken> lexer_tokens = {{"", Lexer::TokenType::MINUS}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, InvalidFunctionCallNoRparen) { + // Query: "foo(" No terminating RPAREN + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, InvalidFunctionCallNoLparen) { + // Query: "foo bar" foo labeled FUNCTION_NAME despite no LPAREN + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::FUNCTION_NAME}, + {"bar", Lexer::TokenType::FUNCTION_NAME}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, InvalidFunctionArgsHangingComma) { + // Query: "foo("bar",)" no valid arg following COMMA + std::vector<Lexer::LexerToken> lexer_tokens = { + {"foo", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"bar", Lexer::TokenType::STRING}, + {"", Lexer::TokenType::COMMA}, + {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeQuery(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(ParserTest, ScoringPlus) { + // Scoring: "1 + 1 + 1" + std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::PLUS}, + {"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::PLUS}, + {"1", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // PLUS + // / | \ + // member member member + // | | | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator))); +} + +TEST(ParserTest, ScoringMinus) { + // Scoring: "1 - 1 - 1" + std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::MINUS}, + {"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::MINUS}, + {"1", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // MINUS + // / | \ + // member member member + // | | | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("MINUS", NodeType::kNaryOperator))); +} + +TEST(ParserTest, ScoringUnaryMinus) { + // Scoring: "1 + -1 + 1" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"1", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::PLUS}, + {"", Lexer::TokenType::MINUS}, {"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::PLUS}, {"1", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // PLUS + // / | \ + // member MINUS member + // | | | + // text member text + // | + // text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("MINUS", NodeType::kUnaryOperator), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator))); +} + +TEST(ParserTest, ScoringPlusMinus) { + // Scoring: "11 + 12 - 13 + 14" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"11", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::PLUS}, + {"12", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::MINUS}, + {"13", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::PLUS}, + {"14", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // PLUS + // / \ + // MINUS member + // / \ | + // PLUS member text + // / \ | + // member member text + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("11", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("12", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator), + EqualsNodeInfo("13", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("MINUS", NodeType::kNaryOperator), + EqualsNodeInfo("14", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator))); +} + +TEST(ParserTest, ScoringTimes) { + // Scoring: "1 * 1 * 1" + std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::TIMES}, + {"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::TIMES}, + {"1", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // TIMES + // / | \ + // member member member + // | | | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("TIMES", NodeType::kNaryOperator))); +} + +TEST(ParserTest, ScoringDiv) { + // Scoring: "1 / 1 / 1" + std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::DIV}, + {"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::DIV}, + {"1", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // DIV + // / | \ + // member member member + // | | | + // text text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("DIV", NodeType::kNaryOperator))); +} + +TEST(ParserTest, ScoringTimesDiv) { + // Scoring: "11 / 12 * 13 / 14 / 15" + std::vector<Lexer::LexerToken> lexer_tokens = { + {"11", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DIV}, + {"12", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::TIMES}, + {"13", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DIV}, + {"14", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DIV}, + {"15", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // DIV + // / | \ + // TIMES member member + // / \ | | + // DIV member text text + // / \ | + // member member text + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("11", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("12", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("DIV", NodeType::kNaryOperator), + EqualsNodeInfo("13", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("TIMES", NodeType::kNaryOperator), + EqualsNodeInfo("14", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("15", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("DIV", NodeType::kNaryOperator))); +} + +TEST(ParserTest, ComplexScoring) { + // Scoring: "1 + pow((2 * sin(3)), 4) + -5 / 6" + // With parentheses in function arguments. + std::vector<Lexer::LexerToken> lexer_tokens = { + {"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::PLUS}, + {"pow", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"", Lexer::TokenType::LPAREN}, + {"2", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::TIMES}, + {"sin", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"3", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::RPAREN}, + {"", Lexer::TokenType::RPAREN}, + {"", Lexer::TokenType::COMMA}, + {"4", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::RPAREN}, + {"", Lexer::TokenType::PLUS}, + {"", Lexer::TokenType::MINUS}, + {"5", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::DIV}, + {"6", Lexer::TokenType::TEXT}, + }; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + SimpleVisitor visitor; + tree_root->Accept(&visitor); + std::vector<NodeInfo> node = visitor.nodes(); + EXPECT_THAT(node, + ElementsAre(EqualsNodeInfo("1", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("pow", NodeType::kFunctionName), + EqualsNodeInfo("2", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("sin", NodeType::kFunctionName), + EqualsNodeInfo("3", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("TIMES", NodeType::kNaryOperator), + EqualsNodeInfo("4", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("5", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("MINUS", NodeType::kUnaryOperator), + EqualsNodeInfo("6", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("DIV", NodeType::kNaryOperator), + EqualsNodeInfo("PLUS", NodeType::kNaryOperator))); + + // Scoring: "1 + pow(2 * sin(3), 4) + -5 / 6" + // Without parentheses in function arguments. + lexer_tokens = { + {"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::PLUS}, + {"pow", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"2", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::TIMES}, + {"sin", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"3", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::RPAREN}, + {"", Lexer::TokenType::COMMA}, + {"4", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::RPAREN}, + {"", Lexer::TokenType::PLUS}, + {"", Lexer::TokenType::MINUS}, + {"5", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::DIV}, + {"6", Lexer::TokenType::TEXT}, + }; + parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(tree_root, parser.ConsumeScoring()); + visitor = SimpleVisitor(); + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), ElementsAreArray(node)); +} + +TEST(ParserTest, ScoringMemberFunction) { + // Scoring: this.CreationTimestamp() + std::vector<Lexer::LexerToken> lexer_tokens = { + {"this", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::DOT}, + {"CreationTimestamp", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // member + // / \ + // text function + // | + // function_name + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT( + visitor.nodes(), + ElementsAre(EqualsNodeInfo("this", NodeType::kText), + EqualsNodeInfo("CreationTimestamp", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kMember))); +} + +TEST(ParserTest, QueryMemberFunction) { + // Query: this.foo() + std::vector<Lexer::LexerToken> lexer_tokens = { + {"this", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::DOT}, + {"foo", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // member + // / \ + // text function + // | + // function_name + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("this", NodeType::kText), + EqualsNodeInfo("foo", NodeType::kFunctionName), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kMember))); +} + +TEST(ParserTest, ScoringComplexMemberFunction) { + // Scoring: a.b.fun(c, d) + std::vector<Lexer::LexerToken> lexer_tokens = { + {"a", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::DOT}, + {"b", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::DOT}, + {"fun", Lexer::TokenType::FUNCTION_NAME}, + {"", Lexer::TokenType::LPAREN}, + {"c", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::COMMA}, + {"d", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + // Expected AST: + // member + // / | \ + // text text function + // / | \ + // function_name member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("a", NodeType::kText), + EqualsNodeInfo("b", NodeType::kText), + EqualsNodeInfo("fun", NodeType::kFunctionName), + EqualsNodeInfo("c", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("d", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kMember))); +} + +TEST(ParserTest, QueryComplexMemberFunction) { + // Query: this.abc.fun(def, ghi) + std::vector<Lexer::LexerToken> lexer_tokens = { + {"this", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DOT}, + {"abc", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::DOT}, + {"fun", Lexer::TokenType::FUNCTION_NAME}, {"", Lexer::TokenType::LPAREN}, + {"def", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::COMMA}, + {"ghi", Lexer::TokenType::TEXT}, {"", Lexer::TokenType::RPAREN}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + // Expected AST: + // member + // / | \ + // text text function + // / | \ + // function_name member member + // | | + // text text + SimpleVisitor visitor; + tree_root->Accept(&visitor); + EXPECT_THAT(visitor.nodes(), + ElementsAre(EqualsNodeInfo("this", NodeType::kText), + EqualsNodeInfo("abc", NodeType::kText), + EqualsNodeInfo("fun", NodeType::kFunctionName), + EqualsNodeInfo("def", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("ghi", NodeType::kText), + EqualsNodeInfo("", NodeType::kMember), + EqualsNodeInfo("", NodeType::kFunction), + EqualsNodeInfo("", NodeType::kMember))); +} + +TEST(ParserTest, InvalidScoringToken) { + // Scoring: "1 + NOT 1" + std::vector<Lexer::LexerToken> lexer_tokens = {{"1", Lexer::TokenType::TEXT}, + {"", Lexer::TokenType::PLUS}, + {"", Lexer::TokenType::NOT}, + {"1", Lexer::TokenType::TEXT}}; + Parser parser = Parser::Create(std::move(lexer_tokens)); + EXPECT_THAT(parser.ConsumeScoring(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/query/advanced_query_parser/query-visitor.cc b/icing/query/advanced_query_parser/query-visitor.cc new file mode 100644 index 0000000..21ce55b --- /dev/null +++ b/icing/query/advanced_query_parser/query-visitor.cc @@ -0,0 +1,228 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/advanced_query_parser/query-visitor.h" + +#include <cstdint> +#include <cstdlib> +#include <limits> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/schema/section-manager.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +namespace { + +bool IsNumericComparator(std::string_view operator_text) { + if (operator_text.length() < 1 || operator_text.length() > 2) { + return false; + } + // TODO(tjbarron) decide how/if to support != + return operator_text == "<" || operator_text == ">" || + operator_text == "==" || operator_text == "<=" || + operator_text == ">="; +} + +bool IsSupportedOperator(std::string_view operator_text) { + return IsNumericComparator(operator_text); +} + +} // namespace + +libtextclassifier3::StatusOr<int64_t> QueryVisitor::RetrieveIntValue() { + if (pending_values_.empty() || !pending_values_.top().holds_text()) { + return absl_ports::InvalidArgumentError("Unable to retrieve int value."); + } + std::string& value = pending_values_.top().text; + char* value_end; + int64_t int_value = std::strtoll(value.c_str(), &value_end, /*base=*/10); + if (value_end != value.c_str() + value.length()) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat("Unable to parse \"", value, "\" as number.")); + } + pending_values_.pop(); + return int_value; +} + +libtextclassifier3::StatusOr<std::string> QueryVisitor::RetrieveStringValue() { + if (pending_values_.empty() || !pending_values_.top().holds_text()) { + return absl_ports::InvalidArgumentError("Unable to retrieve string value."); + } + std::string string_value = std::move(pending_values_.top().text); + pending_values_.pop(); + return string_value; +} + +struct Int64Range { + int64_t low; + int64_t high; +}; + +libtextclassifier3::StatusOr<Int64Range> GetInt64Range( + std::string_view operator_text, int64_t int_value) { + Int64Range range = {std::numeric_limits<int64_t>::min(), + std::numeric_limits<int64_t>::max()}; + if (operator_text == "<") { + if (int_value == std::numeric_limits<int64_t>::min()) { + return absl_ports::InvalidArgumentError( + "Cannot specify < INT64_MIN in query expression."); + } + range.high = int_value - 1; + } else if (operator_text == "<=") { + range.high = int_value; + } else if (operator_text == "==") { + range.high = int_value; + range.low = int_value; + } else if (operator_text == ">=") { + range.low = int_value; + } else if (operator_text == ">") { + if (int_value == std::numeric_limits<int64_t>::max()) { + return absl_ports::InvalidArgumentError( + "Cannot specify > INT64_MAX in query expression."); + } + range.low = int_value + 1; + } + return range; +} + +libtextclassifier3::StatusOr<QueryVisitor::PendingValue> +QueryVisitor::ProcessNumericComparator(const NaryOperatorNode* node) { + // 1. The children should have been processed and added their outputs to + // pending_values_. Time to process them. + // The first two pending values should be the int value and the property. + ICING_ASSIGN_OR_RETURN(int64_t int_value, RetrieveIntValue()); + ICING_ASSIGN_OR_RETURN(std::string property, RetrieveStringValue()); + + // 2. Create the iterator. + ICING_ASSIGN_OR_RETURN(Int64Range range, + GetInt64Range(node->operator_text(), int_value)); + auto iterator_or = + numeric_index_.GetIterator(property, range.low, range.high); + if (!iterator_or.ok()) { + return std::move(iterator_or).status(); + } + std::unique_ptr<DocHitInfoIterator> iterator = + std::move(iterator_or).ValueOrDie(); + return PendingValue(std::move(iterator)); +} + +void QueryVisitor::VisitFunctionName(const FunctionNameNode* node) { + pending_error_ = absl_ports::UnimplementedError( + "Function Name node visiting not implemented yet."); +} + +void QueryVisitor::VisitString(const StringNode* node) { + pending_error_ = absl_ports::UnimplementedError( + "String node visiting not implemented yet."); +} + +void QueryVisitor::VisitText(const TextNode* node) { + pending_values_.push(PendingValue(node->value())); +} + +void QueryVisitor::VisitMember(const MemberNode* node) { + // 1. Put in a placeholder PendingValue + pending_values_.push(PendingValue()); + + // 2. Visit the children. + for (const std::unique_ptr<TextNode>& child : node->children()) { + child->Accept(this); + if (has_pending_error()) { + return; + } + } + + // 3. The children should have been processed and added their outputs to + // pending_values_. Time to process them. + std::string member = std::move(pending_values_.top().text); + pending_values_.pop(); + while (!pending_values_.empty() && !pending_values_.top().is_placeholder()) { + member = absl_ports::StrCat(pending_values_.top().text, kPropertySeparator, + member); + pending_values_.pop(); + } + + // 4. If pending_values_ is empty somehow, then our placeholder disappeared + // somehow. + if (pending_values_.empty()) { + pending_error_ = absl_ports::InvalidArgumentError( + "\"<\" operator must have two arguments."); + return; + } + pending_values_.pop(); + + pending_values_.push(PendingValue(std::move(member))); +} + +void QueryVisitor::VisitFunction(const FunctionNode* node) { + pending_error_ = absl_ports::UnimplementedError( + "Function node visiting not implemented yet."); +} + +void QueryVisitor::VisitUnaryOperator(const UnaryOperatorNode* node) { + pending_error_ = + absl_ports::UnimplementedError("Not node visiting not implemented yet."); +} + +void QueryVisitor::VisitNaryOperator(const NaryOperatorNode* node) { + if (has_pending_error()) { + return; + } + + if (!IsSupportedOperator(node->operator_text())) { + pending_error_ = absl_ports::UnimplementedError( + "No support for any non-numeric operators."); + return; + } + + // 1. Put in a placeholder PendingValue + pending_values_.push(PendingValue()); + + // 2. Visit the children. + for (const std::unique_ptr<Node>& child : node->children()) { + child->Accept(this); + if (has_pending_error()) { + return; + } + } + + // 3. Retrieve the pending value for this node. + PendingValue pending_value; + if (IsNumericComparator(node->operator_text())) { + auto pending_value_or = ProcessNumericComparator(node); + if (!pending_value_or.ok()) { + pending_error_ = std::move(pending_value_or).status(); + return; + } + pending_value = std::move(pending_value_or).ValueOrDie(); + } + + // 4. Check for the placeholder. + if (!pending_values_.top().is_placeholder()) { + pending_error_ = absl_ports::InvalidArgumentError( + "Error processing arguments for node."); + return; + } + pending_values_.pop(); + + pending_values_.push(std::move(pending_value)); +} + +} // namespace lib +} // namespace icing diff --git a/icing/query/advanced_query_parser/query-visitor.h b/icing/query/advanced_query_parser/query-visitor.h new file mode 100644 index 0000000..e834606 --- /dev/null +++ b/icing/query/advanced_query_parser/query-visitor.h @@ -0,0 +1,119 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_QUERY_ADVANCED_QUERY_PARSER_QUERY_VISITOR_H_ +#define ICING_QUERY_ADVANCED_QUERY_PARSER_QUERY_VISITOR_H_ + +#include <cstdint> +#include <memory> +#include <stack> +#include <string> + +#include "icing/absl_ports/canonical_errors.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/numeric-index.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" + +namespace icing { +namespace lib { + +// The Visitor used to create the DocHitInfoIterator tree from the AST output by +// the parser. +class QueryVisitor : public AbstractSyntaxTreeVisitor { + public: + explicit QueryVisitor(const NumericIndex<int64_t>* numeric_index) + : numeric_index_(*numeric_index) {} + + void VisitFunctionName(const FunctionNameNode* node) override; + void VisitString(const StringNode* node) override; + void VisitText(const TextNode* node) override; + void VisitMember(const MemberNode* node) override; + void VisitFunction(const FunctionNode* node) override; + void VisitUnaryOperator(const UnaryOperatorNode* node) override; + void VisitNaryOperator(const NaryOperatorNode* node) override; + + // RETURNS: + // - the DocHitInfoIterator that is the root of the query iterator tree + // - INVALID_ARGUMENT if the AST does not conform to supported expressions + libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> root() && { + if (has_pending_error()) { + return pending_error_; + } + if (pending_values_.size() != 1 || + !pending_values_.top().holds_iterator()) { + return absl_ports::InvalidArgumentError( + "Visitor does not contain a single root iterator."); + } + return std::move(pending_values_.top().iterator); + } + + private: + // A holder for intermediate results when processing child nodes. + struct PendingValue { + PendingValue() = default; + + explicit PendingValue(std::unique_ptr<DocHitInfoIterator> iterator) + : iterator(std::move(iterator)) {} + + explicit PendingValue(std::string text) : text(std::move(text)) {} + + // Placeholder is used to indicate where the children of a particular node + // begin. + bool is_placeholder() const { return iterator == nullptr && text.empty(); } + + bool holds_text() const { return iterator == nullptr && !text.empty(); } + + bool holds_iterator() const { return iterator != nullptr && text.empty(); } + + std::unique_ptr<DocHitInfoIterator> iterator; + std::string text; + }; + + bool has_pending_error() const { return !pending_error_.ok(); } + + // Processes the PendingValue at the top of pending_values_, parses it into a + // int64_t and pops the top. + // Returns: + // - On success, the int value stored in the text at the top + // - INVALID_ARGUMENT if pending_values_ is empty, doesn't hold a text or + // can't be parsed as an int. + libtextclassifier3::StatusOr<int64_t> RetrieveIntValue(); + + // Processes the PendingValue at the top of pending_values_ and pops the top. + // Returns: + // - On success, the string value stored in the text at the top + // - INVALID_ARGUMENT if pending_values_ is empty or doesn't hold a text. + libtextclassifier3::StatusOr<std::string> RetrieveStringValue(); + + // Processes the NumericComparator represented by node. This must be called + // *after* this node's children have been visited. The PendingValues added by + // this node's children will be consumed by this function and the PendingValue + // for this node will be returned. + // Returns: + // - On success, then PendingValue representing this node and it's children. + // - INVALID_ARGUMENT if unable to retrieve string value or int value + // - NOT_FOUND if there is no entry in the numeric index for the property + libtextclassifier3::StatusOr<PendingValue> ProcessNumericComparator( + const NaryOperatorNode* node); + + std::stack<PendingValue> pending_values_; + libtextclassifier3::Status pending_error_; + + const NumericIndex<int64>& numeric_index_; // Does not own! +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_QUERY_ADVANCED_QUERY_PARSER_QUERY_VISITOR_H_ diff --git a/icing/query/advanced_query_parser/query-visitor_test.cc b/icing/query/advanced_query_parser/query-visitor_test.cc new file mode 100644 index 0000000..1e456fe --- /dev/null +++ b/icing/query/advanced_query_parser/query-visitor_test.cc @@ -0,0 +1,557 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/advanced_query_parser/query-visitor.h" + +#include <cstdint> +#include <limits> +#include <memory> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/index/iterator/doc-hit-info-iterator-test-util.h" +#include "icing/index/numeric/dummy-numeric-index.h" +#include "icing/index/numeric/numeric-index.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/query/advanced_query_parser/lexer.h" +#include "icing/query/advanced_query_parser/parser.h" +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; + +constexpr DocumentId kDocumentId0 = 0; +constexpr DocumentId kDocumentId1 = 1; +constexpr DocumentId kDocumentId2 = 2; + +constexpr SectionId kSectionId0 = 0; +constexpr SectionId kSectionId1 = 1; +constexpr SectionId kSectionId2 = 2; + +TEST(QueryVisitorTest, SimpleLessThan) { + // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2 + // respectively. + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("price", kDocumentId0, kSectionId0); + editor->BufferKey(0); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId1, kSectionId1); + editor->BufferKey(1); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId2, kSectionId2); + editor->BufferKey(2); + editor->IndexAllBufferedKeys(); + + std::string query = "price < 2"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, + std::move(query_visitor).root()); + + EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ElementsAre(kDocumentId1, kDocumentId0)); +} + +TEST(QueryVisitorTest, SimpleLessThanEq) { + // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2 + // respectively. + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("price", kDocumentId0, kSectionId0); + editor->BufferKey(0); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId1, kSectionId1); + editor->BufferKey(1); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId2, kSectionId2); + editor->BufferKey(2); + editor->IndexAllBufferedKeys(); + + std::string query = "price <= 1"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, + std::move(query_visitor).root()); + + EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ElementsAre(kDocumentId1, kDocumentId0)); +} + +TEST(QueryVisitorTest, SimpleEqual) { + // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2 + // respectively. + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("price", kDocumentId0, kSectionId0); + editor->BufferKey(0); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId1, kSectionId1); + editor->BufferKey(1); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId2, kSectionId2); + editor->BufferKey(2); + editor->IndexAllBufferedKeys(); + + std::string query = "price == 2"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, + std::move(query_visitor).root()); + + EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); +} + +TEST(QueryVisitorTest, SimpleGreaterThanEq) { + // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2 + // respectively. + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("price", kDocumentId0, kSectionId0); + editor->BufferKey(0); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId1, kSectionId1); + editor->BufferKey(1); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId2, kSectionId2); + editor->BufferKey(2); + editor->IndexAllBufferedKeys(); + + std::string query = "price >= 1"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, + std::move(query_visitor).root()); + + EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ElementsAre(kDocumentId2, kDocumentId1)); +} + +TEST(QueryVisitorTest, SimpleGreaterThan) { + // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2 + // respectively. + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("price", kDocumentId0, kSectionId0); + editor->BufferKey(0); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId1, kSectionId1); + editor->BufferKey(1); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId2, kSectionId2); + editor->BufferKey(2); + editor->IndexAllBufferedKeys(); + + std::string query = "price > 1"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, + std::move(query_visitor).root()); + + EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId2)); +} + +// TODO(b/208654892) Properly handle negative numbers in query expressions. +TEST(QueryVisitorTest, DISABLED_IntMinLessThanEqual) { + // Setup the numeric index with docs 0, 1 and 2 holding the values INT_MIN, + // INT_MAX and INT_MIN + 1 respectively. + int64_t int_min = std::numeric_limits<int64_t>::min(); + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("price", kDocumentId0, kSectionId0); + editor->BufferKey(int_min); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId1, kSectionId1); + editor->BufferKey(std::numeric_limits<int64_t>::max()); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId2, kSectionId2); + editor->BufferKey(int_min + 1); + editor->IndexAllBufferedKeys(); + + std::string query = "price <= " + std::to_string(int_min); + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, + std::move(query_visitor).root()); + + EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId0)); +} + +TEST(QueryVisitorTest, IntMaxGreaterThanEqual) { + // Setup the numeric index with docs 0, 1 and 2 holding the values INT_MIN, + // INT_MAX and INT_MAX - 1 respectively. + int64_t int_max = std::numeric_limits<int64_t>::max(); + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("price", kDocumentId0, kSectionId0); + editor->BufferKey(std::numeric_limits<int64_t>::min()); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId1, kSectionId1); + editor->BufferKey(int_max); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId2, kSectionId2); + editor->BufferKey(int_max - 1); + editor->IndexAllBufferedKeys(); + + std::string query = "price >= " + std::to_string(int_max); + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, + std::move(query_visitor).root()); + + EXPECT_THAT(GetDocumentIds(root_iterator.get()), ElementsAre(kDocumentId1)); +} + +TEST(QueryVisitorTest, NestedPropertyLessThan) { + // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2 + // respectively. + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("subscription.price", kDocumentId0, kSectionId0); + editor->BufferKey(0); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("subscription.price", kDocumentId1, kSectionId1); + editor->BufferKey(1); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("subscription.price", kDocumentId2, kSectionId2); + editor->BufferKey(2); + editor->IndexAllBufferedKeys(); + + std::string query = "subscription.price < 2"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> root_iterator, + std::move(query_visitor).root()); + + EXPECT_THAT(GetDocumentIds(root_iterator.get()), + ElementsAre(kDocumentId1, kDocumentId0)); +} + +TEST(QueryVisitorTest, IntParsingError) { + DummyNumericIndex<int64_t> numeric_index; + + std::string query = "subscription.price < fruit"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).root(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(QueryVisitorTest, NotEqualsUnsupported) { + DummyNumericIndex<int64_t> numeric_index; + + std::string query = "subscription.price != 3"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).root(), + StatusIs(libtextclassifier3::StatusCode::UNIMPLEMENTED)); +} + +TEST(QueryVisitorTest, UnrecognizedOperatorTooLongUnsupported) { + DummyNumericIndex<int64_t> numeric_index; + + // Create an AST for the query 'subscription.price !<= 3' + std::string query = "subscription.price !<= 3"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // There is no support for the 'not less than or equal to' operator. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).root(), + StatusIs(libtextclassifier3::StatusCode::UNIMPLEMENTED)); +} + +TEST(QueryVisitorTest, LessThanTooManyOperandsInvalid) { + // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2 + // respectively. + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("subscription.price", kDocumentId0, kSectionId0); + editor->BufferKey(0); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("subscription.price", kDocumentId1, kSectionId1); + editor->BufferKey(1); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("subscription.price", kDocumentId2, kSectionId2); + editor->BufferKey(2); + editor->IndexAllBufferedKeys(); + + // Create an invalid AST for the query '3 < subscription.price 25' where '<' + // has three operands + auto property_node = std::make_unique<TextNode>("subscription"); + auto subproperty_node = std::make_unique<TextNode>("price"); + std::vector<std::unique_ptr<TextNode>> member_args; + member_args.push_back(std::move(property_node)); + member_args.push_back(std::move(subproperty_node)); + auto member_node = std::make_unique<MemberNode>(std::move(member_args), + /*function=*/nullptr); + + auto value_node = std::make_unique<TextNode>("3"); + auto extra_value_node = std::make_unique<TextNode>("25"); + std::vector<std::unique_ptr<Node>> args; + args.push_back(std::move(value_node)); + args.push_back(std::move(member_node)); + args.push_back(std::move(extra_value_node)); + auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args)); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).root(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(QueryVisitorTest, LessThanTooFewOperandsInvalid) { + DummyNumericIndex<int64_t> numeric_index; + + // Create an invalid AST for the query 'subscription.price <' where '<' + // has a single operand + auto property_node = std::make_unique<TextNode>("subscription"); + auto subproperty_node = std::make_unique<TextNode>("price"); + std::vector<std::unique_ptr<TextNode>> member_args; + member_args.push_back(std::move(property_node)); + member_args.push_back(std::move(subproperty_node)); + auto member_node = std::make_unique<MemberNode>(std::move(member_args), + /*function=*/nullptr); + + std::vector<std::unique_ptr<Node>> args; + args.push_back(std::move(member_node)); + auto root_node = std::make_unique<NaryOperatorNode>("<", std::move(args)); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).root(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(QueryVisitorTest, LessThanNonExistentPropertyNotFound) { + // Setup the numeric index with docs 0, 1 and 2 holding the values 0, 1 and 2 + // respectively. + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("subscription.price", kDocumentId0, kSectionId0); + editor->BufferKey(0); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("subscription.price", kDocumentId1, kSectionId1); + editor->BufferKey(1); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("subscription.price", kDocumentId2, kSectionId2); + editor->BufferKey(2); + editor->IndexAllBufferedKeys(); + + // Create an invalid AST for the query 'time < 25' where '<' + // has three operands + std::string query = "time < 25"; + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).root(), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); +} + +TEST(QueryVisitorTest, NeverVisitedReturnsInvalid) { + DummyNumericIndex<int64_t> numeric_index; + QueryVisitor query_visitor(&numeric_index); + EXPECT_THAT(std::move(query_visitor).root(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +// TODO(b/208654892) Properly handle negative numbers in query expressions. +TEST(QueryVisitorTest, DISABLED_IntMinLessThanInvalid) { + // Setup the numeric index with docs 0, 1 and 2 holding the values INT_MIN, + // INT_MAX and INT_MIN + 1 respectively. + int64_t int_min = std::numeric_limits<int64_t>::min(); + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("price", kDocumentId0, kSectionId0); + editor->BufferKey(int_min); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId1, kSectionId1); + editor->BufferKey(std::numeric_limits<int64_t>::max()); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId2, kSectionId2); + editor->BufferKey(int_min + 1); + editor->IndexAllBufferedKeys(); + + std::string query = "price <" + std::to_string(int_min); + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).root(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST(QueryVisitorTest, IntMaxGreaterThanInvalid) { + // Setup the numeric index with docs 0, 1 and 2 holding the values INT_MIN, + // INT_MAX and INT_MAX - 1 respectively. + int64_t int_max = std::numeric_limits<int64_t>::max(); + DummyNumericIndex<int64_t> numeric_index; + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index.Edit("price", kDocumentId0, kSectionId0); + editor->BufferKey(std::numeric_limits<int64_t>::min()); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId1, kSectionId1); + editor->BufferKey(int_max); + editor->IndexAllBufferedKeys(); + + editor = numeric_index.Edit("price", kDocumentId2, kSectionId2); + editor->BufferKey(int_max - 1); + editor->IndexAllBufferedKeys(); + + std::string query = "price >" + std::to_string(int_max); + Lexer lexer(query, Lexer::Language::QUERY); + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Node> root_node, + parser.ConsumeQuery()); + + // Retrieve the root_iterator from the visitor. + QueryVisitor query_visitor(&numeric_index); + root_node->Accept(&query_visitor); + EXPECT_THAT(std::move(query_visitor).root(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc index 90587aa..abef9e4 100644 --- a/icing/query/query-processor.cc +++ b/icing/query/query-processor.cc @@ -35,12 +35,12 @@ #include "icing/index/iterator/doc-hit-info-iterator-section-restrict.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/search.pb.h" - -#ifdef ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY -#include "icing/query/advanced-query-processor.h" -#endif // ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY - +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/query/advanced_query_parser/lexer.h" +#include "icing/query/advanced_query_parser/parser.h" +#include "icing/query/advanced_query_parser/query-visitor.h" #include "icing/query/query-processor.h" +#include "icing/query/query-results.h" #include "icing/query/query-terms.h" #include "icing/query/query-utils.h" #include "icing/schema/schema-store.h" @@ -107,27 +107,31 @@ std::unique_ptr<DocHitInfoIterator> ProcessParserStateFrame( } // namespace libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>> -QueryProcessor::Create(Index* index, +QueryProcessor::Create(Index* index, const NumericIndex<int64_t>* numeric_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, const SchemaStore* schema_store) { ICING_RETURN_ERROR_IF_NULL(index); + ICING_RETURN_ERROR_IF_NULL(numeric_index); ICING_RETURN_ERROR_IF_NULL(language_segmenter); ICING_RETURN_ERROR_IF_NULL(normalizer); ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); - return std::unique_ptr<QueryProcessor>(new QueryProcessor( - index, language_segmenter, normalizer, document_store, schema_store)); + return std::unique_ptr<QueryProcessor>( + new QueryProcessor(index, numeric_index, language_segmenter, normalizer, + document_store, schema_store)); } QueryProcessor::QueryProcessor(Index* index, + const NumericIndex<int64_t>* numeric_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, const SchemaStore* schema_store) : index_(*index), + numeric_index_(*numeric_index), language_segmenter_(*language_segmenter), normalizer_(*normalizer), document_store_(*document_store), @@ -145,22 +149,19 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch( QueryResults results; if (search_spec.search_type() == SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { -#ifdef ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY ICING_VLOG(1) << "Using EXPERIMENTAL_ICING_ADVANCED_QUERY parser!"; - ICING_ASSIGN_OR_RETURN( - std::unique_ptr<AdvancedQueryProcessor> advanced_query_processor, - AdvancedQueryProcessor::Create(&index_, &language_segmenter_, - &normalizer_, &document_store_, - &schema_store_)); - ICING_ASSIGN_OR_RETURN(results, advanced_query_processor->ParseSearch( - search_spec, ranking_strategy)); -#else // !ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY - ICING_LOG(ERROR) << "Requested EXPERIMENTAL_ICING_ADVANCED_QUERY search " - "type, but advanced query is not compiled in. Falling " - "back to ICING_RAW_QUERY."; - ICING_ASSIGN_OR_RETURN(results, - ParseRawQuery(search_spec, ranking_strategy)); -#endif // ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY + libtextclassifier3::StatusOr<QueryResults> results_or = + ParseAdvancedQuery(search_spec); + if (results_or.ok()) { + results = std::move(results_or).ValueOrDie(); + } else { + ICING_VLOG(1) + << "Unable to parse query using advanced query parser. Error: " + << results_or.status().error_message() + << ". Falling back to old query parser."; + ICING_ASSIGN_OR_RETURN(results, + ParseRawQuery(search_spec, ranking_strategy)); + } } else { ICING_ASSIGN_OR_RETURN(results, ParseRawQuery(search_spec, ranking_strategy)); @@ -173,6 +174,30 @@ libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseSearch( return results; } +libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseAdvancedQuery( + const SearchSpecProto& search_spec) const { + QueryResults results; + Lexer lexer(search_spec.query(), Lexer::Language::QUERY); + ICING_ASSIGN_OR_RETURN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> tree_root, + parser.ConsumeQuery()); + + if (tree_root == nullptr) { + results.root_iterator = std::make_unique<DocHitInfoIteratorAllDocumentId>( + document_store_.last_added_document_id()); + return results; + } + + QueryVisitor query_visitor(&numeric_index_); + tree_root->Accept(&query_visitor); + ICING_ASSIGN_OR_RETURN(results.root_iterator, + std::move(query_visitor).root()); + return results; +} + // TODO(cassiewang): Collect query stats to populate the SearchResultsProto libtextclassifier3::StatusOr<QueryResults> QueryProcessor::ParseRawQuery( const SearchSpecProto& search_spec, diff --git a/icing/query/query-processor.h b/icing/query/query-processor.h index f544a7a..a4f8973 100644 --- a/icing/query/query-processor.h +++ b/icing/query/query-processor.h @@ -15,12 +15,14 @@ #ifndef ICING_QUERY_QUERY_PROCESSOR_H_ #define ICING_QUERY_QUERY_PROCESSOR_H_ +#include <cstdint> #include <memory> #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-filter.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/numeric-index.h" #include "icing/proto/search.pb.h" #include "icing/query/query-results.h" #include "icing/query/query-terms.h" @@ -45,9 +47,9 @@ class QueryProcessor { // An QueryProcessor on success // FAILED_PRECONDITION if any of the pointers is null. static libtextclassifier3::StatusOr<std::unique_ptr<QueryProcessor>> Create( - Index* index, const LanguageSegmenter* language_segmenter, - const Normalizer* normalizer, const DocumentStore* document_store, - const SchemaStore* schema_store); + Index* index, const NumericIndex<int64_t>* numeric_index, + const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, + const DocumentStore* document_store, const SchemaStore* schema_store); // Parse the search configurations (including the query, any additional // filters, etc.) in the SearchSpecProto into one DocHitInfoIterator. @@ -69,12 +71,23 @@ class QueryProcessor { private: explicit QueryProcessor(Index* index, + const NumericIndex<int64_t>* numeric_index, const LanguageSegmenter* language_segmenter, const Normalizer* normalizer, const DocumentStore* document_store, const SchemaStore* schema_store); // Parse the query into a one DocHitInfoIterator that represents the root of a + // query tree in our new Advanced Query Language. + // + // Returns: + // On success, + // - One iterator that represents the entire query + // INVALID_ARGUMENT if query syntax is incorrect and cannot be tokenized + libtextclassifier3::StatusOr<QueryResults> ParseAdvancedQuery( + const SearchSpecProto& search_spec) const; + + // Parse the query into a one DocHitInfoIterator that represents the root of a // query tree. // // Returns: @@ -90,6 +103,7 @@ class QueryProcessor { // Not const because we could modify/sort the hit buffer in the lite index at // query time. Index& index_; + const NumericIndex<int64_t>& numeric_index_; const LanguageSegmenter& language_segmenter_; const Normalizer& normalizer_; const DocumentStore& document_store_; diff --git a/icing/query/query-processor_benchmark.cc b/icing/query/query-processor_benchmark.cc index 2015d81..6d776ce 100644 --- a/icing/query/query-processor_benchmark.cc +++ b/icing/query/query-processor_benchmark.cc @@ -17,6 +17,8 @@ #include "third_party/absl/flags/flag.h" #include "icing/document-builder.h" #include "icing/index/index.h" +#include "icing/index/numeric/dummy-numeric-index.h" +#include "icing/index/numeric/numeric-index.h" #include "icing/proto/schema.pb.h" #include "icing/proto/search.pb.h" #include "icing/proto/term.pb.h" @@ -113,6 +115,9 @@ void BM_QueryOneTerm(benchmark::State& state) { std::unique_ptr<Index> index = CreateIndex(icing_filesystem, filesystem, index_dir); + // TODO(b/249829533): switch to use persistent numeric index. + auto numeric_index = std::make_unique<DummyNumericIndex<int64_t>>(); + language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = language_segmenter_factory::Create(std::move(options)).ValueOrDie(); @@ -147,9 +152,9 @@ void BM_QueryOneTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index.get(), language_segmenter.get(), - normalizer.get(), document_store.get(), - schema_store.get())); + QueryProcessor::Create(index.get(), numeric_index.get(), + language_segmenter.get(), normalizer.get(), + document_store.get(), schema_store.get())); SearchSpecProto search_spec; search_spec.set_query(input_string); @@ -233,6 +238,9 @@ void BM_QueryFiveTerms(benchmark::State& state) { std::unique_ptr<Index> index = CreateIndex(icing_filesystem, filesystem, index_dir); + // TODO(b/249829533): switch to use persistent numeric index. + auto numeric_index = std::make_unique<DummyNumericIndex<int64_t>>(); + language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = language_segmenter_factory::Create(std::move(options)).ValueOrDie(); @@ -281,9 +289,9 @@ void BM_QueryFiveTerms(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index.get(), language_segmenter.get(), - normalizer.get(), document_store.get(), - schema_store.get())); + QueryProcessor::Create(index.get(), numeric_index.get(), + language_segmenter.get(), normalizer.get(), + document_store.get(), schema_store.get())); const std::string query_string = absl_ports::StrCat( input_string_a, " ", input_string_b, " ", input_string_c, " ", @@ -371,6 +379,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) { std::unique_ptr<Index> index = CreateIndex(icing_filesystem, filesystem, index_dir); + // TODO(b/249829533): switch to use persistent numeric index. + auto numeric_index = std::make_unique<DummyNumericIndex<int64_t>>(); + language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = language_segmenter_factory::Create(std::move(options)).ValueOrDie(); @@ -408,9 +419,9 @@ void BM_QueryDiacriticTerm(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index.get(), language_segmenter.get(), - normalizer.get(), document_store.get(), - schema_store.get())); + QueryProcessor::Create(index.get(), numeric_index.get(), + language_segmenter.get(), normalizer.get(), + document_store.get(), schema_store.get())); SearchSpecProto search_spec; search_spec.set_query(input_string); @@ -494,6 +505,9 @@ void BM_QueryHiragana(benchmark::State& state) { std::unique_ptr<Index> index = CreateIndex(icing_filesystem, filesystem, index_dir); + // TODO(b/249829533): switch to use persistent numeric index. + auto numeric_index = std::make_unique<DummyNumericIndex<int64_t>>(); + language_segmenter_factory::SegmenterOptions options(ULOC_US); std::unique_ptr<LanguageSegmenter> language_segmenter = language_segmenter_factory::Create(std::move(options)).ValueOrDie(); @@ -531,9 +545,9 @@ void BM_QueryHiragana(benchmark::State& state) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index.get(), language_segmenter.get(), - normalizer.get(), document_store.get(), - schema_store.get())); + QueryProcessor::Create(index.get(), numeric_index.get(), + language_segmenter.get(), normalizer.get(), + document_store.get(), schema_store.get())); SearchSpecProto search_spec; search_spec.set_query(input_string); diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc index da35df8..9f1386c 100644 --- a/icing/query/query-processor_test.cc +++ b/icing/query/query-processor_test.cc @@ -14,6 +14,7 @@ #include "icing/query/query-processor.h" +#include <cstdint> #include <memory> #include <string> @@ -26,13 +27,14 @@ #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/index/numeric/dummy-numeric-index.h" +#include "icing/index/numeric/numeric-index.h" #include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/portable/platform.h" #include "icing/proto/schema.pb.h" #include "icing/proto/search.pb.h" #include "icing/proto/term.pb.h" -#include "icing/query/query-processor.h" #include "icing/schema-builder.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" @@ -59,7 +61,6 @@ using ::testing::ElementsAre; using ::testing::ElementsAreArray; using ::testing::IsEmpty; using ::testing::SizeIs; -using ::testing::Test; using ::testing::UnorderedElementsAre; class QueryProcessorTest @@ -88,11 +89,22 @@ class QueryProcessorTest icu_data_file_helper::SetUpICUDataFile( GetTestFilePath("icing/icu.dat"))); } + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); Index::Options options(index_dir_, /*index_merge_size=*/1024 * 1024); ICING_ASSERT_OK_AND_ASSIGN( index_, Index::Create(options, &filesystem_, &icing_filesystem_)); + // TODO(b/249829533): switch to use persistent numeric index. + numeric_index_ = std::make_unique<DummyNumericIndex<int64_t>>(); language_segmenter_factory::SegmenterOptions segmenter_options( ULOC_US, jni_cache_.get()); @@ -102,6 +114,12 @@ class QueryProcessorTest ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( /*max_term_byte_size=*/1000)); + + ICING_ASSERT_OK_AND_ASSIGN( + query_processor_, + QueryProcessor::Create(index_.get(), numeric_index_.get(), + language_segmenter_.get(), normalizer_.get(), + document_store_.get(), schema_store_.get())); } libtextclassifier3::Status AddTokenToIndex( @@ -113,6 +131,16 @@ class QueryProcessorTest return status.ok() ? editor.IndexAllBufferedTerms() : status; } + libtextclassifier3::Status AddToNumericIndex(DocumentId document_id, + const std::string& property, + SectionId section_id, + int64_t value) { + std::unique_ptr<NumericIndex<int64_t>::Editor> editor = + numeric_index_->Edit(property, document_id, section_id); + ICING_RETURN_IF_ERROR(editor->BufferKey(value)); + return editor->IndexAllBufferedKeys(); + } + void TearDown() override { document_store_.reset(); schema_store_.reset(); @@ -129,35 +157,44 @@ class QueryProcessorTest protected: std::unique_ptr<Index> index_; + std::unique_ptr<NumericIndex<int64_t>> numeric_index_; std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<Normalizer> normalizer_; FakeClock fake_clock_; std::unique_ptr<const JniCache> jni_cache_ = GetTestJniCache(); std::unique_ptr<SchemaStore> schema_store_; std::unique_ptr<DocumentStore> document_store_; + std::unique_ptr<QueryProcessor> query_processor_; }; TEST_P(QueryProcessorTest, CreationWithNullPointerShouldFail) { EXPECT_THAT( - QueryProcessor::Create(/*index=*/nullptr, language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get()), + QueryProcessor::Create(/*index=*/nullptr, numeric_index_.get(), + language_segmenter_.get(), normalizer_.get(), + document_store_.get(), schema_store_.get()), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT( - QueryProcessor::Create(index_.get(), /*language_segmenter=*/nullptr, - normalizer_.get(), document_store_.get(), - schema_store_.get()), + QueryProcessor::Create(index_.get(), /*numeric_index_=*/nullptr, + language_segmenter_.get(), normalizer_.get(), + document_store_.get(), schema_store_.get()), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT( - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - /*normalizer=*/nullptr, document_store_.get(), - schema_store_.get()), + QueryProcessor::Create(index_.get(), numeric_index_.get(), + /*language_segmenter=*/nullptr, normalizer_.get(), + document_store_.get(), schema_store_.get()), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); - EXPECT_THAT(QueryProcessor::Create( - index_.get(), language_segmenter_.get(), normalizer_.get(), - /*document_store=*/nullptr, schema_store_.get()), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); - EXPECT_THAT(QueryProcessor::Create(index_.get(), language_segmenter_.get(), + EXPECT_THAT( + QueryProcessor::Create( + index_.get(), numeric_index_.get(), language_segmenter_.get(), + /*normalizer=*/nullptr, document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + QueryProcessor::Create(index_.get(), numeric_index_.get(), + language_segmenter_.get(), normalizer_.get(), + /*document_store=*/nullptr, schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(QueryProcessor::Create(index_.get(), numeric_index_.get(), + language_segmenter_.get(), normalizer_.get(), document_store_.get(), /*schema_store=*/nullptr), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); @@ -168,18 +205,8 @@ TEST_P(QueryProcessorTest, EmptyGroupMatchAllDocuments) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, document_store_->Put(DocumentBuilder() .SetKey("namespace", "1") @@ -193,22 +220,14 @@ TEST_P(QueryProcessorTest, EmptyGroupMatchAllDocuments) { // We don't need to insert anything in the index since the empty query will // match all DocumentIds from the DocumentStore - - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("()"); search_spec.set_search_type(GetParam()); ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch(search_spec, - ScoringSpecProto::RankingStrategy::NONE)); + query_processor_->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); // Descending order of valid DocumentIds EXPECT_THAT(GetDocumentIds(results.root_iterator.get()), @@ -222,18 +241,8 @@ TEST_P(QueryProcessorTest, EmptyQueryMatchAllDocuments) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, document_store_->Put(DocumentBuilder() .SetKey("namespace", "1") @@ -247,22 +256,14 @@ TEST_P(QueryProcessorTest, EmptyQueryMatchAllDocuments) { // We don't need to insert anything in the index since the empty query will // match all DocumentIds from the DocumentStore - - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query(""); search_spec.set_search_type(GetParam()); ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch(search_spec, - ScoringSpecProto::RankingStrategy::NONE)); + query_processor_->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); // Descending order of valid DocumentIds EXPECT_THAT(GetDocumentIds(results.root_iterator.get()), @@ -276,18 +277,8 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // namespaces populated. @@ -310,13 +301,6 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) { AddTokenToIndex(document_id, section_id, term_match_type, "world"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("hElLo WORLD"); search_spec.set_term_match_type(term_match_type); @@ -324,7 +308,7 @@ TEST_P(QueryProcessorTest, QueryTermNormalized) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; @@ -354,18 +338,8 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // namespaces populated. @@ -385,13 +359,6 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) { AddTokenToIndex(document_id, section_id, term_match_type, "hello"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("he"); search_spec.set_term_match_type(term_match_type); @@ -399,7 +366,7 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; @@ -424,18 +391,8 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatchWithMaxSectionID) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // namespaces populated. @@ -456,13 +413,6 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatchWithMaxSectionID) { AddTokenToIndex(document_id, section_id, term_match_type, "hello"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("he"); search_spec.set_term_match_type(term_match_type); @@ -470,7 +420,7 @@ TEST_P(QueryProcessorTest, OneTermPrefixMatchWithMaxSectionID) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; @@ -495,18 +445,8 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // namespaces populated. @@ -526,13 +466,6 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) { AddTokenToIndex(document_id, section_id, term_match_type, "hello"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("hello"); search_spec.set_term_match_type(term_match_type); @@ -540,7 +473,7 @@ TEST_P(QueryProcessorTest, OneTermExactMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; @@ -565,18 +498,8 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -596,13 +519,6 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) { AddTokenToIndex(document_id, section_id, term_match_type, "hello"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("hello hello"); search_spec.set_term_match_type(term_match_type); @@ -610,7 +526,7 @@ TEST_P(QueryProcessorTest, AndSameTermExactMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; @@ -637,18 +553,8 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -671,13 +577,6 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) { AddTokenToIndex(document_id, section_id, term_match_type, "world"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("hello world"); search_spec.set_term_match_type(term_match_type); @@ -685,7 +584,7 @@ TEST_P(QueryProcessorTest, AndTwoTermExactMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; @@ -714,18 +613,8 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -745,13 +634,6 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) { AddTokenToIndex(document_id, section_id, term_match_type, "hello"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("he he"); search_spec.set_term_match_type(term_match_type); @@ -759,7 +641,7 @@ TEST_P(QueryProcessorTest, AndSameTermPrefixMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; @@ -786,18 +668,8 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -820,13 +692,6 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) { AddTokenToIndex(document_id, section_id, term_match_type, "world"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("he wo"); search_spec.set_term_match_type(term_match_type); @@ -834,7 +699,7 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -864,18 +729,8 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -898,13 +753,6 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { AddTokenToIndex(document_id, section_id, term_match_type, "world"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("hello wo"); search_spec.set_term_match_type(term_match_type); @@ -912,7 +760,7 @@ TEST_P(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -942,18 +790,8 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -981,13 +819,6 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) { AddTokenToIndex(document_id2, section_id, term_match_type, "world"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("hello OR world"); search_spec.set_term_match_type(term_match_type); @@ -995,7 +826,7 @@ TEST_P(QueryProcessorTest, OrTwoTermExactMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -1033,18 +864,8 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -1072,13 +893,6 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) { AddTokenToIndex(document_id2, section_id, term_match_type, "world"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("he OR wo"); search_spec.set_term_match_type(term_match_type); @@ -1086,7 +900,7 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -1123,18 +937,8 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -1161,13 +965,6 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { AddTokenToIndex(document_id2, section_id, TermMatchType::PREFIX, "world"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("hello OR wo"); search_spec.set_term_match_type(TermMatchType::PREFIX); @@ -1175,7 +972,7 @@ TEST_P(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -1212,18 +1009,8 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -1264,13 +1051,6 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - { // OR gets precedence over AND, this is parsed as ((puppy OR kitten) AND // dog) @@ -1281,7 +1061,7 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Only Document 1 matches since it has puppy AND dog @@ -1318,7 +1098,7 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Both Document 1 and 2 match since Document 1 has animal AND puppy, and @@ -1374,7 +1154,7 @@ TEST_P(QueryProcessorTest, CombinedAndOrTerms) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Only Document 2 matches since it has both kitten and cat @@ -1407,18 +1187,8 @@ TEST_P(QueryProcessorTest, OneGroup) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -1451,13 +1221,6 @@ TEST_P(QueryProcessorTest, OneGroup) { EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - // Without grouping, this would be parsed as ((puppy OR kitten) AND foo) and // no documents would match. But with grouping, Document 1 matches puppy SearchSpecProto search_spec; @@ -1467,7 +1230,7 @@ TEST_P(QueryProcessorTest, OneGroup) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -1486,18 +1249,8 @@ TEST_P(QueryProcessorTest, TwoGroups) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -1530,12 +1283,6 @@ TEST_P(QueryProcessorTest, TwoGroups) { EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - // Without grouping, this would be parsed as (puppy AND (dog OR kitten) AND // cat) and wouldn't match any documents. But with grouping, Document 1 // matches (puppy AND dog) and Document 2 matches (kitten and cat). @@ -1546,7 +1293,7 @@ TEST_P(QueryProcessorTest, TwoGroups) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -1567,18 +1314,8 @@ TEST_P(QueryProcessorTest, ManyLevelNestedGrouping) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -1611,13 +1348,6 @@ TEST_P(QueryProcessorTest, ManyLevelNestedGrouping) { EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - // Without grouping, this would be parsed as ((puppy OR kitten) AND foo) and // no documents would match. But with grouping, Document 1 matches puppy SearchSpecProto search_spec; @@ -1627,7 +1357,7 @@ TEST_P(QueryProcessorTest, ManyLevelNestedGrouping) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -1646,18 +1376,8 @@ TEST_P(QueryProcessorTest, OneLevelNestedGrouping) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -1690,13 +1410,6 @@ TEST_P(QueryProcessorTest, OneLevelNestedGrouping) { EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - // Document 1 will match puppy and Document 2 matches (kitten AND (cat)) SearchSpecProto search_spec; // TODO(b/208654892) decide how we want to handle queries of the form foo(...) @@ -1706,7 +1419,7 @@ TEST_P(QueryProcessorTest, OneLevelNestedGrouping) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -1727,18 +1440,8 @@ TEST_P(QueryProcessorTest, ExcludeTerm) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that they'll bump the // last_added_document_id, which will give us the proper exclusion results @@ -1764,13 +1467,6 @@ TEST_P(QueryProcessorTest, ExcludeTerm) { AddTokenToIndex(document_id2, section_id, term_match_type, "world"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("-hello"); search_spec.set_term_match_type(term_match_type); @@ -1778,8 +1474,8 @@ TEST_P(QueryProcessorTest, ExcludeTerm) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch(search_spec, - ScoringSpecProto::RankingStrategy::NONE)); + query_processor_->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); // We don't know have the section mask to indicate what section "world" // came. It doesn't matter which section it was in since the query doesn't @@ -1795,18 +1491,8 @@ TEST_P(QueryProcessorTest, ExcludeNonexistentTerm) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that they'll bump the // last_added_document_id, which will give us the proper exclusion results @@ -1831,13 +1517,6 @@ TEST_P(QueryProcessorTest, ExcludeNonexistentTerm) { AddTokenToIndex(document_id2, section_id, term_match_type, "world"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("-foo"); search_spec.set_term_match_type(term_match_type); @@ -1845,8 +1524,8 @@ TEST_P(QueryProcessorTest, ExcludeNonexistentTerm) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch(search_spec, - ScoringSpecProto::RankingStrategy::NONE)); + query_processor_->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); // Descending order of valid DocumentIds EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), @@ -1861,18 +1540,8 @@ TEST_P(QueryProcessorTest, ExcludeAnd) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that they'll bump the // last_added_document_id, which will give us the proper exclusion results @@ -1905,13 +1574,6 @@ TEST_P(QueryProcessorTest, ExcludeAnd) { ASSERT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - { SearchSpecProto search_spec; search_spec.set_query("-dog -cat"); @@ -1920,7 +1582,7 @@ TEST_P(QueryProcessorTest, ExcludeAnd) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // The query is interpreted as "exclude all documents that have animal, @@ -1939,7 +1601,7 @@ TEST_P(QueryProcessorTest, ExcludeAnd) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // The query is interpreted as "exclude all documents that have animal, @@ -1957,18 +1619,8 @@ TEST_P(QueryProcessorTest, ExcludeOr) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that they'll bump the // last_added_document_id, which will give us the proper exclusion results @@ -2001,13 +1653,6 @@ TEST_P(QueryProcessorTest, ExcludeOr) { ASSERT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - { SearchSpecProto search_spec; search_spec.set_query("-animal OR -cat"); @@ -2016,7 +1661,7 @@ TEST_P(QueryProcessorTest, ExcludeOr) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // We don't have a section mask indicating which sections in this document @@ -2036,7 +1681,7 @@ TEST_P(QueryProcessorTest, ExcludeOr) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -2056,18 +1701,8 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // just inserting the documents so that the DocHitInfoIterators will see // that the document exists and not filter out the DocumentId as deleted. @@ -2110,12 +1745,6 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) { EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - // OR gets precedence over AND, this is parsed as (animal AND (puppy OR // kitten)) SearchSpecProto search_spec; @@ -2125,8 +1754,8 @@ TEST_P(QueryProcessorTest, WithoutTermFrequency) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch(search_spec, - ScoringSpecProto::RankingStrategy::NONE)); + query_processor_->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); // Since need_hit_term_frequency is false, the expected term frequencies // should all be 0. Hit::TermFrequencyArray exp_term_frequencies{0}; @@ -2175,18 +1804,8 @@ TEST_P(QueryProcessorTest, DeletedFilter) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // namespaces populated. @@ -2220,13 +1839,6 @@ TEST_P(QueryProcessorTest, DeletedFilter) { ASSERT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("animal"); search_spec.set_term_match_type(term_match_type); @@ -2234,7 +1846,7 @@ TEST_P(QueryProcessorTest, DeletedFilter) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -2252,18 +1864,8 @@ TEST_P(QueryProcessorTest, NamespaceFilter) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // namespaces populated. @@ -2296,13 +1898,6 @@ TEST_P(QueryProcessorTest, NamespaceFilter) { ASSERT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("animal"); search_spec.set_term_match_type(term_match_type); @@ -2311,7 +1906,7 @@ TEST_P(QueryProcessorTest, NamespaceFilter) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -2331,18 +1926,8 @@ TEST_P(QueryProcessorTest, SchemaTypeFilter) { .AddType(SchemaTypeConfigBuilder().SetType("email")) .AddType(SchemaTypeConfigBuilder().SetType("message")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // schema types populated. @@ -2371,13 +1956,6 @@ TEST_P(QueryProcessorTest, SchemaTypeFilter) { AddTokenToIndex(document_id2, section_id, term_match_type, "animal"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; search_spec.set_query("animal"); search_spec.set_term_match_type(term_match_type); @@ -2386,7 +1964,7 @@ TEST_P(QueryProcessorTest, SchemaTypeFilter) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -2411,18 +1989,8 @@ TEST_P(QueryProcessorTest, PropertyFilterForOneDocument) { .Build(); // First and only indexed property, so it gets a section_id of 0 int subject_section_id = 0; - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // schema types populated. @@ -2440,13 +2008,6 @@ TEST_P(QueryProcessorTest, PropertyFilterForOneDocument) { "animal"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; // Create a section filter '<section name>:<query term>' search_spec.set_query("subject:animal"); @@ -2455,7 +2016,7 @@ TEST_P(QueryProcessorTest, PropertyFilterForOneDocument) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds @@ -2496,18 +2057,8 @@ TEST_P(QueryProcessorTest, PropertyFilterAcrossSchemaTypes) { // alphabetically. int email_foo_section_id = 1; int message_foo_section_id = 0; - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // schema types populated. @@ -2535,13 +2086,6 @@ TEST_P(QueryProcessorTest, PropertyFilterAcrossSchemaTypes) { term_match_type, "animal"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; // Create a section filter '<section name>:<query term>' search_spec.set_query("foo:animal"); @@ -2550,7 +2094,7 @@ TEST_P(QueryProcessorTest, PropertyFilterAcrossSchemaTypes) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Ordered by descending DocumentId, so message comes first since it was @@ -2582,18 +2126,8 @@ TEST_P(QueryProcessorTest, PropertyFilterWithinSchemaType) { .Build(); int email_foo_section_id = 0; int message_foo_section_id = 0; - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // schema types populated. @@ -2621,13 +2155,6 @@ TEST_P(QueryProcessorTest, PropertyFilterWithinSchemaType) { term_match_type, "animal"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; // Create a section filter '<section name>:<query term>', but only look // within documents of email schema @@ -2638,7 +2165,7 @@ TEST_P(QueryProcessorTest, PropertyFilterWithinSchemaType) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Shouldn't include the message document since we're only looking at email @@ -2686,18 +2213,8 @@ TEST_P(QueryProcessorTest, NestedPropertyFilter) { TOKENIZER_PLAIN) .SetCardinality(CARDINALITY_OPTIONAL))) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // schema types populated. @@ -2715,13 +2232,6 @@ TEST_P(QueryProcessorTest, NestedPropertyFilter) { term_match_type, "animal"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; // Create a section filter '<section name>:<query term>', but only look // within documents of email schema @@ -2731,7 +2241,7 @@ TEST_P(QueryProcessorTest, NestedPropertyFilter) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Even though the section id is the same, we should be able to tell that it @@ -2763,18 +2273,8 @@ TEST_P(QueryProcessorTest, PropertyFilterRespectsDifferentSectionIds) { .Build(); int email_foo_section_id = 0; int message_foo_section_id = 0; - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // schema types populated. @@ -2804,13 +2304,6 @@ TEST_P(QueryProcessorTest, PropertyFilterRespectsDifferentSectionIds) { term_match_type, "animal"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; // Create a section filter '<section name>:<query term>', but only look // within documents of email schema @@ -2820,7 +2313,7 @@ TEST_P(QueryProcessorTest, PropertyFilterRespectsDifferentSectionIds) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Even though the section id is the same, we should be able to tell that it @@ -2839,18 +2332,8 @@ TEST_P(QueryProcessorTest, NonexistentPropertyFilterReturnsEmptyResults) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // schema types populated. @@ -2868,13 +2351,6 @@ TEST_P(QueryProcessorTest, NonexistentPropertyFilterReturnsEmptyResults) { term_match_type, "animal"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; // Create a section filter '<section name>:<query term>', but only look // within documents of email schema @@ -2884,7 +2360,7 @@ TEST_P(QueryProcessorTest, NonexistentPropertyFilterReturnsEmptyResults) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Even though the section id is the same, we should be able to tell that it @@ -2909,18 +2385,8 @@ TEST_P(QueryProcessorTest, UnindexedPropertyFilterReturnsEmptyResults) { .SetDataType(TYPE_STRING) .SetCardinality(CARDINALITY_OPTIONAL))) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // schema types populated. @@ -2938,13 +2404,6 @@ TEST_P(QueryProcessorTest, UnindexedPropertyFilterReturnsEmptyResults) { term_match_type, "animal"), IsOk()); - // Perform query - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; // Create a section filter '<section name>:<query term>', but only look // within documents of email schema @@ -2954,7 +2413,7 @@ TEST_P(QueryProcessorTest, UnindexedPropertyFilterReturnsEmptyResults) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Even though the section id is the same, we should be able to tell that it @@ -2982,18 +2441,8 @@ TEST_P(QueryProcessorTest, PropertyFilterTermAndUnrestrictedTerm) { .Build(); int email_foo_section_id = 0; int message_foo_section_id = 0; - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - DocumentStore::CreateResult create_result, - DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, - schema_store_.get())); - document_store_ = std::move(create_result.document_store); - // These documents don't actually match to the tokens in the index. We're // inserting the documents to get the appropriate number of documents and // schema types populated. @@ -3024,12 +2473,6 @@ TEST_P(QueryProcessorTest, PropertyFilterTermAndUnrestrictedTerm) { term_match_type, "animal"), IsOk()); - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); - SearchSpecProto search_spec; // Create a section filter '<section name>:<query term>' search_spec.set_query("cat OR foo:animal"); @@ -3038,7 +2481,7 @@ TEST_P(QueryProcessorTest, PropertyFilterTermAndUnrestrictedTerm) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch( + query_processor_->ParseSearch( search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Ordered by descending DocumentId, so message comes first since it was @@ -3060,10 +2503,6 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); // Arbitrary value, just has to be less than the document's creation @@ -3096,10 +2535,10 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) { // Perform query ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); + std::unique_ptr<QueryProcessor> local_query_processor, + QueryProcessor::Create(index_.get(), numeric_index_.get(), + language_segmenter_.get(), normalizer_.get(), + document_store_.get(), schema_store_.get())); SearchSpecProto search_spec; search_spec.set_query("hello"); @@ -3108,8 +2547,8 @@ TEST_P(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch(search_spec, - ScoringSpecProto::RankingStrategy::NONE)); + local_query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::NONE)); DocHitInfo expectedDocHitInfo(document_id); expectedDocHitInfo.UpdateSection(/*section_id=*/0); @@ -3122,10 +2561,6 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) { SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email")) .Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, - SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); // Arbitrary value, just has to be greater than the document's creation @@ -3158,10 +2593,10 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) { // Perform query ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<QueryProcessor> query_processor, - QueryProcessor::Create(index_.get(), language_segmenter_.get(), - normalizer_.get(), document_store_.get(), - schema_store_.get())); + std::unique_ptr<QueryProcessor> local_query_processor, + QueryProcessor::Create(index_.get(), numeric_index_.get(), + language_segmenter_.get(), normalizer_.get(), + document_store_.get(), schema_store_.get())); SearchSpecProto search_spec; search_spec.set_query("hello"); @@ -3170,21 +2605,117 @@ TEST_P(QueryProcessorTest, DocumentPastTtlFilteredOut) { ICING_ASSERT_OK_AND_ASSIGN( QueryResults results, - query_processor->ParseSearch(search_spec, - ScoringSpecProto::RankingStrategy::NONE)); + local_query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::NONE)); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), IsEmpty()); } +TEST_P(QueryProcessorTest, NumericFilter) { + if (GetParam() != + SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY) { + GTEST_SKIP() << "Numeric filter is only supported in advanced query."; + } + + // Create the schema and document store + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("transaction") + .AddProperty(PropertyConfigBuilder() + .SetName("price") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("cost") + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + // SectionIds are assigned alphabetically + SectionId cost_section_id = 0; + SectionId price_section_id = 1; + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_one_id, + document_store_->Put(DocumentBuilder() + .SetKey("namespace", "1") + .SetSchema("transaction") + .AddInt64Property("price", 10) + .Build())); + ICING_ASSERT_OK( + AddToNumericIndex(document_one_id, "price", price_section_id, 10)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_two_id, + document_store_->Put(DocumentBuilder() + .SetKey("namespace", "2") + .SetSchema("transaction") + .AddInt64Property("price", 25) + .Build())); + ICING_ASSERT_OK( + AddToNumericIndex(document_two_id, "price", price_section_id, 25)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_three_id, + document_store_->Put(DocumentBuilder() + .SetKey("namespace", "3") + .SetSchema("transaction") + .AddInt64Property("cost", 2) + .Build())); + ICING_ASSERT_OK( + AddToNumericIndex(document_three_id, "cost", cost_section_id, 2)); + + SearchSpecProto search_spec; + search_spec.set_query("price < 20"); + search_spec.set_search_type(GetParam()); + ICING_ASSERT_OK_AND_ASSIGN( + QueryResults results, + query_processor_->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); + EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), + ElementsAre(EqualsDocHitInfo( + document_one_id, std::vector<SectionId>{price_section_id}))); + + search_spec.set_query("price == 25"); + ICING_ASSERT_OK_AND_ASSIGN( + results, query_processor_->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::NONE)); + EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), + ElementsAre(EqualsDocHitInfo( + document_two_id, std::vector<SectionId>{price_section_id}))); + + search_spec.set_query("cost > 2"); + ICING_ASSERT_OK_AND_ASSIGN( + results, query_processor_->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::NONE)); + EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), IsEmpty()); + + search_spec.set_query("cost >= 2"); + ICING_ASSERT_OK_AND_ASSIGN( + results, query_processor_->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::NONE)); + EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), + ElementsAre(EqualsDocHitInfo( + document_three_id, std::vector<SectionId>{cost_section_id}))); + + search_spec.set_query("price <= 25"); + ICING_ASSERT_OK_AND_ASSIGN( + results, query_processor_->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::NONE)); + EXPECT_THAT( + GetDocHitInfos(results.root_iterator.get()), + ElementsAre(EqualsDocHitInfo(document_two_id, + std::vector<SectionId>{price_section_id}), + EqualsDocHitInfo(document_one_id, + std::vector<SectionId>{price_section_id}))); +} + INSTANTIATE_TEST_SUITE_P( QueryProcessorTest, QueryProcessorTest, -#ifdef ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY testing::Values( SearchSpecProto::SearchType::ICING_RAW_QUERY, SearchSpecProto::SearchType::EXPERIMENTAL_ICING_ADVANCED_QUERY)); -#else // !ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY - testing::Values(SearchSpecProto::SearchType::ICING_RAW_QUERY)); -#endif // ENABLE_EXPERIMENTAL_ICING_ADVANCED_QUERY } // namespace diff --git a/icing/result/result-retriever-v2.cc b/icing/result/result-retriever-v2.cc index c10c4e8..a51a8e6 100644 --- a/icing/result/result-retriever-v2.cc +++ b/icing/result/result-retriever-v2.cc @@ -113,16 +113,17 @@ std::pair<PageResult, bool> ResultRetrieverV2::RetrieveNextPage( int32_t num_total_bytes = 0; while (results.size() < result_state.num_per_page() && !result_state.scored_document_hits_ranker->empty()) { - ScoredDocumentHit next_best_document_hit = + JoinedScoredDocumentHit next_best_document_hit = result_state.scored_document_hits_ranker->PopNext(); if (group_result_limiter_->ShouldBeRemoved( - next_best_document_hit, result_state.namespace_group_id_map(), - doc_store_, result_state.group_result_limits)) { + next_best_document_hit.parent_scored_document_hit(), + result_state.namespace_group_id_map(), doc_store_, + result_state.group_result_limits)) { continue; } - libtextclassifier3::StatusOr<DocumentProto> document_or = - doc_store_.Get(next_best_document_hit.document_id()); + libtextclassifier3::StatusOr<DocumentProto> document_or = doc_store_.Get( + next_best_document_hit.parent_scored_document_hit().document_id()); if (!document_or.ok()) { // Skip the document if getting errors. ICING_LOG(WARNING) << "Fail to fetch document from document store: " @@ -147,14 +148,38 @@ std::pair<PageResult, bool> ResultRetrieverV2::RetrieveNextPage( SnippetProto snippet_proto = snippet_retriever_->RetrieveSnippet( snippet_context.query_terms, snippet_context.match_type, snippet_context.snippet_spec, document, - next_best_document_hit.hit_section_id_mask()); + next_best_document_hit.parent_scored_document_hit() + .hit_section_id_mask()); *result.mutable_snippet() = std::move(snippet_proto); ++num_results_with_snippets; } // Add the document, itself. *result.mutable_document() = std::move(document); - result.set_score(next_best_document_hit.score()); + result.set_score(next_best_document_hit.final_score()); + + // Retrieve child documents + for (const ScoredDocumentHit& child_scored_document_hit : + next_best_document_hit.child_scored_document_hits()) { + libtextclassifier3::StatusOr<DocumentProto> child_document_or = + doc_store_.Get(child_scored_document_hit.document_id()); + if (!child_document_or.ok()) { + // Skip the document if getting errors. + ICING_LOG(WARNING) + << "Fail to fetch child document from document store: " + << child_document_or.status().error_message(); + continue; + } + + DocumentProto child_document = std::move(child_document_or).ValueOrDie(); + // TODO(b/256022027): apply projection and add snippet for child doc + + SearchResultProto::ResultProto* child_result = + result.add_joined_results(); + *child_result->mutable_document() = std::move(child_document); + child_result->set_score(child_scored_document_hit.score()); + } + size_t result_bytes = result.ByteSizeLong(); results.push_back(std::move(result)); diff --git a/icing/result/result-retriever-v2_group-result-limiter_test.cc b/icing/result/result-retriever-v2_group-result-limiter_test.cc index e0a6c79..02c9bc6 100644 --- a/icing/result/result-retriever-v2_group-result-limiter_test.cc +++ b/icing/result/result-retriever-v2_group-result-limiter_test.cc @@ -162,7 +162,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest, // Creates a ResultState with 2 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), result_spec, @@ -219,7 +220,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest, // Creates a ResultState with 2 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), result_spec, @@ -292,7 +294,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest, // Creates a ResultState with 4 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), result_spec, @@ -376,7 +379,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest, // Creates a ResultState with 4 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), result_spec, @@ -433,7 +437,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest, // Creates a ResultState with 2 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), result_spec, @@ -534,7 +539,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest, // Creates a ResultState with 6 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), result_spec, @@ -592,7 +598,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest, // Creates a ResultState with 2 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), result_spec, @@ -687,7 +694,8 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest, // Creates a ResultState with 5 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), result_spec, diff --git a/icing/result/result-retriever-v2_projection_test.cc b/icing/result/result-retriever-v2_projection_test.cc index ec67caa..d093d1f 100644 --- a/icing/result/result-retriever-v2_projection_test.cc +++ b/icing/result/result-retriever-v2_projection_test.cc @@ -222,7 +222,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionTopLevelLeadNodeFieldPath) { // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -317,7 +318,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionNestedLeafNodeFieldPath) { // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -423,7 +425,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionIntermediateNodeFieldPath) { // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -533,7 +536,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionMultipleNestedFieldPaths) { // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -626,7 +630,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionEmptyFieldPath) { // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -702,7 +707,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionInvalidFieldPath) { // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -779,7 +785,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionValidAndInvalidFieldPath) { // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -858,7 +865,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionMultipleTypesNoWildcards) { // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -941,7 +949,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, ProjectionMultipleTypesWildcard) { // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -1028,7 +1037,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -1124,7 +1134,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -1224,7 +1235,8 @@ TEST_F(ResultRetrieverV2ProjectionTest, // 4. Create ResultState with custom ResultSpec. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/SectionRestrictQueryTermsMap{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), diff --git a/icing/result/result-retriever-v2_snippet_test.cc b/icing/result/result-retriever-v2_snippet_test.cc index 9384d6b..6123bf4 100644 --- a/icing/result/result-retriever-v2_snippet_test.cc +++ b/icing/result/result-retriever-v2_snippet_test.cc @@ -37,12 +37,12 @@ #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" #include "icing/testing/icu-data-file-helper.h" -#include "icing/testing/snippet-helpers.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" #include "icing/transform/normalizer-factory.h" #include "icing/transform/normalizer.h" +#include "icing/util/snippet-helpers.h" #include "unicode/uloc.h" namespace icing { @@ -225,7 +225,8 @@ TEST_F(ResultRetrieverV2SnippetTest, language_segmenter_.get(), normalizer_.get())); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), @@ -267,7 +268,8 @@ TEST_F(ResultRetrieverV2SnippetTest, SimpleSnippeted) { *result_spec.mutable_snippet_spec() = CreateSnippetSpec(); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/{{"", {"foo", "bar"}}}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -368,7 +370,8 @@ TEST_F(ResultRetrieverV2SnippetTest, OnlyOneDocumentSnippeted) { *result_spec.mutable_snippet_spec() = std::move(snippet_spec); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/{{"", {"foo", "bar"}}}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -437,7 +440,8 @@ TEST_F(ResultRetrieverV2SnippetTest, ShouldSnippetAllResults) { *result_spec.mutable_snippet_spec() = std::move(snippet_spec); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/{{"", {"foo", "bar"}}}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -483,7 +487,8 @@ TEST_F(ResultRetrieverV2SnippetTest, ShouldSnippetSomeResults) { *result_spec.mutable_snippet_spec() = std::move(snippet_spec); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/{{"", {"foo", "bar"}}}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -534,7 +539,8 @@ TEST_F(ResultRetrieverV2SnippetTest, ShouldNotSnippetAnyResults) { *result_spec.mutable_snippet_spec() = std::move(snippet_spec); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/{{"", {"foo", "bar"}}}, CreateSearchSpec(TermMatchType::EXACT_ONLY), diff --git a/icing/result/result-retriever-v2_test.cc b/icing/result/result-retriever-v2_test.cc index 0fb2ba0..6171688 100644 --- a/icing/result/result-retriever-v2_test.cc +++ b/icing/result/result-retriever-v2_test.cc @@ -289,7 +289,8 @@ TEST_F(ResultRetrieverV2Test, ShouldRetrieveSimpleResults) { result5.set_score(1); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), CreateScoringSpec(/*is_descending_order=*/true), @@ -366,7 +367,8 @@ TEST_F(ResultRetrieverV2Test, ShouldIgnoreNonInternalErrors) { result2.set_score(4); ResultStateV2 result_state1( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -383,7 +385,8 @@ TEST_F(ResultRetrieverV2Test, ShouldIgnoreNonInternalErrors) { {document_id1, hit_section_id_mask, /*score=*/12}, {document_id2, hit_section_id_mask, /*score=*/4}}; ResultStateV2 result_state2( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -432,7 +435,8 @@ TEST_F(ResultRetrieverV2Test, ShouldIgnoreInternalErrors) { result1.set_score(0); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -479,7 +483,8 @@ TEST_F(ResultRetrieverV2Test, ShouldUpdateResultState) { language_segmenter_.get(), normalizer_.get())); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -550,7 +555,8 @@ TEST_F(ResultRetrieverV2Test, ShouldUpdateNumTotalHits) { {document_id2, hit_section_id_mask, /*score=*/0}}; std::shared_ptr<ResultStateV2> result_state1 = std::make_shared<ResultStateV2>( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/SectionRestrictQueryTermsMap{}, @@ -576,7 +582,8 @@ TEST_F(ResultRetrieverV2Test, ShouldUpdateNumTotalHits) { {document_id5, hit_section_id_mask, /*score=*/0}}; std::shared_ptr<ResultStateV2> result_state2 = std::make_shared<ResultStateV2>( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/SectionRestrictQueryTermsMap{}, @@ -662,7 +669,8 @@ TEST_F(ResultRetrieverV2Test, ShouldLimitNumTotalBytesPerPage) { ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); result_spec.set_num_total_bytes_per_page_threshold(result1.ByteSizeLong()); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -723,7 +731,8 @@ TEST_F(ResultRetrieverV2Test, ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); result_spec.set_num_total_bytes_per_page_threshold(threshold); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -783,7 +792,8 @@ TEST_F(ResultRetrieverV2Test, ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); result_spec.set_num_total_bytes_per_page_threshold(threshold); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), diff --git a/icing/result/result-retriever_test.cc b/icing/result/result-retriever_test.cc index e0b4875..044e0f2 100644 --- a/icing/result/result-retriever_test.cc +++ b/icing/result/result-retriever_test.cc @@ -36,12 +36,12 @@ #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" #include "icing/testing/icu-data-file-helper.h" -#include "icing/testing/snippet-helpers.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" #include "icing/transform/normalizer-factory.h" #include "icing/transform/normalizer.h" +#include "icing/util/snippet-helpers.h" #include "unicode/uloc.h" namespace icing { diff --git a/icing/result/result-state-manager_test.cc b/icing/result/result-state-manager_test.cc index 7025c63..e7acc31 100644 --- a/icing/result/result-state-manager_test.cc +++ b/icing/result/result-state-manager_test.cc @@ -183,9 +183,9 @@ TEST_F(ResultStateManagerTest, ShouldCacheAndRetrieveFirstPageOnePage) { {document_id1, kSectionIdMaskNone, /*score=*/1}, {document_id2, kSectionIdMaskNone, /*score=*/1}, {document_id3, kSectionIdMaskNone, /*score=*/1}}; - std::unique_ptr<ScoredDocumentHitsRanker> ranker = - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( - std::move(scored_document_hits), /*is_descending=*/true); + std::unique_ptr<ScoredDocumentHitsRanker> ranker = std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( + std::move(scored_document_hits), /*is_descending=*/true); ResultStateManager result_state_manager( /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), @@ -228,9 +228,9 @@ TEST_F(ResultStateManagerTest, ShouldCacheAndRetrieveFirstPageMultiplePages) { {document_id3, kSectionIdMaskNone, /*score=*/1}, {document_id4, kSectionIdMaskNone, /*score=*/1}, {document_id5, kSectionIdMaskNone, /*score=*/1}}; - std::unique_ptr<ScoredDocumentHitsRanker> ranker = - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( - std::move(scored_document_hits), /*is_descending=*/true); + std::unique_ptr<ScoredDocumentHitsRanker> ranker = std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( + std::move(scored_document_hits), /*is_descending=*/true); ResultStateManager result_state_manager( /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), @@ -299,7 +299,8 @@ TEST_F(ResultStateManagerTest, EmptyRankerShouldReturnEmptyFirstPage) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::vector<ScoredDocumentHit>(), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -333,7 +334,8 @@ TEST_F(ResultStateManagerTest, ShouldAllowEmptyFirstPage) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), result_spec, document_store(), @@ -373,7 +375,8 @@ TEST_F(ResultStateManagerTest, ShouldAllowEmptyLastPage) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), result_spec, document_store(), @@ -417,7 +420,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), query_terms, search_spec, scoring_spec, result_spec, document_store(), result_retriever())); @@ -428,7 +432,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), query_terms, search_spec, scoring_spec, result_spec, document_store(), result_retriever())); @@ -462,7 +467,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -474,7 +480,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -516,7 +523,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -561,7 +569,8 @@ TEST_F(ResultStateManagerTest, ShouldInvalidateOneToken) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -570,7 +579,8 @@ TEST_F(ResultStateManagerTest, ShouldInvalidateOneToken) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -607,7 +617,8 @@ TEST_F(ResultStateManagerTest, ShouldInvalidateAllTokens) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -616,7 +627,8 @@ TEST_F(ResultStateManagerTest, ShouldInvalidateAllTokens) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -649,7 +661,8 @@ TEST_F(ResultStateManagerTest, ShouldRemoveOldestResultState) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -658,7 +671,8 @@ TEST_F(ResultStateManagerTest, ShouldRemoveOldestResultState) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -668,7 +682,8 @@ TEST_F(ResultStateManagerTest, ShouldRemoveOldestResultState) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info3, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits3), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -713,7 +728,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -722,7 +738,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -731,7 +748,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info3, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits3), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -749,7 +767,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info4, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits4), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -801,7 +820,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -810,7 +830,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -819,7 +840,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info3, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits3), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -841,7 +863,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info4, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits4), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -850,7 +873,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info5, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits5), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -859,7 +883,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info6, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits6), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -920,7 +945,8 @@ TEST_F( ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -929,7 +955,8 @@ TEST_F( ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -938,7 +965,8 @@ TEST_F( ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info3, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits3), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -956,7 +984,8 @@ TEST_F( ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info4, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits4), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -970,7 +999,8 @@ TEST_F( ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info5, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits5), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1025,7 +1055,8 @@ TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1034,7 +1065,8 @@ TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1043,7 +1075,8 @@ TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info3, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits3), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1066,7 +1099,8 @@ TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info4, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits4), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1118,7 +1152,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1127,7 +1162,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1136,7 +1172,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info3, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits3), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1159,7 +1196,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info4, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits4), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1173,7 +1211,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info5, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits5), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1226,7 +1265,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1235,7 +1275,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1251,7 +1292,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info3, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits3), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1324,7 +1366,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1337,7 +1380,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info2, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), @@ -1374,7 +1418,8 @@ TEST_F(ResultStateManagerTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/2), diff --git a/icing/result/result-state-manager_thread-safety_test.cc b/icing/result/result-state-manager_thread-safety_test.cc index 523f84a..0da37d8 100644 --- a/icing/result/result-state-manager_thread-safety_test.cc +++ b/icing/result/result-state-manager_thread-safety_test.cc @@ -160,7 +160,8 @@ TEST_F(ResultStateManagerThreadSafetyTest, ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(kNumPerPage), *document_store_, @@ -260,7 +261,8 @@ TEST_F(ResultStateManagerThreadSafetyTest, InvalidateResultStateWhileUsing) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/false), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(kNumPerPage), *document_store_, @@ -389,7 +391,8 @@ TEST_F(ResultStateManagerThreadSafetyTest, MultipleResultStates) { ICING_ASSERT_OK_AND_ASSIGN( PageResultInfo page_result_info1, result_state_manager.CacheAndRetrieveFirstPage( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits_copy), /*is_descending=*/false), /*query_terms=*/{}, SearchSpecProto::default_instance(), CreateScoringSpec(), CreateResultSpec(kNumPerPage), diff --git a/icing/result/result-state-v2_test.cc b/icing/result/result-state-v2_test.cc index 7255958..f32546a 100644 --- a/icing/result/result-state-v2_test.cc +++ b/icing/result/result-state-v2_test.cc @@ -130,7 +130,8 @@ TEST_F(ResultStateV2Test, ShouldInitializeValuesAccordingToSpecs) { result_spec.set_num_total_bytes_per_page_threshold(4096); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::vector<ScoredDocumentHit>(), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -152,7 +153,8 @@ TEST_F(ResultStateV2Test, ShouldInitializeValuesAccordingToDefaultSpecs) { Eq(std::numeric_limits<int32_t>::max())); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::vector<ScoredDocumentHit>(), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -178,7 +180,8 @@ TEST_F(ResultStateV2Test, ShouldReturnSnippetContextAccordingToSpecs) { query_terms_map.emplace("term1", std::unordered_set<std::string>()); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::vector<ScoredDocumentHit>(), /*is_descending=*/true), query_terms_map, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -217,7 +220,8 @@ TEST_F(ResultStateV2Test, NoSnippetingShouldReturnNull) { query_terms_map.emplace("term1", std::unordered_set<std::string>()); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::vector<ScoredDocumentHit>(), /*is_descending=*/true), query_terms_map, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -253,7 +257,8 @@ TEST_F(ResultStateV2Test, ShouldConstructProjectionTreeMapAccordingToSpecs) { wildcard_type_property_mask->add_paths("wild.card"); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::vector<ScoredDocumentHit>(), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -319,7 +324,8 @@ TEST_F(ResultStateV2Test, document_store().GetNamespaceId("namespace3")); ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::vector<ScoredDocumentHit>(), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -352,7 +358,8 @@ TEST_F(ResultStateV2Test, ShouldUpdateNumTotalHits) { // Creates a ResultState with 5 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -384,7 +391,8 @@ TEST_F(ResultStateV2Test, ShouldUpdateNumTotalHitsWhenDestructed) { { // Creates a ResultState with 5 ScoredDocumentHits. ResultStateV2 result_state1( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits1), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -399,7 +407,8 @@ TEST_F(ResultStateV2Test, ShouldUpdateNumTotalHitsWhenDestructed) { { // Creates another ResultState with 2 ScoredDocumentHits. ResultStateV2 result_state2( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits2), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -428,7 +437,8 @@ TEST_F(ResultStateV2Test, ShouldNotUpdateNumTotalHitsWhenNotRegistered) { // Creates a ResultState with 5 ScoredDocumentHits. { ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), @@ -458,7 +468,8 @@ TEST_F(ResultStateV2Test, ShouldDecrementOriginalNumTotalHitsWhenReregister) { // Creates a ResultState with 5 ScoredDocumentHits. ResultStateV2 result_state( - std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::make_unique< + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>>( std::move(scored_document_hits), /*is_descending=*/true), /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), diff --git a/icing/result/snippet-retriever_test.cc b/icing/result/snippet-retriever_test.cc index 0940b51..80d00d5 100644 --- a/icing/result/snippet-retriever_test.cc +++ b/icing/result/snippet-retriever_test.cc @@ -38,7 +38,6 @@ #include "icing/testing/fake-clock.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/jni-test-helpers.h" -#include "icing/testing/snippet-helpers.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" @@ -46,6 +45,7 @@ #include "icing/transform/map/map-normalizer.h" #include "icing/transform/normalizer-factory.h" #include "icing/transform/normalizer.h" +#include "icing/util/snippet-helpers.h" #include "unicode/uloc.h" namespace icing { diff --git a/icing/scoring/advanced_scoring/advanced-scorer.cc b/icing/scoring/advanced_scoring/advanced-scorer.cc new file mode 100644 index 0000000..9d52fde --- /dev/null +++ b/icing/scoring/advanced_scoring/advanced-scorer.cc @@ -0,0 +1,58 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/advanced_scoring/advanced-scorer.h" + +#include <memory> + +#include "icing/query/advanced_query_parser/lexer.h" +#include "icing/query/advanced_query_parser/parser.h" +#include "icing/scoring/advanced_scoring/score-expression.h" +#include "icing/scoring/advanced_scoring/scoring-visitor.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> +AdvancedScorer::Create(const ScoringSpecProto& scoring_spec, + double default_score, + const DocumentStore* document_store, + const SchemaStore* schema_store) { + ICING_RETURN_ERROR_IF_NULL(document_store); + ICING_RETURN_ERROR_IF_NULL(schema_store); + + Lexer lexer(scoring_spec.advanced_scoring_expression(), + Lexer::Language::SCORING); + ICING_ASSIGN_OR_RETURN(std::vector<Lexer::LexerToken> lexer_tokens, + lexer.ExtractTokens()); + Parser parser = Parser::Create(std::move(lexer_tokens)); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> tree_root, + parser.ConsumeScoring()); + + ScoringVisitor visitor(default_score); + tree_root->Accept(&visitor); + + ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression, + std::move(visitor).Expression()); + if (expression->is_document_type()) { + return absl_ports::InvalidArgumentError( + "The root scoring expression will always be evaluated to a document, " + "but a number is expected."); + } + return std::unique_ptr<AdvancedScorer>( + new AdvancedScorer(std::move(expression), default_score)); +} + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/advanced_scoring/advanced-scorer.h b/icing/scoring/advanced_scoring/advanced-scorer.h new file mode 100644 index 0000000..6557ba6 --- /dev/null +++ b/icing/scoring/advanced_scoring/advanced-scorer.h @@ -0,0 +1,66 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ +#define ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ + +#include <memory> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/schema/schema-store.h" +#include "icing/scoring/advanced_scoring/score-expression.h" +#include "icing/scoring/scorer.h" +#include "icing/store/document-store.h" + +namespace icing { +namespace lib { + +class AdvancedScorer : public Scorer { + public: + // Returns: + // A AdvancedScorer instance on success + // FAILED_PRECONDITION on any null pointer input + // INVALID_ARGUMENT if fails to create an instance + static libtextclassifier3::StatusOr<std::unique_ptr<AdvancedScorer>> Create( + const ScoringSpecProto& scoring_spec, double default_score, + const DocumentStore* document_store, const SchemaStore* schema_store); + + double GetScore(const DocHitInfo& hit_info, + const DocHitInfoIterator* query_it) override { + libtextclassifier3::StatusOr<double> result = + score_expression_->eval(hit_info, query_it); + if (!result.ok()) { + ICING_LOG(ERROR) << "Got an error when scoring a document:\n" + << result.status().error_message(); + return default_score_; + } + return std::move(result).ValueOrDie(); + } + + private: + explicit AdvancedScorer(std::unique_ptr<ScoreExpression> score_expression, + double default_score) + : score_expression_(std::move(score_expression)), + default_score_(default_score) {} + + std::unique_ptr<ScoreExpression> score_expression_; + double default_score_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_ADVANCED_SCORING_ADVANCED_SCORER_H_ diff --git a/icing/scoring/advanced_scoring/advanced-scorer_test.cc b/icing/scoring/advanced_scoring/advanced-scorer_test.cc new file mode 100644 index 0000000..0d3a05c --- /dev/null +++ b/icing/scoring/advanced_scoring/advanced-scorer_test.cc @@ -0,0 +1,404 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/advanced_scoring/advanced-scorer.h" + +#include <cmath> +#include <memory> +#include <string> +#include <string_view> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/proto/document.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/proto/scoring.pb.h" +#include "icing/proto/usage.pb.h" +#include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/scoring/scorer-factory.h" +#include "icing/scoring/scorer.h" +#include "icing/store/document-id.h" +#include "icing/store/document-store.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { +using ::testing::DoubleNear; +using ::testing::Eq; + +class AdvancedScorerTest : public testing::Test { + protected: + AdvancedScorerTest() + : test_dir_(GetTestTempDir() + "/icing"), + doc_store_dir_(test_dir_ + "/doc_store"), + schema_store_dir_(test_dir_ + "/schema_store") {} + + void SetUp() override { + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + filesystem_.CreateDirectoryRecursively(doc_store_dir_.c_str()); + filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, doc_store_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + + // Creates a simple email schema + SchemaProto test_email_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder().SetType("email").AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + ICING_ASSERT_OK(schema_store_->SetSchema(test_email_schema)); + } + + void TearDown() override { + document_store_.reset(); + schema_store_.reset(); + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + const std::string test_dir_; + const std::string doc_store_dir_; + const std::string schema_store_dir_; + Filesystem filesystem_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<DocumentStore> document_store_; + FakeClock fake_clock_; +}; + +constexpr double kEps = 0.0000000001; +constexpr int kDefaultScore = 0; +constexpr int64_t kDefaultCreationTimestampMs = 1571100001111; + +DocumentProto CreateDocument( + const std::string& name_space, const std::string& uri, + int score = kDefaultScore, + int64_t creation_timestamp_ms = kDefaultCreationTimestampMs) { + return DocumentBuilder() + .SetKey(name_space, uri) + .SetSchema("email") + .SetScore(score) + .SetCreationTimestampMs(creation_timestamp_ms) + .Build(); +} + +ScoringSpecProto CreateAdvancedScoringSpec( + const std::string& advanced_scoring_expression) { + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by( + ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION); + scoring_spec.set_advanced_scoring_expression(advanced_scoring_expression); + return scoring_spec; +} + +TEST_F(AdvancedScorerTest, InvalidAdvancedScoringSpec) { + // Empty scoring expression for advanced scoring + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by( + ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION); + EXPECT_THAT( + scorer_factory::Create(scoring_spec, /*default_score=*/10, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + // Non-empty scoring expression for normal scoring + scoring_spec = ScoringSpecProto::default_instance(); + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + scoring_spec.set_advanced_scoring_expression("1"); + EXPECT_THAT( + scorer_factory::Create(scoring_spec, /*default_score=*/10, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(AdvancedScorerTest, SimpleExpression) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(CreateDocument("namespace", "uri"))); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("123"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + + DocHitInfo docHitInfo = DocHitInfo(document_id); + + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(123)); +} + +TEST_F(AdvancedScorerTest, BasicPureArithmeticExpression) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(CreateDocument("namespace", "uri"))); + DocHitInfo docHitInfo = DocHitInfo(document_id); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + 2"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(3)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("-1 + 2"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + -2"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 - 2"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(-1)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 * 2"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(2)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 2"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0.5)); +} + +TEST_F(AdvancedScorerTest, BasicMathFunctionExpression) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(CreateDocument("namespace", "uri"))); + DocHitInfo docHitInfo = DocHitInfo(document_id); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("log(10, 1000)"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(3, kEps)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("log(2.718281828459045)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(2, 10)"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(1024)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("max(10, 11, 12, 13, 14)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(14)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("min(10, 11, 12, 13, 14)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(2)"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(sqrt(2), kEps)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(-2) + abs(2)"), + /*default_score=*/10, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(4)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("sin(3.141592653589793)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(0, kEps)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("cos(3.141592653589793)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(-1, kEps)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create( + CreateAdvancedScoringSpec("tan(3.141592653589793 / 4)"), + /*default_score=*/10, document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(1, kEps)); +} + +// Should be a parsing Error +TEST_F(AdvancedScorerTest, EmptyExpression) { + EXPECT_THAT( + AdvancedScorer::Create(CreateAdvancedScoringSpec(""), + /*default_score=*/10, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(AdvancedScorerTest, EvaluationErrorShouldReturnDefaultScore) { + const double default_score = 123; + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store_->Put(CreateDocument("namespace", "uri"))); + DocHitInfo docHitInfo = DocHitInfo(document_id); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Scorer> scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("log(0)"), default_score, + document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, + AdvancedScorer::Create(CreateAdvancedScoringSpec("1 / 0"), default_score, + document_store_.get(), schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(-1)"), + default_score, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); + + ICING_ASSERT_OK_AND_ASSIGN( + scorer, AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(-1, 0.5)"), + default_score, document_store_.get(), + schema_store_.get())); + EXPECT_THAT(scorer->GetScore(docHitInfo), DoubleNear(default_score, kEps)); +} + +// The following tests should trigger a type error while the visitor tries to +// build a ScoreExpression object. +TEST_F(AdvancedScorerTest, MathTypeError) { + const double default_score = 0; + + EXPECT_THAT( + AdvancedScorer::Create(CreateAdvancedScoringSpec("test"), default_score, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT( + AdvancedScorer::Create(CreateAdvancedScoringSpec("log()"), default_score, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, 2, 3)"), + default_score, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("log(1, this)"), + default_score, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT( + AdvancedScorer::Create(CreateAdvancedScoringSpec("pow(1)"), default_score, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sqrt(1, 2)"), + default_score, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("abs(1, 2)"), + default_score, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("sin(1, 2)"), + default_score, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("cos(1, 2)"), + default_score, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("tan(1, 2)"), + default_score, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT( + AdvancedScorer::Create(CreateAdvancedScoringSpec("this"), default_score, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT( + AdvancedScorer::Create(CreateAdvancedScoringSpec("-this"), default_score, + document_store_.get(), schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + EXPECT_THAT(AdvancedScorer::Create(CreateAdvancedScoringSpec("1 + this"), + default_score, document_store_.get(), + schema_store_.get()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/advanced_scoring/score-expression.cc b/icing/scoring/advanced_scoring/score-expression.cc new file mode 100644 index 0000000..cd77046 --- /dev/null +++ b/icing/scoring/advanced_scoring/score-expression.cc @@ -0,0 +1,203 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/advanced_scoring/score-expression.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>> +OperatorScoreExpression::Create( + OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) { + if (children.empty()) { + return absl_ports::InvalidArgumentError( + "OperatorScoreExpression must have at least one argument."); + } + for (const auto& child : children) { + ICING_RETURN_ERROR_IF_NULL(child); + if (child->is_document_type()) { + return absl_ports::InvalidArgumentError( + "Operators are not supported for document type."); + } + } + if (op == OperatorType::kNegative) { + if (children.size() != 1) { + return absl_ports::InvalidArgumentError( + "Negative operator must have only 1 argument."); + } + } + return std::unique_ptr<OperatorScoreExpression>( + new OperatorScoreExpression(op, std::move(children))); +} + +libtextclassifier3::StatusOr<double> OperatorScoreExpression::eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) { + // The Create factory guarantees that an operator will have at least one + // child. + ICING_ASSIGN_OR_RETURN(double res, children_.at(0)->eval(hit_info, query_it)); + + if (op_ == OperatorType::kNegative) { + return -res; + } + + for (int i = 1; i < children_.size(); ++i) { + ICING_ASSIGN_OR_RETURN(double v, children_.at(i)->eval(hit_info, query_it)); + switch (op_) { + case OperatorType::kPlus: + res += v; + break; + case OperatorType::kMinus: + res -= v; + break; + case OperatorType::kTimes: + res *= v; + break; + case OperatorType::kDiv: + res /= v; + break; + case OperatorType::kNegative: + return absl_ports::InternalError("Should never reach here."); + } + if (!std::isfinite(res)) { + return absl_ports::InvalidArgumentError( + "Got a non-finite value while evaluating operator score expression."); + } + } + return res; +} + +const std::unordered_map<std::string, MathFunctionScoreExpression::FunctionType> + MathFunctionScoreExpression::kFunctionNames = { + {"log", FunctionType::kLog}, {"pow", FunctionType::kPow}, + {"max", FunctionType::kMax}, {"min", FunctionType::kMin}, + {"sqrt", FunctionType::kSqrt}, {"abs", FunctionType::kAbs}, + {"sin", FunctionType::kSin}, {"cos", FunctionType::kCos}, + {"tan", FunctionType::kTan}}; + +libtextclassifier3::StatusOr<std::unique_ptr<MathFunctionScoreExpression>> +MathFunctionScoreExpression::Create( + FunctionType function_type, + std::vector<std::unique_ptr<ScoreExpression>> children) { + if (children.empty()) { + return absl_ports::InvalidArgumentError( + "Math functions must have at least one argument."); + } + for (const auto& child : children) { + ICING_RETURN_ERROR_IF_NULL(child); + if (child->is_document_type()) { + return absl_ports::InvalidArgumentError( + "Math functions are not supported for document type."); + } + } + switch (function_type) { + case FunctionType::kLog: + if (children.size() != 1 && children.size() != 2) { + return absl_ports::InvalidArgumentError( + "log must have 1 or 2 arguments."); + } + break; + case FunctionType::kPow: + if (children.size() != 2) { + return absl_ports::InvalidArgumentError("pow must have 2 arguments."); + } + break; + case FunctionType::kSqrt: + if (children.size() != 1) { + return absl_ports::InvalidArgumentError("sqrt must have 1 argument."); + } + break; + case FunctionType::kAbs: + if (children.size() != 1) { + return absl_ports::InvalidArgumentError("abs must have 1 argument."); + } + break; + case FunctionType::kSin: + if (children.size() != 1) { + return absl_ports::InvalidArgumentError("sin must have 1 argument."); + } + break; + case FunctionType::kCos: + if (children.size() != 1) { + return absl_ports::InvalidArgumentError("cos must have 1 argument."); + } + break; + case FunctionType::kTan: + if (children.size() != 1) { + return absl_ports::InvalidArgumentError("tan must have 1 argument."); + } + break; + // max and min support variable length arguments + case FunctionType::kMax: + [[fallthrough]]; + case FunctionType::kMin: + break; + } + return std::unique_ptr<MathFunctionScoreExpression>( + new MathFunctionScoreExpression(function_type, std::move(children))); +} + +libtextclassifier3::StatusOr<double> MathFunctionScoreExpression::eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) { + std::vector<double> values; + for (const auto& child : children_) { + ICING_ASSIGN_OR_RETURN(double v, child->eval(hit_info, query_it)); + values.push_back(v); + } + + double res = 0; + switch (function_type_) { + case FunctionType::kLog: + if (values.size() == 1) { + res = log(values[0]); + } else { + // argument 0 is log base + // argument 1 is the value + res = log(values[1]) / log(values[0]); + } + break; + case FunctionType::kPow: + res = pow(values[0], values[1]); + break; + case FunctionType::kMax: + res = *std::max_element(values.begin(), values.end()); + break; + case FunctionType::kMin: + res = *std::min_element(values.begin(), values.end()); + break; + case FunctionType::kSqrt: + res = sqrt(values[0]); + break; + case FunctionType::kAbs: + res = abs(values[0]); + break; + case FunctionType::kSin: + res = sin(values[0]); + break; + case FunctionType::kCos: + res = cos(values[0]); + break; + case FunctionType::kTan: + res = tan(values[0]); + break; + } + if (!std::isfinite(res)) { + return absl_ports::InvalidArgumentError( + "Got a non-finite value while evaluating math function score " + "expression."); + } + return res; +} + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/advanced_scoring/score-expression.h b/icing/scoring/advanced_scoring/score-expression.h new file mode 100644 index 0000000..0e0c538 --- /dev/null +++ b/icing/scoring/advanced_scoring/score-expression.h @@ -0,0 +1,154 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ +#define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ + +#include <algorithm> +#include <cmath> +#include <memory> +#include <unordered_map> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +// TODO(b/261474063) Simplify every ScoreExpression node to +// ConstantScoreExpression if its evaluation does not depend on a document. +class ScoreExpression { + public: + virtual ~ScoreExpression() = default; + + // Evaluate the score expression to double with the current document. + // + // RETURNS: + // - The evaluated result as a double on success. + // - INVALID_ARGUMENT if a non-finite value is reached while evaluating the + // expression. + // - INTERNAL if there are inconsistencies. + virtual libtextclassifier3::StatusOr<double> eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) = 0; + + // Indicate whether the current expression is of document type + virtual bool is_document_type() const { return false; } +}; + +class ThisExpression : public ScoreExpression { + public: + static std::unique_ptr<ThisExpression> Create() { + return std::unique_ptr<ThisExpression>(new ThisExpression()); + } + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override { + return absl_ports::InternalError( + "Should never reach here to evaluate a document type as double. " + "There must be inconsistencies."); + } + + bool is_document_type() const override { return true; } + + private: + ThisExpression() = default; +}; + +class ConstantScoreExpression : public ScoreExpression { + public: + static std::unique_ptr<ConstantScoreExpression> Create(double c) { + return std::unique_ptr<ConstantScoreExpression>( + new ConstantScoreExpression(c)); + } + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo&, const DocHitInfoIterator*) override { + return c_; + } + + private: + explicit ConstantScoreExpression(double c) : c_(c) {} + + double c_; +}; + +class OperatorScoreExpression : public ScoreExpression { + public: + enum class OperatorType { kPlus, kMinus, kNegative, kTimes, kDiv }; + + // RETURNS: + // - An OperatorScoreExpression instance on success. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>> + Create(OperatorType op, + std::vector<std::unique_ptr<ScoreExpression>> children); + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override; + + private: + explicit OperatorScoreExpression( + OperatorType op, std::vector<std::unique_ptr<ScoreExpression>> children) + : op_(op), children_(std::move(children)) {} + + OperatorType op_; + std::vector<std::unique_ptr<ScoreExpression>> children_; +}; + +class MathFunctionScoreExpression : public ScoreExpression { + public: + enum class FunctionType { + kLog, + kPow, + kMax, + kMin, + kSqrt, + kAbs, + kSin, + kCos, + kTan + }; + + static const std::unordered_map<std::string, FunctionType> kFunctionNames; + + // RETURNS: + // - A MathFunctionScoreExpression instance on success. + // - FAILED_PRECONDITION on any null pointer in children. + // - INVALID_ARGUMENT on type errors. + static libtextclassifier3::StatusOr< + std::unique_ptr<MathFunctionScoreExpression>> + Create(FunctionType function_type, + std::vector<std::unique_ptr<ScoreExpression>> children); + + libtextclassifier3::StatusOr<double> eval( + const DocHitInfo& hit_info, const DocHitInfoIterator* query_it) override; + + private: + explicit MathFunctionScoreExpression( + FunctionType function_type, + std::vector<std::unique_ptr<ScoreExpression>> children) + : function_type_(function_type), children_(std::move(children)) {} + + FunctionType function_type_; + std::vector<std::unique_ptr<ScoreExpression>> children_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_H_ diff --git a/icing/scoring/advanced_scoring/scoring-visitor.cc b/icing/scoring/advanced_scoring/scoring-visitor.cc new file mode 100644 index 0000000..7737213 --- /dev/null +++ b/icing/scoring/advanced_scoring/scoring-visitor.cc @@ -0,0 +1,159 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/advanced_scoring/scoring-visitor.h" + +#include "icing/absl_ports/str_cat.h" + +namespace icing { +namespace lib { + +void ScoringVisitor::VisitFunctionName(const FunctionNameNode* node) { + pending_error_ = absl_ports::InternalError( + "FunctionNameNode should be handled in VisitFunction!"); +} + +void ScoringVisitor::VisitString(const StringNode* node) { + pending_error_ = + absl_ports::InvalidArgumentError("Scoring does not support String!"); +} + +void ScoringVisitor::VisitText(const TextNode* node) { + pending_error_ = + absl_ports::InternalError("TextNode should be handled in VisitMember!"); +} + +void ScoringVisitor::VisitMember(const MemberNode* node) { + std::string value; + if (node->children().size() == 1) { + // If a member has only one child, then it can be a numeric literal, + // or "this" if the member is a reference to a member function. + value = node->children()[0]->value(); + if (value == "this") { + stack.push_back(ThisExpression::Create()); + return; + } + } else if (node->children().size() == 2) { + // If a member has two children, then it can only represent a floating point + // number, so we need to join them by "." to build the numeric literal. + value = absl_ports::StrCat(node->children()[0]->value(), ".", + node->children()[1]->value()); + } else { + pending_error_ = absl_ports::InvalidArgumentError( + "MemberNode must have 1 or 2 children."); + return; + } + char* end; + double number = std::strtod(value.c_str(), &end); + if (end != value.c_str() + value.length()) { + // While it would be doable to support property references in the scoring + // grammar, we currently don't have an efficient way to support such a + // lookup (we'd have to read each document). As such, it's simpler to just + // restrict the scoring language to not include properties. + pending_error_ = absl_ports::InvalidArgumentError( + absl_ports::StrCat("Expect a numeric literal, but got ", value)); + return; + } + stack.push_back(ConstantScoreExpression::Create(number)); +} + +void ScoringVisitor::VisitFunction(const FunctionNode* node) { + std::vector<std::unique_ptr<ScoreExpression>> children; + for (const auto& arg : node->args()) { + arg->Accept(this); + if (has_pending_error()) { + return; + } + children.push_back(pop_stack()); + } + const std::string& function_name = node->function_name()->value(); + libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> expression = + absl_ports::InvalidArgumentError( + absl_ports::StrCat("Unknown function: ", function_name)); + + // Math functions + if (MathFunctionScoreExpression::kFunctionNames.find(function_name) != + MathFunctionScoreExpression::kFunctionNames.end()) { + expression = MathFunctionScoreExpression::Create( + MathFunctionScoreExpression::kFunctionNames.at(function_name), + std::move(children)); + } + + if (!expression.ok()) { + pending_error_ = expression.status(); + return; + } + stack.push_back(std::move(expression).ValueOrDie()); +} + +void ScoringVisitor::VisitUnaryOperator(const UnaryOperatorNode* node) { + if (node->operator_text() != "MINUS") { + pending_error_ = absl_ports::InvalidArgumentError( + absl_ports::StrCat("Unknown unary operator: ", node->operator_text())); + return; + } + node->child()->Accept(this); + if (has_pending_error()) { + return; + } + std::vector<std::unique_ptr<ScoreExpression>> children; + children.push_back(pop_stack()); + + libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>> + expression = OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kNegative, + std::move(children)); + if (!expression.ok()) { + pending_error_ = expression.status(); + return; + } + stack.push_back(std::move(expression).ValueOrDie()); +} + +void ScoringVisitor::VisitNaryOperator(const NaryOperatorNode* node) { + std::vector<std::unique_ptr<ScoreExpression>> children; + for (const auto& arg : node->children()) { + arg->Accept(this); + if (has_pending_error()) { + return; + } + children.push_back(pop_stack()); + } + + libtextclassifier3::StatusOr<std::unique_ptr<OperatorScoreExpression>> + expression = absl_ports::InvalidArgumentError( + absl_ports::StrCat("Unknown Nary operator: ", node->operator_text())); + + if (node->operator_text() == "PLUS") { + expression = OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kPlus, std::move(children)); + } else if (node->operator_text() == "MINUS") { + expression = OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kMinus, std::move(children)); + } else if (node->operator_text() == "TIMES") { + expression = OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kTimes, std::move(children)); + } else if (node->operator_text() == "DIV") { + expression = OperatorScoreExpression::Create( + OperatorScoreExpression::OperatorType::kDiv, std::move(children)); + } + if (!expression.ok()) { + pending_error_ = expression.status(); + return; + } + stack.push_back(std::move(expression).ValueOrDie()); +} + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/advanced_scoring/scoring-visitor.h b/icing/scoring/advanced_scoring/scoring-visitor.h new file mode 100644 index 0000000..47a03fd --- /dev/null +++ b/icing/scoring/advanced_scoring/scoring-visitor.h @@ -0,0 +1,77 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ +#define ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/legacy/core/icing-string-util.h" +#include "icing/proto/scoring.pb.h" +#include "icing/query/advanced_query_parser/abstract-syntax-tree.h" +#include "icing/scoring/advanced_scoring/score-expression.h" + +namespace icing { +namespace lib { + +class ScoringVisitor : public AbstractSyntaxTreeVisitor { + public: + explicit ScoringVisitor(double default_score) + : default_score_(default_score) {} + + void VisitFunctionName(const FunctionNameNode* node) override; + void VisitString(const StringNode* node) override; + void VisitText(const TextNode* node) override; + void VisitMember(const MemberNode* node) override; + void VisitFunction(const FunctionNode* node) override; + void VisitUnaryOperator(const UnaryOperatorNode* node) override; + void VisitNaryOperator(const NaryOperatorNode* node) override; + + // RETURNS: + // - An ScoreExpression instance able to evaluate the expression on success. + // - INVALID_ARGUMENT if the AST does not conform to supported expressions, + // such as type errors. + // - INTERNAL if there are inconsistencies. + libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> + Expression() && { + if (has_pending_error()) { + return pending_error_; + } + if (stack.size() != 1) { + return absl_ports::InternalError(IcingStringUtil::StringPrintf( + "Expect to get only one result from " + "ScoringVisitor, but got %zu. There must be inconsistencies.", + stack.size())); + } + return std::move(stack[0]); + } + + private: + bool has_pending_error() const { return !pending_error_.ok(); } + + std::unique_ptr<ScoreExpression> pop_stack() { + std::unique_ptr<ScoreExpression> result = std::move(stack.back()); + stack.pop_back(); + return result; + } + + double default_score_; + libtextclassifier3::Status pending_error_; + std::vector<std::unique_ptr<ScoreExpression>> stack; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker.cc b/icing/scoring/priority-queue-scored-document-hits-ranker.cc deleted file mode 100644 index 691b088..0000000 --- a/icing/scoring/priority-queue-scored-document-hits-ranker.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (C) 2022 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" - -#include <queue> -#include <vector> - -#include "icing/scoring/scored-document-hit.h" - -namespace icing { -namespace lib { - -PriorityQueueScoredDocumentHitsRanker::PriorityQueueScoredDocumentHitsRanker( - std::vector<ScoredDocumentHit>&& scored_document_hits, bool is_descending) - : comparator_(/*is_ascending=*/!is_descending), - scored_document_hits_pq_(comparator_, std::move(scored_document_hits)) {} - -ScoredDocumentHit PriorityQueueScoredDocumentHitsRanker::PopNext() { - ScoredDocumentHit ret = scored_document_hits_pq_.top(); - scored_document_hits_pq_.pop(); - return ret; -} - -void PriorityQueueScoredDocumentHitsRanker::TruncateHitsTo(int new_size) { - if (new_size < 0 || scored_document_hits_pq_.size() <= new_size) { - return; - } - - // Copying the best new_size results. - std::priority_queue<ScoredDocumentHit, std::vector<ScoredDocumentHit>, - Comparator> - new_pq(comparator_); - for (int i = 0; i < new_size; ++i) { - new_pq.push(scored_document_hits_pq_.top()); - scored_document_hits_pq_.pop(); - } - scored_document_hits_pq_ = std::move(new_pq); -} - -} // namespace lib -} // namespace icing diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker.h b/icing/scoring/priority-queue-scored-document-hits-ranker.h index 3ef2ae5..0798d7d 100644 --- a/icing/scoring/priority-queue-scored-document-hits-ranker.h +++ b/icing/scoring/priority-queue-scored-document-hits-ranker.h @@ -26,21 +26,37 @@ namespace lib { // ScoredDocumentHitsRanker interface implementation, based on // std::priority_queue. We can get next top hit in O(lgN) time. +template <typename ScoredDataType, + typename Converter = typename ScoredDataType::Converter> class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker { public: explicit PriorityQueueScoredDocumentHitsRanker( - std::vector<ScoredDocumentHit>&& scored_document_hits, - bool is_descending = true); + std::vector<ScoredDataType>&& scored_data_vec, bool is_descending = true); ~PriorityQueueScoredDocumentHitsRanker() override = default; - ScoredDocumentHit PopNext() override; + // Note: ranker may store ScoredDocumentHit or JoinedScoredDocumentHit, so we + // have template for scored_data_pq_. + // - JoinedScoredDocumentHit is a superset of ScoredDocumentHit, so we unify + // the return type of PopNext to use the superset type + // JoinedScoredDocumentHit in order to make it simple, and rankers storing + // ScoredDocumentHit should convert it to JoinedScoredDocumentHit before + // returning. It makes the implementation simpler, especially for + // ResultRetriever, which now only needs to deal with one single return + // format. + // - JoinedScoredDocumentHit has ~2x size of ScoredDocumentHit. Since we cache + // ranker (which contains a priority queue of data) in ResultState, if we + // store the scored hits in JoinedScoredDocumentHit format directly, then it + // doubles the memory usage. Therefore, we still keep the flexibility to + // store ScoredDocumentHit or any other types of data, but require PopNext + // to convert it to JoinedScoredDocumentHit. + JoinedScoredDocumentHit PopNext() override; void TruncateHitsTo(int new_size) override; - int size() const override { return scored_document_hits_pq_.size(); } + int size() const override { return scored_data_pq_.size(); } - bool empty() const override { return scored_document_hits_pq_.empty(); } + bool empty() const override { return scored_data_pq_.empty(); } private: // Comparator for std::priority_queue. Since std::priority is a max heap @@ -49,8 +65,8 @@ class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker { public: explicit Comparator(bool is_ascending) : is_ascending_(is_ascending) {} - bool operator()(const ScoredDocumentHit& lhs, - const ScoredDocumentHit& rhs) const { + bool operator()(const ScoredDataType& lhs, + const ScoredDataType& rhs) const { // STL comparator requirement: equal MUST return false. // If writing `return is_ascending_ == !(lhs < rhs)`: // - When lhs == rhs, !(lhs < rhs) is true @@ -68,11 +84,44 @@ class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker { Comparator comparator_; // Use priority queue to get top K hits in O(KlgN) time. - std::priority_queue<ScoredDocumentHit, std::vector<ScoredDocumentHit>, - Comparator> - scored_document_hits_pq_; + std::priority_queue<ScoredDataType, std::vector<ScoredDataType>, Comparator> + scored_data_pq_; + + Converter converter_; }; +template <typename ScoredDataType, typename Converter> +PriorityQueueScoredDocumentHitsRanker<ScoredDataType, Converter>:: + PriorityQueueScoredDocumentHitsRanker( + std::vector<ScoredDataType>&& scored_data_vec, bool is_descending) + : comparator_(/*is_ascending=*/!is_descending), + scored_data_pq_(comparator_, std::move(scored_data_vec)) {} + +template <typename ScoredDataType, typename Converter> +JoinedScoredDocumentHit +PriorityQueueScoredDocumentHitsRanker<ScoredDataType, Converter>::PopNext() { + ScoredDataType next_scored_data = scored_data_pq_.top(); + scored_data_pq_.pop(); + return converter_(std::move(next_scored_data)); +} + +template <typename ScoredDataType, typename Converter> +void PriorityQueueScoredDocumentHitsRanker< + ScoredDataType, Converter>::TruncateHitsTo(int new_size) { + if (new_size < 0 || scored_data_pq_.size() <= new_size) { + return; + } + + // Copying the best new_size results. + std::priority_queue<ScoredDataType, std::vector<ScoredDataType>, Comparator> + new_pq(comparator_); + for (int i = 0; i < new_size; ++i) { + new_pq.push(scored_data_pq_.top()); + scored_data_pq_.pop(); + } + scored_data_pq_ = std::move(new_pq); +} + } // namespace lib } // namespace icing diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc b/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc index a575eaf..ace2350 100644 --- a/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc +++ b/icing/scoring/priority-queue-scored-document-hits-ranker_test.cc @@ -31,9 +31,19 @@ using ::testing::Eq; using ::testing::IsEmpty; using ::testing::SizeIs; -std::vector<ScoredDocumentHit> PopAll( - PriorityQueueScoredDocumentHitsRanker& ranker) { - std::vector<ScoredDocumentHit> hits; +class Converter { + public: + JoinedScoredDocumentHit operator()(ScoredDocumentHit hit) const { + return converter_(std::move(hit)); + } + + private: + ScoredDocumentHit::Converter converter_; +} converter; + +std::vector<JoinedScoredDocumentHit> PopAll( + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit>& ranker) { + std::vector<JoinedScoredDocumentHit> hits; while (!ranker.empty()) { hits.push_back(ranker.PopNext()); } @@ -48,7 +58,7 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldGetCorrectSizeAndEmpty) { ScoredDocumentHit scored_hit_2(/*document_id=*/2, kSectionIdMaskNone, /*score=*/1); - PriorityQueueScoredDocumentHitsRanker ranker( + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker( {scored_hit_1, scored_hit_0, scored_hit_2}, /*is_descending=*/true); EXPECT_THAT(ranker.size(), Eq(3)); @@ -79,18 +89,19 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInDescendingOrder) { ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, /*score=*/1); - PriorityQueueScoredDocumentHitsRanker ranker( + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker( {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, /*is_descending=*/true); EXPECT_THAT(ranker, SizeIs(5)); - std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); - EXPECT_THAT(scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_4), - EqualsScoredDocumentHit(scored_hit_3), - EqualsScoredDocumentHit(scored_hit_2), - EqualsScoredDocumentHit(scored_hit_1), - EqualsScoredDocumentHit(scored_hit_0))); + std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT( + scored_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_3)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_2)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_1)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_0)))); } TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInAscendingOrder) { @@ -105,18 +116,19 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankInAscendingOrder) { ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, /*score=*/1); - PriorityQueueScoredDocumentHitsRanker ranker( + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker( {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, /*is_descending=*/false); EXPECT_THAT(ranker, SizeIs(5)); - std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); - EXPECT_THAT(scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_0), - EqualsScoredDocumentHit(scored_hit_1), - EqualsScoredDocumentHit(scored_hit_2), - EqualsScoredDocumentHit(scored_hit_3), - EqualsScoredDocumentHit(scored_hit_4))); + std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT( + scored_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_0)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_1)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_2)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_3)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_4)))); } TEST(PriorityQueueScoredDocumentHitsRankerTest, @@ -132,28 +144,30 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, /*score=*/1); - PriorityQueueScoredDocumentHitsRanker ranker( + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker( {scored_hit_2, scored_hit_4, scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_2, scored_hit_4, scored_hit_3}, /*is_descending=*/true); EXPECT_THAT(ranker, SizeIs(8)); - std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); - EXPECT_THAT(scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_4), - EqualsScoredDocumentHit(scored_hit_4), - EqualsScoredDocumentHit(scored_hit_3), - EqualsScoredDocumentHit(scored_hit_2), - EqualsScoredDocumentHit(scored_hit_2), - EqualsScoredDocumentHit(scored_hit_2), - EqualsScoredDocumentHit(scored_hit_1), - EqualsScoredDocumentHit(scored_hit_0))); + std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT( + scored_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_4)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_3)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_2)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_2)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_2)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_1)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_0)))); } TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldRankEmptyScoredDocumentHits) { - PriorityQueueScoredDocumentHitsRanker ranker(/*scored_document_hits=*/{}, - /*is_descending=*/true); + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker( + /*scored_document_hits=*/{}, + /*is_descending=*/true); EXPECT_THAT(ranker, IsEmpty()); } @@ -169,18 +183,19 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToNewSize) { ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, /*score=*/1); - PriorityQueueScoredDocumentHitsRanker ranker( + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker( {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, /*is_descending=*/true); ASSERT_THAT(ranker, SizeIs(5)); ranker.TruncateHitsTo(/*new_size=*/3); EXPECT_THAT(ranker, SizeIs(3)); - std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); - EXPECT_THAT(scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_4), - EqualsScoredDocumentHit(scored_hit_3), - EqualsScoredDocumentHit(scored_hit_2))); + std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT( + scored_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_3)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_2)))); } TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToZero) { @@ -195,7 +210,7 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldTruncateToZero) { ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, /*score=*/1); - PriorityQueueScoredDocumentHitsRanker ranker( + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker( {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, /*is_descending=*/true); ASSERT_THAT(ranker, SizeIs(5)); @@ -216,7 +231,7 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldNotTruncateToNegative) { ScoredDocumentHit scored_hit_4(/*document_id=*/4, kSectionIdMaskNone, /*score=*/1); - PriorityQueueScoredDocumentHitsRanker ranker( + PriorityQueueScoredDocumentHitsRanker<ScoredDocumentHit> ranker( {scored_hit_1, scored_hit_0, scored_hit_2, scored_hit_4, scored_hit_3}, /*is_descending=*/true); ASSERT_THAT(ranker, SizeIs(Eq(5))); @@ -224,13 +239,14 @@ TEST(PriorityQueueScoredDocumentHitsRankerTest, ShouldNotTruncateToNegative) { ranker.TruncateHitsTo(/*new_size=*/-1); EXPECT_THAT(ranker, SizeIs(Eq(5))); // Contents are not affected. - std::vector<ScoredDocumentHit> scored_document_hits = PopAll(ranker); - EXPECT_THAT(scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_4), - EqualsScoredDocumentHit(scored_hit_3), - EqualsScoredDocumentHit(scored_hit_2), - EqualsScoredDocumentHit(scored_hit_1), - EqualsScoredDocumentHit(scored_hit_0))); + std::vector<JoinedScoredDocumentHit> scored_document_hits = PopAll(ranker); + EXPECT_THAT( + scored_document_hits, + ElementsAre(EqualsJoinedScoredDocumentHit(converter(scored_hit_4)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_3)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_2)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_1)), + EqualsJoinedScoredDocumentHit(converter(scored_hit_0)))); } } // namespace diff --git a/icing/scoring/scored-document-hit.cc b/icing/scoring/scored-document-hit.cc new file mode 100644 index 0000000..f519a16 --- /dev/null +++ b/icing/scoring/scored-document-hit.cc @@ -0,0 +1,30 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/scored-document-hit.h" + +namespace icing { +namespace lib { + +JoinedScoredDocumentHit ScoredDocumentHit::Converter::operator()( + ScoredDocumentHit&& scored_doc_hit) const { + double final_score = scored_doc_hit.score(); + return JoinedScoredDocumentHit( + final_score, + /*parent_scored_document_hit=*/std::move(scored_doc_hit), + /*child_scored_document_hits=*/{}); +} + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/scored-document-hit.h b/icing/scoring/scored-document-hit.h index 96ca6aa..141049e 100644 --- a/icing/scoring/scored-document-hit.h +++ b/icing/scoring/scored-document-hit.h @@ -24,11 +24,19 @@ namespace icing { namespace lib { +class JoinedScoredDocumentHit; + // A data class containing information about the document, hit sections, and a // score. The score is calculated against both the document and the hit // sections. class ScoredDocumentHit { public: + class Converter { + public: + JoinedScoredDocumentHit operator()( + ScoredDocumentHit&& scored_doc_hit) const; + }; + ScoredDocumentHit(DocumentId document_id, SectionIdMask hit_section_id_mask, double score) : document_id_(document_id), @@ -85,6 +93,65 @@ class ScoredDocumentHitComparator { bool is_descending_; }; +// A data class containing information about a composite document after joining, +// including final score, parent ScoredDocumentHit, and a vector of all child +// ScoredDocumentHits. The final score is calculated by the strategy specified +// in join spec/rank strategy. It could be aggregated score, raw parent doc +// score, or anything else. +// +// ScoredDocumentHitsRanker may store ScoredDocumentHit or +// JoinedScoredDocumentHit. +// - We could've created a virtual class for them and ScoredDocumentHitsRanker +// uses the abstract type. +// - However, Icing lib caches ScoredDocumentHitsRanker (which contains a list +// of (Joined)ScoredDocumentHits) in ResultState. Inheriting the virtual class +// makes both classes have additional 8 bytes for vtable, which increases 40% +// and 15% memory usage respectively. +// - Also since JoinedScoredDocumentHit is a super-set of ScoredDocumentHit, +// let's avoid the common virtual class and instead implement a convert +// function (original type -> JoinedScoredDocumentHit) for each class, so +// ScoredDocumentHitsRanker::PopNext can return a common type (i.e. +// JoinedScoredDocumentHit). +class JoinedScoredDocumentHit { + public: + class Converter { + public: + JoinedScoredDocumentHit operator()( + JoinedScoredDocumentHit&& scored_doc_hit) const { + return scored_doc_hit; + } + }; + + explicit JoinedScoredDocumentHit( + double final_score, ScoredDocumentHit&& parent_scored_document_hit, + std::vector<ScoredDocumentHit>&& child_scored_document_hits) + : final_score_(final_score), + parent_scored_document_hit_(std::move(parent_scored_document_hit)), + child_scored_document_hits_(std::move(child_scored_document_hits)) {} + + bool operator<(const JoinedScoredDocumentHit& other) const { + if (final_score_ != other.final_score_) { + return final_score_ < other.final_score_; + } + return parent_scored_document_hit_ < other.parent_scored_document_hit_; + } + + double final_score() const { return final_score_; } + + const ScoredDocumentHit& parent_scored_document_hit() const { + return parent_scored_document_hit_; + } + + const std::vector<ScoredDocumentHit>& child_scored_document_hits() const { + return child_scored_document_hits_; + } + + private: + double final_score_; + ScoredDocumentHit parent_scored_document_hit_; + std::vector<ScoredDocumentHit> child_scored_document_hits_; +} __attribute__((packed)); + } // namespace lib } // namespace icing diff --git a/icing/scoring/scored-document-hit_test.cc b/icing/scoring/scored-document-hit_test.cc new file mode 100644 index 0000000..cb9703b --- /dev/null +++ b/icing/scoring/scored-document-hit_test.cc @@ -0,0 +1,77 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/scored-document-hit.h" + +#include <cstdint> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/testing/common-matchers.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::DoubleEq; +using ::testing::IsEmpty; + +TEST(ScoredDocumentHitTest, ScoredDocumentHitConvertToJoinedScoredDocumentHit) { + ScoredDocumentHit::Converter converter; + + double score = 2.0; + ScoredDocumentHit scored_document_hit(/*document_id=*/5, + /*section_id_mask=*/49, score); + + JoinedScoredDocumentHit joined_scored_document_hit = + converter(ScoredDocumentHit(scored_document_hit)); + EXPECT_THAT(joined_scored_document_hit.final_score(), DoubleEq(score)); + EXPECT_THAT(joined_scored_document_hit.parent_scored_document_hit(), + EqualsScoredDocumentHit(scored_document_hit)); + EXPECT_THAT(joined_scored_document_hit.child_scored_document_hits(), + IsEmpty()); +} + +TEST(ScoredDocumentHitTest, + JoinedScoredDocumentHitConvertToJoinedScoredDocumentHit) { + JoinedScoredDocumentHit::Converter converter; + + ScoredDocumentHit parent_scored_document_hit(/*document_id=*/5, + /*section_id_mask=*/49, + /*score=*/1.0); + std::vector<ScoredDocumentHit> child_scored_document_hits{ + ScoredDocumentHit(/*document_id=*/1, + /*section_id_mask=*/1, + /*score=*/2.0), + ScoredDocumentHit(/*document_id=*/2, + /*section_id_mask=*/2, + /*score=*/3.0), + ScoredDocumentHit(/*document_id=*/3, + /*section_id_mask=*/3, + /*score=*/4.0)}; + + JoinedScoredDocumentHit joined_scored_document_hit( + /*final_score=*/12345.6789, std::move(parent_scored_document_hit), + std::move(child_scored_document_hits)); + EXPECT_THAT(converter(JoinedScoredDocumentHit(joined_scored_document_hit)), + EqualsJoinedScoredDocumentHit(joined_scored_document_hit)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/scoring/scored-document-hits-ranker.h b/icing/scoring/scored-document-hits-ranker.h index 0287452..9b76ce7 100644 --- a/icing/scoring/scored-document-hits-ranker.h +++ b/icing/scoring/scored-document-hits-ranker.h @@ -30,10 +30,19 @@ class ScoredDocumentHitsRanker { public: virtual ~ScoredDocumentHitsRanker() = default; - // Pop the next top ScoredDocumentHit and return. It is undefined to call - // PopNext on an empty ranker, so the caller should check if it is not empty - // before calling. - virtual ScoredDocumentHit PopNext() = 0; + // Pop the next top JoinedScoredDocumentHit and return. It is undefined to + // call PopNext on an empty ranker, so the caller should check if it is not + // empty before calling. + // + // Note: ranker may store ScoredDocumentHit or JoinedScoredDocumentHit. We can + // add template for this interface, but since JoinedScoredDocumentHit is a + // superset of ScoredDocumentHit, we unify the return type of PopNext to use + // the superset type JoinedScoredDocumentHit in order to make it simple, and + // rankers storing ScoredDocumentHit should convert it to + // JoinedScoredDocumentHit before returning. It makes the implementation + // simpler, especially for ResultRetriever, which now only needs to deal with + // one single return format. + virtual JoinedScoredDocumentHit PopNext() = 0; // Truncates the remaining ScoredDocumentHits to the given size. The best // ScoredDocumentHits (according to the ranking policy) should be kept. diff --git a/icing/scoring/scorer.cc b/icing/scoring/scorer-factory.cc index 14a004e..600fe6b 100644 --- a/icing/scoring/scorer.cc +++ b/icing/scoring/scorer-factory.cc @@ -12,16 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/scoring/scorer.h" +#include "icing/scoring/scorer-factory.h" #include <memory> +#include <unordered_map> #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/scoring.pb.h" +#include "icing/scoring/advanced_scoring/advanced-scorer.h" #include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/scorer.h" #include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" @@ -156,12 +159,22 @@ class NoScorer : public Scorer { double default_score_; }; -libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( +namespace scorer_factory { + +libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( const ScoringSpecProto& scoring_spec, double default_score, const DocumentStore* document_store, const SchemaStore* schema_store) { ICING_RETURN_ERROR_IF_NULL(document_store); ICING_RETURN_ERROR_IF_NULL(schema_store); + if (!scoring_spec.advanced_scoring_expression().empty() && + scoring_spec.rank_by() != + ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION) { + return absl_ports::InvalidArgumentError( + "Advanced scoring is not enabled, but the advanced scoring expression " + "is not empty!"); + } + switch (scoring_spec.rank_by()) { case ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE: return std::make_unique<DocumentScoreScorer>(document_store, @@ -192,6 +205,13 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP: return std::make_unique<UsageScorer>( document_store, scoring_spec.rank_by(), default_score); + case ScoringSpecProto::RankingStrategy::ADVANCED_SCORING_EXPRESSION: + if (scoring_spec.advanced_scoring_expression().empty()) { + return absl_ports::InvalidArgumentError( + "Advanced scoring is enabled, but the expression is empty!"); + } + return AdvancedScorer::Create(scoring_spec, default_score, document_store, + schema_store); case ScoringSpecProto::RankingStrategy::JOIN_AGGREGATE_SCORE: ICING_LOG(WARNING) << "JOIN_AGGREGATE_SCORE not implemented, falling back to NoScorer"; @@ -201,5 +221,7 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( } } +} // namespace scorer_factory + } // namespace lib } // namespace icing diff --git a/icing/scoring/scorer-factory.h b/icing/scoring/scorer-factory.h new file mode 100644 index 0000000..8c19c75 --- /dev/null +++ b/icing/scoring/scorer-factory.h @@ -0,0 +1,46 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_SCORER_FACTORY_H_ +#define ICING_SCORING_SCORER_FACTORY_H_ + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/scoring/scorer.h" +#include "icing/store/document-store.h" + +namespace icing { +namespace lib { + +namespace scorer_factory { + +// Factory function to create a Scorer which does not take ownership of any +// input components (DocumentStore), and all pointers must refer to valid +// objects that outlive the created Scorer instance. The default score will be +// returned only when the scorer fails to find or calculate a score for the +// document. +// +// Returns: +// A Scorer on success +// FAILED_PRECONDITION on any null pointer input +// INVALID_ARGUMENT if fails to create an instance +libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( + const ScoringSpecProto& scoring_spec, double default_score, + const DocumentStore* document_store, const SchemaStore* schema_store); + +} // namespace scorer_factory + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_SCORER_FACTORY_H_ diff --git a/icing/scoring/scorer.h b/icing/scoring/scorer.h index abdd5ca..ec48502 100644 --- a/icing/scoring/scorer.h +++ b/icing/scoring/scorer.h @@ -16,13 +16,11 @@ #define ICING_SCORING_SCORER_H_ #include <memory> +#include <unordered_map> -#include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/scoring.pb.h" -#include "icing/store/document-id.h" -#include "icing/store/document-store.h" namespace icing { namespace lib { @@ -32,20 +30,6 @@ class Scorer { public: virtual ~Scorer() = default; - // Factory function to create a Scorer which does not take ownership of any - // input components (DocumentStore), and all pointers must refer to valid - // objects that outlive the created Scorer instance. The default score will be - // returned only when the scorer fails to find or calculate a score for the - // document. - // - // Returns: - // A Scorer on success - // FAILED_PRECONDITION on any null pointer input - // INVALID_ARGUMENT if fails to create an instance - static libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( - const ScoringSpecProto& scoring_spec, double default_score, - const DocumentStore* document_store, const SchemaStore* schema_store); - // Returns a non-negative score of a document. The score can be a // document-associated score which comes from the DocumentProto directly, an // accumulated score, a relevance score, or even an inferred score. If it diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc index 5432cde..7bbb8b7 100644 --- a/icing/scoring/scorer_test.cc +++ b/icing/scoring/scorer_test.cc @@ -28,6 +28,7 @@ #include "icing/proto/usage.pb.h" #include "icing/schema-builder.h" #include "icing/schema/schema-store.h" +#include "icing/scoring/scorer-factory.h" #include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" @@ -128,28 +129,29 @@ ScoringSpecProto CreateScoringSpecForRankingStrategy( TEST_F(ScorerTest, CreationWithNullDocumentStoreShouldFail) { EXPECT_THAT( - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), - /*default_score=*/0, /*document_store=*/nullptr, - schema_store()), + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, /*document_store=*/nullptr, schema_store()), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScorerTest, CreationWithNullSchemaStoreShouldFail) { - EXPECT_THAT( - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), - /*default_score=*/0, document_store(), - /*schema_store=*/nullptr), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, document_store(), + /*schema_store=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), - /*default_score=*/10, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); // Non existent document id DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1); @@ -171,9 +173,10 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), - /*default_score=*/10, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -204,9 +207,10 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), - /*default_score=*/10, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -233,9 +237,10 @@ TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), - /*default_score=*/10, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); @@ -256,9 +261,10 @@ TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5)); @@ -281,9 +287,10 @@ TEST_F(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE), - /*default_score=*/10, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); @@ -313,9 +320,10 @@ TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) { document_store()->Put(test_document2)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(document_id1); DocHitInfo docHitInfo2 = DocHitInfo(document_id2); @@ -340,19 +348,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -384,19 +395,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -428,19 +442,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -472,22 +489,25 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE2_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE3_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -535,22 +555,25 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE2_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE3_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -598,22 +621,25 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE2_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE3_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -649,9 +675,10 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::NONE), - /*default_score=*/3, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE), + /*default_score=*/3, document_store(), + schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0); DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1); @@ -661,10 +688,10 @@ TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy::NONE), - /*default_score=*/111, document_store(), schema_store())); + scorer, scorer_factory::Create( + CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE), + /*default_score=*/111, document_store(), schema_store())); docHitInfo1 = DocHitInfo(/*document_id_in=*/4); docHitInfo2 = DocHitInfo(/*document_id_in=*/5); @@ -688,10 +715,11 @@ TEST_F(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(CreateScoringSpecForRankingStrategy( - ScoringSpecProto::RankingStrategy:: - USAGE_TYPE1_LAST_USED_TIMESTAMP), - /*default_score=*/0, document_store(), schema_store())); + scorer_factory::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), + schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); // Create usage report for the maximum allowable timestamp. diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc index d5e64d8..571a112 100644 --- a/icing/scoring/scoring-processor.cc +++ b/icing/scoring/scoring-processor.cc @@ -15,6 +15,7 @@ #include "icing/scoring/scoring-processor.h" #include <memory> +#include <unordered_map> #include <utility> #include <vector> @@ -25,6 +26,7 @@ #include "icing/proto/scoring.pb.h" #include "icing/scoring/ranker.h" #include "icing/scoring/scored-document-hit.h" +#include "icing/scoring/scorer-factory.h" #include "icing/scoring/scorer.h" #include "icing/store/document-store.h" #include "icing/util/status-macros.h" @@ -50,10 +52,11 @@ ScoringProcessor::Create(const ScoringSpecProto& scoring_spec, ICING_ASSIGN_OR_RETURN( std::unique_ptr<Scorer> scorer, - Scorer::Create(scoring_spec, - is_descending_order ? kDefaultScoreInDescendingOrder - : kDefaultScoreInAscendingOrder, - document_store, schema_store)); + scorer_factory::Create(scoring_spec, + is_descending_order + ? kDefaultScoreInDescendingOrder + : kDefaultScoreInAscendingOrder, + document_store, schema_store)); // Using `new` to access a non-public constructor. return std::unique_ptr<ScoringProcessor>( new ScoringProcessor(std::move(scorer))); diff --git a/icing/store/document-store.cc b/icing/store/document-store.cc index 9a33682..62599c8 100644 --- a/icing/store/document-store.cc +++ b/icing/store/document-store.cc @@ -1678,8 +1678,8 @@ DocumentStore::OptimizeInto(const std::string& new_directory, } TokenizedDocument tokenized_document( std::move(tokenized_document_or).ValueOrDie()); - new_document_id_or = - new_doc_store->Put(document_to_keep, tokenized_document.num_tokens()); + new_document_id_or = new_doc_store->Put( + document_to_keep, tokenized_document.num_string_tokens()); } else { // TODO(b/144458732): Implement a more robust version of // TC_ASSIGN_OR_RETURN that can support error logging. diff --git a/icing/store/dynamic-trie-key-mapper.h b/icing/store/dynamic-trie-key-mapper.h index 35d2200..63e8488 100644 --- a/icing/store/dynamic-trie-key-mapper.h +++ b/icing/store/dynamic-trie-key-mapper.h @@ -60,12 +60,15 @@ class DynamicTrieKeyMapper : public KeyMapper<T, Formatter> { Create(const Filesystem& filesystem, std::string_view base_dir, int maximum_size_bytes); - // Deletes all the files associated with the DynamicTrieKeyMapper. Returns - // success or any encountered IO errors + // Deletes all the files associated with the DynamicTrieKeyMapper. // // base_dir : Base directory used to save all the files required to persist // DynamicTrieKeyMapper. Should be the same as passed into // Create(). + // + // Returns + // OK on success + // INTERNAL_ERROR on I/O error static libtextclassifier3::Status Delete(const Filesystem& filesystem, std::string_view base_dir); diff --git a/icing/store/dynamic-trie-key-mapper_test.cc b/icing/store/dynamic-trie-key-mapper_test.cc index add88bb..fd56170 100644 --- a/icing/store/dynamic-trie-key-mapper_test.cc +++ b/icing/store/dynamic-trie-key-mapper_test.cc @@ -14,21 +14,21 @@ #include "icing/store/dynamic-trie-key-mapper.h" +#include <string> + +#include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "icing/file/filesystem.h" #include "icing/store/document-id.h" #include "icing/testing/common-matchers.h" #include "icing/testing/tmp-directory.h" -using ::testing::_; -using ::testing::HasSubstr; -using ::testing::IsEmpty; -using ::testing::Pair; -using ::testing::UnorderedElementsAre; - namespace icing { namespace lib { + namespace { + constexpr int kMaxDynamicTrieKeyMapperSize = 3 * 1024 * 1024; // 3 MiB class DynamicTrieKeyMapperTest : public testing::Test { @@ -43,168 +43,25 @@ class DynamicTrieKeyMapperTest : public testing::Test { Filesystem filesystem_; }; -std::unordered_map<std::string, DocumentId> GetAllKeyValuePairs( - const DynamicTrieKeyMapper<DocumentId>* key_mapper) { - std::unordered_map<std::string, DocumentId> ret; - - std::unique_ptr<typename KeyMapper<DocumentId>::Iterator> itr = - key_mapper->GetIterator(); - while (itr->Advance()) { - ret.emplace(itr->GetKey(), itr->GetValue()); - } - return ret; -} - TEST_F(DynamicTrieKeyMapperTest, InvalidBaseDir) { - ASSERT_THAT(DynamicTrieKeyMapper<DocumentId>::Create( - filesystem_, "/dev/null", kMaxDynamicTrieKeyMapperSize) - .status() - .error_message(), - HasSubstr("Failed to create DynamicTrieKeyMapper")); + EXPECT_THAT(DynamicTrieKeyMapper<DocumentId>::Create( + filesystem_, "/dev/null", kMaxDynamicTrieKeyMapperSize), + StatusIs(libtextclassifier3::StatusCode::INTERNAL)); } TEST_F(DynamicTrieKeyMapperTest, NegativeMaxKeyMapperSizeReturnsInternalError) { - ASSERT_THAT( + EXPECT_THAT( DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, -1), StatusIs(libtextclassifier3::StatusCode::INTERNAL)); } TEST_F(DynamicTrieKeyMapperTest, TooLargeMaxKeyMapperSizeReturnsInternalError) { - ASSERT_THAT(DynamicTrieKeyMapper<DocumentId>::Create( + EXPECT_THAT(DynamicTrieKeyMapper<DocumentId>::Create( filesystem_, base_dir_, std::numeric_limits<int>::max()), StatusIs(libtextclassifier3::StatusCode::INTERNAL)); } -TEST_F(DynamicTrieKeyMapperTest, CreateNewKeyMapper) { - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, - DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, - kMaxDynamicTrieKeyMapperSize)); - EXPECT_THAT(key_mapper->num_keys(), 0); -} - -TEST_F(DynamicTrieKeyMapperTest, CanUpdateSameKeyMultipleTimes) { - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, - DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, - kMaxDynamicTrieKeyMapperSize)); - - ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); - ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 50)); - - EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100)); - - ICING_EXPECT_OK(key_mapper->Put("default-google.com", 200)); - EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(200)); - EXPECT_THAT(key_mapper->num_keys(), 2); - - ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300)); - EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300)); - EXPECT_THAT(key_mapper->num_keys(), 2); -} - -TEST_F(DynamicTrieKeyMapperTest, GetOrPutOk) { - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, - DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, - kMaxDynamicTrieKeyMapperSize)); - - EXPECT_THAT(key_mapper->Get("foo"), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - EXPECT_THAT(key_mapper->GetOrPut("foo", 1), IsOkAndHolds(1)); - EXPECT_THAT(key_mapper->Get("foo"), IsOkAndHolds(1)); -} - -TEST_F(DynamicTrieKeyMapperTest, CanPersistToDiskRegularly) { - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, - DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, - kMaxDynamicTrieKeyMapperSize)); - // Can persist an empty DynamicTrieKeyMapper. - ICING_EXPECT_OK(key_mapper->PersistToDisk()); - EXPECT_THAT(key_mapper->num_keys(), 0); - - // Can persist the smallest DynamicTrieKeyMapper. - ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); - ICING_EXPECT_OK(key_mapper->PersistToDisk()); - EXPECT_THAT(key_mapper->num_keys(), 1); - EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100)); - - // Can continue to add keys after PersistToDisk(). - ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 200)); - EXPECT_THAT(key_mapper->num_keys(), 2); - EXPECT_THAT(key_mapper->Get("default-youtube.com"), IsOkAndHolds(200)); - - // Can continue to update the same key after PersistToDisk(). - ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300)); - EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300)); - EXPECT_THAT(key_mapper->num_keys(), 2); -} - -TEST_F(DynamicTrieKeyMapperTest, CanUseAcrossMultipleInstances) { - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, - DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, - kMaxDynamicTrieKeyMapperSize)); - ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); - ICING_EXPECT_OK(key_mapper->PersistToDisk()); - - key_mapper.reset(); - ICING_ASSERT_OK_AND_ASSIGN( - key_mapper, DynamicTrieKeyMapper<DocumentId>::Create( - filesystem_, base_dir_, kMaxDynamicTrieKeyMapperSize)); - EXPECT_THAT(key_mapper->num_keys(), 1); - EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100)); - - // Can continue to read/write to the KeyMapper. - ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 200)); - ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300)); - EXPECT_THAT(key_mapper->num_keys(), 2); - EXPECT_THAT(key_mapper->Get("default-youtube.com"), IsOkAndHolds(200)); - EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300)); -} - -TEST_F(DynamicTrieKeyMapperTest, CanDeleteAndRestartKeyMapping) { - // Can delete even if there's nothing there - ICING_EXPECT_OK( - DynamicTrieKeyMapper<DocumentId>::Delete(filesystem_, base_dir_)); - - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, - DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, - kMaxDynamicTrieKeyMapperSize)); - ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); - ICING_EXPECT_OK(key_mapper->PersistToDisk()); - ICING_EXPECT_OK( - DynamicTrieKeyMapper<DocumentId>::Delete(filesystem_, base_dir_)); - - key_mapper.reset(); - ICING_ASSERT_OK_AND_ASSIGN( - key_mapper, DynamicTrieKeyMapper<DocumentId>::Create( - filesystem_, base_dir_, kMaxDynamicTrieKeyMapperSize)); - EXPECT_THAT(key_mapper->num_keys(), 0); - ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); - EXPECT_THAT(key_mapper->num_keys(), 1); -} - -TEST_F(DynamicTrieKeyMapperTest, Iterator) { - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DynamicTrieKeyMapper<DocumentId>> key_mapper, - DynamicTrieKeyMapper<DocumentId>::Create(filesystem_, base_dir_, - kMaxDynamicTrieKeyMapperSize)); - EXPECT_THAT(GetAllKeyValuePairs(key_mapper.get()), IsEmpty()); - - ICING_EXPECT_OK(key_mapper->Put("foo", /*value=*/1)); - ICING_EXPECT_OK(key_mapper->Put("bar", /*value=*/2)); - EXPECT_THAT(GetAllKeyValuePairs(key_mapper.get()), - UnorderedElementsAre(Pair("foo", 1), Pair("bar", 2))); - - ICING_EXPECT_OK(key_mapper->Put("baz", /*value=*/3)); - EXPECT_THAT( - GetAllKeyValuePairs(key_mapper.get()), - UnorderedElementsAre(Pair("foo", 1), Pair("bar", 2), Pair("baz", 3))); -} - } // namespace + } // namespace lib } // namespace icing diff --git a/icing/store/key-mapper_benchmark.cc b/icing/store/key-mapper_benchmark.cc new file mode 100644 index 0000000..b649bc7 --- /dev/null +++ b/icing/store/key-mapper_benchmark.cc @@ -0,0 +1,316 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <random> +#include <string> +#include <unordered_map> + +#include "testing/base/public/benchmark.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/file/destructible-directory.h" +#include "icing/file/filesystem.h" +#include "icing/store/dynamic-trie-key-mapper.h" +#include "icing/store/key-mapper.h" +#include "icing/store/persistent-hash-map-key-mapper.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/random-string.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::Eq; +using ::testing::Not; + +class KeyMapperBenchmark { + public: + static constexpr int kKeyLength = 20; + + explicit KeyMapperBenchmark() + : clock(std::make_unique<Clock>()), + base_dir(GetTestTempDir() + "/key_mapper_benchmark"), + random_engine(/*seed=*/12345) {} + + std::string GenerateUniqueRandomKeyValuePair(int val, + std::string_view prefix = "") { + std::string rand_str = absl_ports::StrCat( + prefix, RandomString(kAlNumAlphabet, kKeyLength, &random_engine)); + while (random_kvps_map.find(rand_str) != random_kvps_map.end()) { + rand_str = absl_ports::StrCat( + std::string(prefix), + RandomString(kAlNumAlphabet, kKeyLength, &random_engine)); + } + std::pair<std::string, int> entry(rand_str, val); + random_kvps.push_back(entry); + random_kvps_map.insert(entry); + return rand_str; + } + + template <typename UnknownKeyMapperType> + libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<int>>> CreateKeyMapper( + int max_num_entries) { + return absl_ports::InvalidArgumentError("Unknown type"); + } + + template <> + libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<int>>> + CreateKeyMapper<DynamicTrieKeyMapper<int>>(int max_num_entries) { + return DynamicTrieKeyMapper<int>::Create( + filesystem, base_dir, + /*maximum_size_bytes=*/128 * 1024 * 1024); + } + + template <> + libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<int>>> + CreateKeyMapper<PersistentHashMapKeyMapper<int>>(int max_num_entries) { + return PersistentHashMapKeyMapper<int>::Create( + filesystem, base_dir, max_num_entries, + /*average_kv_byte_size=*/kKeyLength + 1 + sizeof(int), + /*max_load_factor_percent=*/100); + } + + std::unique_ptr<Clock> clock; + + Filesystem filesystem; + std::string base_dir; + + std::default_random_engine random_engine; + std::vector<std::pair<std::string, int>> random_kvps; + std::unordered_map<std::string, int> random_kvps_map; +}; + +// Benchmark the total time of putting num_keys (specified by Arg) unique random +// key value pairs. +template <typename KeyMapperType> +void BM_PutMany(benchmark::State& state) { + int num_keys = state.range(0); + + KeyMapperBenchmark benchmark; + for (int i = 0; i < num_keys; ++i) { + benchmark.GenerateUniqueRandomKeyValuePair(i); + } + + for (auto _ : state) { + state.PauseTiming(); + benchmark.filesystem.DeleteDirectoryRecursively(benchmark.base_dir.c_str()); + DestructibleDirectory ddir(&benchmark.filesystem, benchmark.base_dir); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<KeyMapper<int>> key_mapper, + benchmark.CreateKeyMapper<KeyMapperType>(num_keys)); + ASSERT_THAT(key_mapper->num_keys(), Eq(0)); + state.ResumeTiming(); + + for (int i = 0; i < num_keys; ++i) { + ICING_ASSERT_OK(key_mapper->Put(benchmark.random_kvps[i].first, + benchmark.random_kvps[i].second)); + } + + // Explicit calls PersistToDisk. + ICING_ASSERT_OK(key_mapper->PersistToDisk()); + + state.PauseTiming(); + ASSERT_THAT(key_mapper->num_keys(), Eq(num_keys)); + // The destructor of IcingDynamicTrie doesn't implicitly call PersistToDisk, + // while PersistentHashMap does. Thus, we reset the unique pointer to invoke + // destructor in the pause timing block, so in this case PersistToDisk will + // be included into the benchmark only once. + key_mapper.reset(); + state.ResumeTiming(); + } +} +BENCHMARK(BM_PutMany<DynamicTrieKeyMapper<int>>) + ->Arg(1 << 10) + ->Arg(1 << 11) + ->Arg(1 << 12) + ->Arg(1 << 13) + ->Arg(1 << 14) + ->Arg(1 << 15) + ->Arg(1 << 16) + ->Arg(1 << 17) + ->Arg(1 << 18) + ->Arg(1 << 19) + ->Arg(1 << 20); +BENCHMARK(BM_PutMany<PersistentHashMapKeyMapper<int>>) + ->Arg(1 << 10) + ->Arg(1 << 11) + ->Arg(1 << 12) + ->Arg(1 << 13) + ->Arg(1 << 14) + ->Arg(1 << 15) + ->Arg(1 << 16) + ->Arg(1 << 17) + ->Arg(1 << 18) + ->Arg(1 << 19) + ->Arg(1 << 20); + +// Benchmark the average time of putting 1 unique random key value pair. The +// result will be affected by # of iterations, so use --benchmark_max_iters=k +// and --benchmark_min_iters=k to force # of iterations to be fixed. +template <typename KeyMapperType> +void BM_Put(benchmark::State& state) { + KeyMapperBenchmark benchmark; + benchmark.filesystem.DeleteDirectoryRecursively(benchmark.base_dir.c_str()); + DestructibleDirectory ddir(&benchmark.filesystem, benchmark.base_dir); + + // The overhead of state.PauseTiming is too large and affects the benchmark + // result a lot, so pre-generate enough kvps to avoid calling too many times + // state.PauseTiming for GenerateUniqueRandomKeyValuePair in the benchmark + // for-loop. + int MAX_PREGEN_KVPS = 1 << 22; + for (int i = 0; i < MAX_PREGEN_KVPS; ++i) { + benchmark.GenerateUniqueRandomKeyValuePair(i); + } + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<KeyMapper<int>> key_mapper, + benchmark.CreateKeyMapper<KeyMapperType>(/*max_num_entries=*/1 << 22)); + ASSERT_THAT(key_mapper->num_keys(), Eq(0)); + + int cnt = 0; + for (auto _ : state) { + if (cnt >= MAX_PREGEN_KVPS) { + state.PauseTiming(); + benchmark.GenerateUniqueRandomKeyValuePair(cnt); + state.ResumeTiming(); + } + + ICING_ASSERT_OK(key_mapper->Put(benchmark.random_kvps[cnt].first, + benchmark.random_kvps[cnt].second)); + ++cnt; + } +} +BENCHMARK(BM_Put<DynamicTrieKeyMapper<int>>); +BENCHMARK(BM_Put<PersistentHashMapKeyMapper<int>>); + +// Benchmark the average time of getting 1 existing key value pair from the key +// mapper with size num_keys (specified by Arg). +template <typename KeyMapperType> +void BM_Get(benchmark::State& state) { + int num_keys = state.range(0); + + KeyMapperBenchmark benchmark; + benchmark.filesystem.DeleteDirectoryRecursively(benchmark.base_dir.c_str()); + DestructibleDirectory ddir(&benchmark.filesystem, benchmark.base_dir); + + // Create a key mapper with num_keys entries. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<KeyMapper<int>> key_mapper, + benchmark.CreateKeyMapper<KeyMapperType>(num_keys)); + for (int i = 0; i < num_keys; ++i) { + ICING_ASSERT_OK( + key_mapper->Put(benchmark.GenerateUniqueRandomKeyValuePair(i), i)); + } + ASSERT_THAT(key_mapper->num_keys(), Eq(num_keys)); + + std::uniform_int_distribution<> distrib(0, num_keys - 1); + std::default_random_engine e(/*seed=*/12345); + for (auto _ : state) { + int idx = distrib(e); + ICING_ASSERT_OK_AND_ASSIGN( + int val, key_mapper->Get(benchmark.random_kvps[idx].first)); + ASSERT_THAT(val, Eq(benchmark.random_kvps[idx].second)); + } +} +BENCHMARK(BM_Get<DynamicTrieKeyMapper<int>>) + ->Arg(1 << 10) + ->Arg(1 << 11) + ->Arg(1 << 12) + ->Arg(1 << 13) + ->Arg(1 << 14) + ->Arg(1 << 15) + ->Arg(1 << 16) + ->Arg(1 << 17) + ->Arg(1 << 18) + ->Arg(1 << 19) + ->Arg(1 << 20); +BENCHMARK(BM_Get<PersistentHashMapKeyMapper<int>>) + ->Arg(1 << 10) + ->Arg(1 << 11) + ->Arg(1 << 12) + ->Arg(1 << 13) + ->Arg(1 << 14) + ->Arg(1 << 15) + ->Arg(1 << 16) + ->Arg(1 << 17) + ->Arg(1 << 18) + ->Arg(1 << 19) + ->Arg(1 << 20); + +// Benchmark the total time of iterating through all key value pairs of the key +// mapper with size num_keys (specified by Arg). +template <typename KeyMapperType> +void BM_Iterator(benchmark::State& state) { + int num_keys = state.range(0); + + KeyMapperBenchmark benchmark; + benchmark.filesystem.DeleteDirectoryRecursively(benchmark.base_dir.c_str()); + DestructibleDirectory ddir(&benchmark.filesystem, benchmark.base_dir); + + // Create a key mapper with num_keys entries. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<KeyMapper<int>> key_mapper, + benchmark.CreateKeyMapper<KeyMapperType>(num_keys)); + for (int i = 0; i < num_keys; ++i) { + ICING_ASSERT_OK( + key_mapper->Put(benchmark.GenerateUniqueRandomKeyValuePair(i), i)); + } + ASSERT_THAT(key_mapper->num_keys(), Eq(num_keys)); + + for (auto _ : state) { + auto iter = key_mapper->GetIterator(); + int cnt = 0; + while (iter->Advance()) { + ++cnt; + std::string key(iter->GetKey()); + int value = iter->GetValue(); + auto it = benchmark.random_kvps_map.find(key); + ASSERT_THAT(it, Not(Eq(benchmark.random_kvps_map.end()))); + ASSERT_THAT(it->second, Eq(value)); + } + ASSERT_THAT(cnt, Eq(num_keys)); + } +} +BENCHMARK(BM_Iterator<DynamicTrieKeyMapper<int>>) + ->Arg(1 << 10) + ->Arg(1 << 11) + ->Arg(1 << 12) + ->Arg(1 << 13) + ->Arg(1 << 14) + ->Arg(1 << 15) + ->Arg(1 << 16) + ->Arg(1 << 17) + ->Arg(1 << 18) + ->Arg(1 << 19) + ->Arg(1 << 20); +BENCHMARK(BM_Iterator<PersistentHashMapKeyMapper<int>>) + ->Arg(1 << 10) + ->Arg(1 << 11) + ->Arg(1 << 12) + ->Arg(1 << 13) + ->Arg(1 << 14) + ->Arg(1 << 15) + ->Arg(1 << 16) + ->Arg(1 << 17) + ->Arg(1 << 18) + ->Arg(1 << 19) + ->Arg(1 << 20); + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/store/key-mapper_test.cc b/icing/store/key-mapper_test.cc new file mode 100644 index 0000000..682888d --- /dev/null +++ b/icing/store/key-mapper_test.cc @@ -0,0 +1,215 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/store/key-mapper.h" + +#include <memory> +#include <string> +#include <type_traits> +#include <unordered_map> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/file/filesystem.h" +#include "icing/store/document-id.h" +#include "icing/store/dynamic-trie-key-mapper.h" +#include "icing/store/persistent-hash-map-key-mapper.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/tmp-directory.h" + +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +namespace icing { +namespace lib { + +namespace { + +constexpr int kMaxDynamicTrieKeyMapperSize = 3 * 1024 * 1024; // 3 MiB + +template <typename T> +class KeyMapperTest : public ::testing::Test { + protected: + using KeyMapperType = T; + + void SetUp() override { base_dir_ = GetTestTempDir() + "/key_mapper"; } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(base_dir_.c_str()); + } + + template <typename UnknownKeyMapperType> + libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<DocumentId>>> + CreateKeyMapper() { + return absl_ports::InvalidArgumentError("Unknown type"); + } + + template <> + libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<DocumentId>>> + CreateKeyMapper<DynamicTrieKeyMapper<DocumentId>>() { + return DynamicTrieKeyMapper<DocumentId>::Create( + filesystem_, base_dir_, kMaxDynamicTrieKeyMapperSize); + } + + template <> + libtextclassifier3::StatusOr<std::unique_ptr<KeyMapper<DocumentId>>> + CreateKeyMapper<PersistentHashMapKeyMapper<DocumentId>>() { + return PersistentHashMapKeyMapper<DocumentId>::Create(filesystem_, + base_dir_); + } + + std::string base_dir_; + Filesystem filesystem_; +}; + +using TestTypes = ::testing::Types<DynamicTrieKeyMapper<DocumentId>, + PersistentHashMapKeyMapper<DocumentId>>; +TYPED_TEST_SUITE(KeyMapperTest, TestTypes); + +std::unordered_map<std::string, DocumentId> GetAllKeyValuePairs( + const KeyMapper<DocumentId>* key_mapper) { + std::unordered_map<std::string, DocumentId> ret; + + std::unique_ptr<typename KeyMapper<DocumentId>::Iterator> itr = + key_mapper->GetIterator(); + while (itr->Advance()) { + ret.emplace(itr->GetKey(), itr->GetValue()); + } + return ret; +} + +TYPED_TEST(KeyMapperTest, CreateNewKeyMapper) { + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper, + this->template CreateKeyMapper<TypeParam>()); + EXPECT_THAT(key_mapper->num_keys(), 0); +} + +TYPED_TEST(KeyMapperTest, CanUpdateSameKeyMultipleTimes) { + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper, + this->template CreateKeyMapper<TypeParam>()); + + ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); + ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 50)); + + EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100)); + + ICING_EXPECT_OK(key_mapper->Put("default-google.com", 200)); + EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(200)); + EXPECT_THAT(key_mapper->num_keys(), 2); + + ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300)); + EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300)); + EXPECT_THAT(key_mapper->num_keys(), 2); +} + +TYPED_TEST(KeyMapperTest, GetOrPutOk) { + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper, + this->template CreateKeyMapper<TypeParam>()); + + EXPECT_THAT(key_mapper->Get("foo"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(key_mapper->GetOrPut("foo", 1), IsOkAndHolds(1)); + EXPECT_THAT(key_mapper->Get("foo"), IsOkAndHolds(1)); +} + +TYPED_TEST(KeyMapperTest, CanPersistToDiskRegularly) { + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper, + this->template CreateKeyMapper<TypeParam>()); + + // Can persist an empty DynamicTrieKeyMapper. + ICING_EXPECT_OK(key_mapper->PersistToDisk()); + EXPECT_THAT(key_mapper->num_keys(), 0); + + // Can persist the smallest DynamicTrieKeyMapper. + ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); + ICING_EXPECT_OK(key_mapper->PersistToDisk()); + EXPECT_THAT(key_mapper->num_keys(), 1); + EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100)); + + // Can continue to add keys after PersistToDisk(). + ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 200)); + EXPECT_THAT(key_mapper->num_keys(), 2); + EXPECT_THAT(key_mapper->Get("default-youtube.com"), IsOkAndHolds(200)); + + // Can continue to update the same key after PersistToDisk(). + ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300)); + EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300)); + EXPECT_THAT(key_mapper->num_keys(), 2); +} + +TYPED_TEST(KeyMapperTest, CanUseAcrossMultipleInstances) { + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper, + this->template CreateKeyMapper<TypeParam>()); + ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); + ICING_EXPECT_OK(key_mapper->PersistToDisk()); + + key_mapper.reset(); + + ICING_ASSERT_OK_AND_ASSIGN(key_mapper, + this->template CreateKeyMapper<TypeParam>()); + EXPECT_THAT(key_mapper->num_keys(), 1); + EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(100)); + + // Can continue to read/write to the KeyMapper. + ICING_EXPECT_OK(key_mapper->Put("default-youtube.com", 200)); + ICING_EXPECT_OK(key_mapper->Put("default-google.com", 300)); + EXPECT_THAT(key_mapper->num_keys(), 2); + EXPECT_THAT(key_mapper->Get("default-youtube.com"), IsOkAndHolds(200)); + EXPECT_THAT(key_mapper->Get("default-google.com"), IsOkAndHolds(300)); +} + +TYPED_TEST(KeyMapperTest, CanDeleteAndRestartKeyMapping) { + // Can delete even if there's nothing there + ICING_EXPECT_OK( + TestFixture::KeyMapperType::Delete(this->filesystem_, this->base_dir_)); + + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper, + this->template CreateKeyMapper<TypeParam>()); + ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); + ICING_EXPECT_OK(key_mapper->PersistToDisk()); + ICING_EXPECT_OK( + TestFixture::KeyMapperType::Delete(this->filesystem_, this->base_dir_)); + + key_mapper.reset(); + ICING_ASSERT_OK_AND_ASSIGN(key_mapper, + this->template CreateKeyMapper<TypeParam>()); + EXPECT_THAT(key_mapper->num_keys(), 0); + ICING_EXPECT_OK(key_mapper->Put("default-google.com", 100)); + EXPECT_THAT(key_mapper->num_keys(), 1); +} + +TYPED_TEST(KeyMapperTest, Iterator) { + ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<KeyMapper<DocumentId>> key_mapper, + this->template CreateKeyMapper<TypeParam>()); + EXPECT_THAT(GetAllKeyValuePairs(key_mapper.get()), IsEmpty()); + + ICING_EXPECT_OK(key_mapper->Put("foo", /*value=*/1)); + ICING_EXPECT_OK(key_mapper->Put("bar", /*value=*/2)); + EXPECT_THAT(GetAllKeyValuePairs(key_mapper.get()), + UnorderedElementsAre(Pair("foo", 1), Pair("bar", 2))); + + ICING_EXPECT_OK(key_mapper->Put("baz", /*value=*/3)); + EXPECT_THAT( + GetAllKeyValuePairs(key_mapper.get()), + UnorderedElementsAre(Pair("foo", 1), Pair("bar", 2), Pair("baz", 3))); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/store/persistent-hash-map-key-mapper.h b/icing/store/persistent-hash-map-key-mapper.h new file mode 100644 index 0000000..a13ec11 --- /dev/null +++ b/icing/store/persistent-hash-map-key-mapper.h @@ -0,0 +1,209 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_STORE_PERSISTENT_HASH_MAP_KEY_MAPPER_H_ +#define ICING_STORE_PERSISTENT_HASH_MAP_KEY_MAPPER_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> +#include <type_traits> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/str_join.h" +#include "icing/file/filesystem.h" +#include "icing/file/persistent-hash-map.h" +#include "icing/store/key-mapper.h" +#include "icing/util/crc32.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +// File-backed mapping between the string key and a trivially copyable value +// type. +template <typename T, typename Formatter = absl_ports::DefaultFormatter> +class PersistentHashMapKeyMapper : public KeyMapper<T, Formatter> { + public: + // Returns an initialized instance of PersistentHashMapKeyMapper that can + // immediately handle read/write operations. + // Returns any encountered IO errors. + // + // filesystem: Object to make system level calls + // base_dir : Base directory used to save all the files required to persist + // PersistentHashMapKeyMapper. If this base_dir was previously used + // to create a PersistentHashMapKeyMapper, then this existing data + // would be loaded. Otherwise, an empty PersistentHashMapKeyMapper + // would be created. + // max_num_entries: max # of kvps. It will be used to compute 3 storages size. + // average_kv_byte_size: average byte size of a single key + serialized value. + // It will be used to compute kv_storage size. + // max_load_factor_percent: percentage of the max loading for the hash map. + // load_factor_percent = 100 * num_keys / num_buckets + // If load_factor_percent exceeds + // max_load_factor_percent, then rehash will be + // invoked (and # of buckets will be doubled). + // Note that load_factor_percent exceeding 100 is + // considered valid. + static libtextclassifier3::StatusOr< + std::unique_ptr<PersistentHashMapKeyMapper<T, Formatter>>> + Create(const Filesystem& filesystem, std::string_view base_dir, + int32_t max_num_entries = PersistentHashMap::Entry::kMaxNumEntries, + int32_t average_kv_byte_size = + PersistentHashMap::Options::kDefaultAverageKVByteSize, + int32_t max_load_factor_percent = + PersistentHashMap::Options::kDefaultMaxLoadFactorPercent); + + // Deletes all the files associated with the PersistentHashMapKeyMapper. + // + // base_dir : Base directory used to save all the files required to persist + // PersistentHashMapKeyMapper. Should be the same as passed into + // Create(). + // + // Returns: + // OK on success + // INTERNAL_ERROR on I/O error + static libtextclassifier3::Status Delete(const Filesystem& filesystem, + std::string_view base_dir); + + ~PersistentHashMapKeyMapper() override = default; + + libtextclassifier3::Status Put(std::string_view key, T value) override { + return persistent_hash_map_->Put(key, &value); + } + + libtextclassifier3::StatusOr<T> GetOrPut(std::string_view key, + T next_value) override { + ICING_RETURN_IF_ERROR(persistent_hash_map_->GetOrPut(key, &next_value)); + return next_value; + } + + libtextclassifier3::StatusOr<T> Get(std::string_view key) const override { + T value; + ICING_RETURN_IF_ERROR(persistent_hash_map_->Get(key, &value)); + return value; + } + + bool Delete(std::string_view key) override { + return persistent_hash_map_->Delete(key).ok(); + } + + std::unique_ptr<typename KeyMapper<T, Formatter>::Iterator> GetIterator() + const override { + return std::make_unique<PersistentHashMapKeyMapper<T, Formatter>::Iterator>( + persistent_hash_map_.get()); + } + + int32_t num_keys() const override { return persistent_hash_map_->size(); } + + libtextclassifier3::Status PersistToDisk() override { + return persistent_hash_map_->PersistToDisk(); + } + + libtextclassifier3::StatusOr<int64_t> GetDiskUsage() const override { + return persistent_hash_map_->GetDiskUsage(); + } + + libtextclassifier3::StatusOr<int64_t> GetElementsSize() const override { + return persistent_hash_map_->GetElementsSize(); + } + + libtextclassifier3::StatusOr<Crc32> ComputeChecksum() override { + return persistent_hash_map_->ComputeChecksum(); + } + + private: + class Iterator : public KeyMapper<T, Formatter>::Iterator { + public: + explicit Iterator(const PersistentHashMap* persistent_hash_map) + : itr_(persistent_hash_map->GetIterator()) {} + + ~Iterator() override = default; + + bool Advance() override { return itr_.Advance(); } + + std::string_view GetKey() const override { return itr_.GetKey(); } + + T GetValue() const override { + T value; + memcpy(&value, itr_.GetValue(), sizeof(T)); + return value; + } + + private: + PersistentHashMap::Iterator itr_; + }; + + static constexpr std::string_view kKeyMapperDir = "key_mapper_dir"; + + // Use PersistentHashMapKeyMapper::Create() to instantiate. + explicit PersistentHashMapKeyMapper( + std::unique_ptr<PersistentHashMap> persistent_hash_map) + : persistent_hash_map_(std::move(persistent_hash_map)) {} + + std::unique_ptr<PersistentHashMap> persistent_hash_map_; + + static_assert(std::is_trivially_copyable<T>::value, + "T must be trivially copyable"); +}; + +template <typename T, typename Formatter> +/* static */ libtextclassifier3::StatusOr< + std::unique_ptr<PersistentHashMapKeyMapper<T, Formatter>>> +PersistentHashMapKeyMapper<T, Formatter>::Create( + const Filesystem& filesystem, std::string_view base_dir, + int32_t max_num_entries, int32_t average_kv_byte_size, + int32_t max_load_factor_percent) { + const std::string key_mapper_dir = + absl_ports::StrCat(base_dir, "/", kKeyMapperDir); + if (!filesystem.CreateDirectoryRecursively(key_mapper_dir.c_str())) { + return absl_ports::InternalError(absl_ports::StrCat( + "Failed to create PersistentHashMapKeyMapper directory: ", + key_mapper_dir)); + } + + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create( + filesystem, key_mapper_dir, + PersistentHashMap::Options( + /*value_type_size_in=*/sizeof(T), + /*max_num_entries_in=*/max_num_entries, + /*max_load_factor_percent_in=*/max_load_factor_percent, + /*average_kv_byte_size_in=*/average_kv_byte_size))); + return std::unique_ptr<PersistentHashMapKeyMapper<T, Formatter>>( + new PersistentHashMapKeyMapper<T, Formatter>( + std::move(persistent_hash_map))); +} + +template <typename T, typename Formatter> +/* static */ libtextclassifier3::Status +PersistentHashMapKeyMapper<T, Formatter>::Delete(const Filesystem& filesystem, + std::string_view base_dir) { + const std::string key_mapper_dir = + absl_ports::StrCat(base_dir, "/", kKeyMapperDir); + if (!filesystem.DeleteDirectoryRecursively(key_mapper_dir.c_str())) { + return absl_ports::InternalError(absl_ports::StrCat( + "Failed to delete PersistentHashMapKeyMapper directory: ", + key_mapper_dir)); + } + return libtextclassifier3::Status::OK; +} + +} // namespace lib +} // namespace icing + +#endif // ICING_STORE_PERSISTENT_HASH_MAP_KEY_MAPPER_H_ diff --git a/icing/store/persistent-hash-map-key-mapper_test.cc b/icing/store/persistent-hash-map-key-mapper_test.cc new file mode 100644 index 0000000..c937c43 --- /dev/null +++ b/icing/store/persistent-hash-map-key-mapper_test.cc @@ -0,0 +1,52 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/store/persistent-hash-map-key-mapper.h" + +#include <string> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/filesystem.h" +#include "icing/store/document-id.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +class PersistentHashMapKeyMapperTest : public testing::Test { + protected: + void SetUp() override { base_dir_ = GetTestTempDir() + "/key_mapper"; } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(base_dir_.c_str()); + } + + std::string base_dir_; + Filesystem filesystem_; +}; + +TEST_F(PersistentHashMapKeyMapperTest, InvalidBaseDir) { + EXPECT_THAT( + PersistentHashMapKeyMapper<DocumentId>::Create(filesystem_, "/dev/null"), + StatusIs(libtextclassifier3::StatusCode::INTERNAL)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/testing/common-matchers.h b/icing/testing/common-matchers.h index f2738e3..e090800 100644 --- a/icing/testing/common-matchers.h +++ b/icing/testing/common-matchers.h @@ -15,7 +15,10 @@ #ifndef ICING_TESTING_COMMON_MATCHERS_H_ #define ICING_TESTING_COMMON_MATCHERS_H_ +#include <algorithm> #include <cmath> +#include <string> +#include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/status_macros.h" @@ -31,6 +34,7 @@ #include "icing/proto/status.pb.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" #include "icing/util/status-macros.h" namespace icing { @@ -104,19 +108,73 @@ MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id, term_frequency_as_expected; } +class ScoredDocumentHitFormatter { + public: + std::string operator()(const ScoredDocumentHit& scored_document_hit) { + return IcingStringUtil::StringPrintf( + "(document_id=%d, hit_section_id_mask=%" PRId64 ", score=%.2f)", + scored_document_hit.document_id(), + scored_document_hit.hit_section_id_mask(), scored_document_hit.score()); + } +}; + +class ScoredDocumentHitEqualComparator { + public: + bool operator()(const ScoredDocumentHit& lhs, + const ScoredDocumentHit& rhs) const { + return lhs.document_id() == rhs.document_id() && + lhs.hit_section_id_mask() == rhs.hit_section_id_mask() && + std::fabs(lhs.score() - rhs.score()) < 1e-6; + } +}; + // Used to match a ScoredDocumentHit MATCHER_P(EqualsScoredDocumentHit, expected_scored_document_hit, "") { - if (arg.document_id() != expected_scored_document_hit.document_id() || - arg.hit_section_id_mask() != - expected_scored_document_hit.hit_section_id_mask() || - std::fabs(arg.score() - expected_scored_document_hit.score()) > 1e-6) { + ScoredDocumentHitEqualComparator equal_comparator; + if (!equal_comparator(arg, expected_scored_document_hit)) { + ScoredDocumentHitFormatter formatter; + *result_listener << "Expected: " << formatter(expected_scored_document_hit) + << ". Actual: " << formatter(arg); + return false; + } + return true; +} + +// Used to match a JoinedScoredDocumentHit +MATCHER_P(EqualsJoinedScoredDocumentHit, expected_joined_scored_document_hit, + "") { + ScoredDocumentHitEqualComparator equal_comparator; + if (std::fabs(arg.final_score() - + expected_joined_scored_document_hit.final_score()) > 1e-6 || + !equal_comparator( + arg.parent_scored_document_hit(), + expected_joined_scored_document_hit.parent_scored_document_hit()) || + arg.child_scored_document_hits().size() != + expected_joined_scored_document_hit.child_scored_document_hits() + .size() || + !std::equal( + arg.child_scored_document_hits().cbegin(), + arg.child_scored_document_hits().cend(), + expected_joined_scored_document_hit.child_scored_document_hits() + .cbegin(), + equal_comparator)) { + ScoredDocumentHitFormatter formatter; + *result_listener << IcingStringUtil::StringPrintf( - "Expected: document_id=%d, hit_section_id_mask=%d, score=%.2f. Actual: " - "document_id=%d, hit_section_id_mask=%d, score=%.2f", - expected_scored_document_hit.document_id(), - expected_scored_document_hit.hit_section_id_mask(), - expected_scored_document_hit.score(), arg.document_id(), - arg.hit_section_id_mask(), arg.score()); + "Expected: final_score=%.2f, parent_scored_document_hit=%s, " + "child_scored_document_hits=[%s]. Actual: final_score=%.2f, " + "parent_scored_document_hit=%s, child_scored_document_hits=[%s]", + expected_joined_scored_document_hit.final_score(), + formatter( + expected_joined_scored_document_hit.parent_scored_document_hit()) + .c_str(), + absl_ports::StrJoin( + expected_joined_scored_document_hit.child_scored_document_hits(), + ",", formatter) + .c_str(), + arg.final_score(), formatter(arg.parent_scored_document_hit()).c_str(), + absl_ports::StrJoin(arg.child_scored_document_hits(), ",", formatter) + .c_str()); return false; } return true; @@ -435,6 +493,11 @@ MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") { actual_copy.clear_debug_info(); for (SearchResultProto::ResultProto& result : *actual_copy.mutable_results()) { + // Joined results + for (SearchResultProto::ResultProto& joined_result : + *result.mutable_joined_results()) { + joined_result.clear_score(); + } result.clear_score(); } @@ -443,6 +506,11 @@ MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") { expected_copy.clear_debug_info(); for (SearchResultProto::ResultProto& result : *expected_copy.mutable_results()) { + // Joined results + for (SearchResultProto::ResultProto& joined_result : + *result.mutable_joined_results()) { + joined_result.clear_score(); + } result.clear_score(); } return ExplainMatchResult(portable_equals_proto::EqualsProto(expected_copy), diff --git a/icing/tokenization/combined-tokenizer_test.cc b/icing/tokenization/combined-tokenizer_test.cc index 42c7743..8314e91 100644 --- a/icing/tokenization/combined-tokenizer_test.cc +++ b/icing/tokenization/combined-tokenizer_test.cc @@ -142,6 +142,7 @@ TEST_F(CombinedTokenizerTest, Negation) { EXPECT_THAT(query_terms, ElementsAre("foo", "bar", "baz")); } +// TODO(b/254874614): Handle colon word breaks in ICU 72+ TEST_F(CombinedTokenizerTest, Colons) { const std::string_view kText = ":foo: :bar baz:"; ICING_ASSERT_OK_AND_ASSIGN( @@ -165,6 +166,7 @@ TEST_F(CombinedTokenizerTest, Colons) { EXPECT_THAT(query_terms, ElementsAre("foo", "bar", "baz")); } +// TODO(b/254874614): Handle colon word breaks in ICU 72+ TEST_F(CombinedTokenizerTest, ColonsPropertyRestricts) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Tokenizer> indexing_tokenizer, @@ -176,33 +178,61 @@ TEST_F(CombinedTokenizerTest, ColonsPropertyRestricts) { CreateQueryTokenizer(tokenizer_factory::QueryTokenizerType::RAW_QUERY, lang_segmenter_.get())); - // This is a difference between the two tokenizers. "foo:bar" is a single - // token to the plain tokenizer because ':' is a word connector. But "foo:bar" - // is a property restrict to the query tokenizer - so "foo" is the property - // and "bar" is the only text term. - constexpr std::string_view kText = "foo:bar"; - ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> indexing_tokens, - indexing_tokenizer->TokenizeAll(kText)); - std::vector<std::string> indexing_terms = GetTokenTerms(indexing_tokens); - EXPECT_THAT(indexing_terms, ElementsAre("foo:bar")); - - ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> query_tokens, - query_tokenizer->TokenizeAll(kText)); - std::vector<std::string> query_terms = GetTokenTerms(query_tokens); - EXPECT_THAT(query_terms, ElementsAre("bar")); - - // This difference, however, should only apply to the first ':'. A - // second ':' should be treated by both tokenizers as a word connector. - constexpr std::string_view kText2 = "foo:bar:baz"; - ICING_ASSERT_OK_AND_ASSIGN(indexing_tokens, - indexing_tokenizer->TokenizeAll(kText2)); - indexing_terms = GetTokenTerms(indexing_tokens); - EXPECT_THAT(indexing_terms, ElementsAre("foo:bar:baz")); - - ICING_ASSERT_OK_AND_ASSIGN(query_tokens, - query_tokenizer->TokenizeAll(kText2)); - query_terms = GetTokenTerms(query_tokens); - EXPECT_THAT(query_terms, ElementsAre("bar:baz")); + if (IsIcu72PlusTokenization()) { + // In ICU 72+ and above, ':' are no longer considered word connectors. The + // query tokenizer should still consider them to be property restricts. + constexpr std::string_view kText = "foo:bar"; + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> indexing_tokens, + indexing_tokenizer->TokenizeAll(kText)); + std::vector<std::string> indexing_terms = GetTokenTerms(indexing_tokens); + EXPECT_THAT(indexing_terms, ElementsAre("foo", "bar")); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> query_tokens, + query_tokenizer->TokenizeAll(kText)); + std::vector<std::string> query_terms = GetTokenTerms(query_tokens); + EXPECT_THAT(query_terms, ElementsAre("bar")); + + // This difference, however, should only apply to the first ':'. Both should + // consider a second ':' to be a word break. + constexpr std::string_view kText2 = "foo:bar:baz"; + ICING_ASSERT_OK_AND_ASSIGN(indexing_tokens, + indexing_tokenizer->TokenizeAll(kText2)); + indexing_terms = GetTokenTerms(indexing_tokens); + EXPECT_THAT(indexing_terms, ElementsAre("foo", "bar", "baz")); + + ICING_ASSERT_OK_AND_ASSIGN(query_tokens, + query_tokenizer->TokenizeAll(kText2)); + query_terms = GetTokenTerms(query_tokens); + EXPECT_THAT(query_terms, ElementsAre("bar", "baz")); + } else { + // This is a difference between the two tokenizers. "foo:bar" is a single + // token to the plain tokenizer because ':' is a word connector. But + // "foo:bar" is a property restrict to the query tokenizer - so "foo" is the + // property and "bar" is the only text term. + constexpr std::string_view kText = "foo:bar"; + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> indexing_tokens, + indexing_tokenizer->TokenizeAll(kText)); + std::vector<std::string> indexing_terms = GetTokenTerms(indexing_tokens); + EXPECT_THAT(indexing_terms, ElementsAre("foo:bar")); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<Token> query_tokens, + query_tokenizer->TokenizeAll(kText)); + std::vector<std::string> query_terms = GetTokenTerms(query_tokens); + EXPECT_THAT(query_terms, ElementsAre("bar")); + + // This difference, however, should only apply to the first ':'. A + // second ':' should be treated by both tokenizers as a word connector. + constexpr std::string_view kText2 = "foo:bar:baz"; + ICING_ASSERT_OK_AND_ASSIGN(indexing_tokens, + indexing_tokenizer->TokenizeAll(kText2)); + indexing_terms = GetTokenTerms(indexing_tokens); + EXPECT_THAT(indexing_terms, ElementsAre("foo:bar:baz")); + + ICING_ASSERT_OK_AND_ASSIGN(query_tokens, + query_tokenizer->TokenizeAll(kText2)); + query_terms = GetTokenTerms(query_tokens); + EXPECT_THAT(query_terms, ElementsAre("bar:baz")); + } } TEST_F(CombinedTokenizerTest, Punctuation) { diff --git a/icing/tokenization/icu/icu-language-segmenter_test.cc b/icing/tokenization/icu/icu-language-segmenter_test.cc index 71e04e2..6771050 100644 --- a/icing/tokenization/icu/icu-language-segmenter_test.cc +++ b/icing/tokenization/icu/icu-language-segmenter_test.cc @@ -21,6 +21,7 @@ #include "gtest/gtest.h" #include "icing/absl_ports/str_cat.h" #include "icing/jni/jni-cache.h" +#include "icing/portable/platform.h" #include "icing/testing/common-matchers.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/icu-i18n-test-utils.h" @@ -118,6 +119,9 @@ class IcuLanguageSegmenterAllLocalesTest : public testing::TestWithParam<const char*> { protected: void SetUp() override { + if (!IsIcuTokenization()) { + GTEST_SKIP() << "ICU tokenization not enabled!"; + } ICING_ASSERT_OK( // File generated via icu_data_file rule in //icing/BUILD. icu_data_file_helper::SetUpICUDataFile( @@ -223,46 +227,42 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, WordConnector) { // Word connecters EXPECT_THAT(language_segmenter->GetAllTerms("com.google.android"), IsOkAndHolds(ElementsAre("com.google.android"))); - EXPECT_THAT(language_segmenter->GetAllTerms("com:google:android"), - IsOkAndHolds(ElementsAre("com:google:android"))); EXPECT_THAT(language_segmenter->GetAllTerms("com'google'android"), IsOkAndHolds(ElementsAre("com'google'android"))); EXPECT_THAT(language_segmenter->GetAllTerms("com_google_android"), IsOkAndHolds(ElementsAre("com_google_android"))); // Word connecters can be mixed - EXPECT_THAT(language_segmenter->GetAllTerms("com.google.android:icing"), - IsOkAndHolds(ElementsAre("com.google.android:icing"))); + EXPECT_THAT(language_segmenter->GetAllTerms("com.google.android_icing"), + IsOkAndHolds(ElementsAre("com.google.android_icing"))); // Connectors that don't have valid terms on both sides of it are not // considered connectors. - EXPECT_THAT(language_segmenter->GetAllTerms(":bar:baz"), - IsOkAndHolds(ElementsAre(":", "bar:baz"))); + EXPECT_THAT(language_segmenter->GetAllTerms("'bar'baz"), + IsOkAndHolds(ElementsAre("'", "bar'baz"))); - EXPECT_THAT(language_segmenter->GetAllTerms("bar:baz:"), - IsOkAndHolds(ElementsAre("bar:baz", ":"))); + EXPECT_THAT(language_segmenter->GetAllTerms("bar.baz."), + IsOkAndHolds(ElementsAre("bar.baz", "."))); // Connectors that don't have valid terms on both sides of it are not // considered connectors. - EXPECT_THAT(language_segmenter->GetAllTerms(" :bar:baz"), - IsOkAndHolds(ElementsAre(" ", ":", "bar:baz"))); + EXPECT_THAT(language_segmenter->GetAllTerms(" .bar.baz"), + IsOkAndHolds(ElementsAre(" ", ".", "bar.baz"))); - EXPECT_THAT(language_segmenter->GetAllTerms("bar:baz: "), - IsOkAndHolds(ElementsAre("bar:baz", ":", " "))); + EXPECT_THAT(language_segmenter->GetAllTerms("bar'baz' "), + IsOkAndHolds(ElementsAre("bar'baz", "'", " "))); // Connectors don't connect if one side is an invalid term (?) - EXPECT_THAT(language_segmenter->GetAllTerms("bar:baz:?"), - IsOkAndHolds(ElementsAre("bar:baz", ":", "?"))); - EXPECT_THAT(language_segmenter->GetAllTerms("?:bar:baz"), - IsOkAndHolds(ElementsAre("?", ":", "bar:baz"))); - EXPECT_THAT(language_segmenter->GetAllTerms("3:14"), - IsOkAndHolds(ElementsAre("3", ":", "14"))); - EXPECT_THAT(language_segmenter->GetAllTerms("私:は"), - IsOkAndHolds(ElementsAre("私", ":", "は"))); - EXPECT_THAT(language_segmenter->GetAllTerms("我:每"), - IsOkAndHolds(ElementsAre("我", ":", "每"))); - EXPECT_THAT(language_segmenter->GetAllTerms("เดิน:ไป"), - IsOkAndHolds(ElementsAre("เดิน:ไป"))); + EXPECT_THAT(language_segmenter->GetAllTerms("bar.baz.?"), + IsOkAndHolds(ElementsAre("bar.baz", ".", "?"))); + EXPECT_THAT(language_segmenter->GetAllTerms("?'bar'baz"), + IsOkAndHolds(ElementsAre("?", "'", "bar'baz"))); + EXPECT_THAT(language_segmenter->GetAllTerms("私'は"), + IsOkAndHolds(ElementsAre("私", "'", "は"))); + EXPECT_THAT(language_segmenter->GetAllTerms("我.每"), + IsOkAndHolds(ElementsAre("我", ".", "每"))); + EXPECT_THAT(language_segmenter->GetAllTerms("เดิน'ไป"), + IsOkAndHolds(ElementsAre("เดิน'ไป"))); // Any heading and trailing characters are not connecters EXPECT_THAT(language_segmenter->GetAllTerms(".com.google.android."), @@ -277,8 +277,6 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, WordConnector) { IsOkAndHolds(ElementsAre("com", "+", "google", "+", "android"))); EXPECT_THAT(language_segmenter->GetAllTerms("com*google*android"), IsOkAndHolds(ElementsAre("com", "*", "google", "*", "android"))); - EXPECT_THAT(language_segmenter->GetAllTerms("com@google@android"), - IsOkAndHolds(ElementsAre("com", "@", "google", "@", "android"))); EXPECT_THAT(language_segmenter->GetAllTerms("com^google^android"), IsOkAndHolds(ElementsAre("com", "^", "google", "^", "android"))); EXPECT_THAT(language_segmenter->GetAllTerms("com&google&android"), @@ -292,6 +290,29 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, WordConnector) { EXPECT_THAT( language_segmenter->GetAllTerms("com\"google\"android"), IsOkAndHolds(ElementsAre("com", "\"", "google", "\"", "android"))); + + // In ICU 72, there were a few changes: + // 1. ':' stopped being a word connector + // 2. '@' became a word connector + // 3. <numeric><word-connector><numeric> such as "3'14" is now considered as + // a single token. + if (IsIcu72PlusTokenization()) { + EXPECT_THAT( + language_segmenter->GetAllTerms("com:google:android"), + IsOkAndHolds(ElementsAre("com", ":", "google", ":", "android"))); + EXPECT_THAT(language_segmenter->GetAllTerms("com@google@android"), + IsOkAndHolds(ElementsAre("com@google@android"))); + EXPECT_THAT(language_segmenter->GetAllTerms("3'14"), + IsOkAndHolds(ElementsAre("3'14"))); + } else { + EXPECT_THAT(language_segmenter->GetAllTerms("com:google:android"), + IsOkAndHolds(ElementsAre("com:google:android"))); + EXPECT_THAT( + language_segmenter->GetAllTerms("com@google@android"), + IsOkAndHolds(ElementsAre("com", "@", "google", "@", "android"))); + EXPECT_THAT(language_segmenter->GetAllTerms("3'14"), + IsOkAndHolds(ElementsAre("3", "'", "14"))); + } } TEST_P(IcuLanguageSegmenterAllLocalesTest, Apostrophes) { @@ -494,17 +515,17 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, ResetToStartUtf32WordConnector) { ICING_ASSERT_OK_AND_ASSIGN( auto segmenter, language_segmenter_factory::Create( GetSegmenterOptions(GetLocale(), jni_cache_.get()))); - constexpr std::string_view kText = "com:google:android is package"; + constexpr std::string_view kText = "com.google.android is package"; ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageSegmenter::Iterator> itr, segmenter->Segment(kText)); - // String: "com:google:android is package" + // String: "com.google.android is package" // ^ ^^ ^^ // UTF-8 idx: 0 18 19 21 22 // UTF-32 idx: 0 18 19 21 22 auto position_or = itr->ResetToStartUtf32(); EXPECT_THAT(position_or, IsOk()); - ASSERT_THAT(itr->GetTerm(), Eq("com:google:android")); + ASSERT_THAT(itr->GetTerm(), Eq("com.google.android")); } TEST_P(IcuLanguageSegmenterAllLocalesTest, NewIteratorResetToStartUtf32) { @@ -585,11 +606,11 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, ResetToTermAfterUtf32WordConnector) { ICING_ASSERT_OK_AND_ASSIGN( auto segmenter, language_segmenter_factory::Create( GetSegmenterOptions(GetLocale(), jni_cache_.get()))); - constexpr std::string_view kText = "package com:google:android name"; + constexpr std::string_view kText = "package com.google.android name"; ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageSegmenter::Iterator> itr, segmenter->Segment(kText)); - // String: "package com:google:android name" + // String: "package com.google.android name" // ^ ^^ ^^ // UTF-8 idx: 0 7 8 26 27 // UTF-32 idx: 0 7 8 26 27 @@ -601,7 +622,7 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, ResetToTermAfterUtf32WordConnector) { position_or = itr->ResetToTermStartingAfterUtf32(7); EXPECT_THAT(position_or, IsOk()); EXPECT_THAT(position_or.ValueOrDie(), Eq(8)); - ASSERT_THAT(itr->GetTerm(), Eq("com:google:android")); + ASSERT_THAT(itr->GetTerm(), Eq("com.google.android")); } TEST_P(IcuLanguageSegmenterAllLocalesTest, ResetToTermAfterUtf32OutOfBounds) { @@ -961,18 +982,18 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, ICING_ASSERT_OK_AND_ASSIGN( auto segmenter, language_segmenter_factory::Create( GetSegmenterOptions(GetLocale(), jni_cache_.get()))); - constexpr std::string_view kText = "package name com:google:android!"; + constexpr std::string_view kText = "package name com.google.android!"; ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<LanguageSegmenter::Iterator> itr, segmenter->Segment(kText)); - // String: "package name com:google:android!" + // String: "package name com.google.android!" // ^ ^^ ^^ ^ // UTF-8 idx: 0 7 8 12 13 31 // UTF-32 idx: 0 7 8 12 13 31 auto position_or = itr->ResetToTermEndingBeforeUtf32(31); EXPECT_THAT(position_or, IsOk()); EXPECT_THAT(position_or.ValueOrDie(), Eq(13)); - ASSERT_THAT(itr->GetTerm(), Eq("com:google:android")); + ASSERT_THAT(itr->GetTerm(), Eq("com.google.android")); position_or = itr->ResetToTermEndingBeforeUtf32(21); EXPECT_THAT(position_or, IsOk()); diff --git a/icing/tokenization/raw-query-tokenizer_test.cc b/icing/tokenization/raw-query-tokenizer_test.cc index 2af4f18..a00f2f7 100644 --- a/icing/tokenization/raw-query-tokenizer_test.cc +++ b/icing/tokenization/raw-query-tokenizer_test.cc @@ -225,7 +225,7 @@ TEST_F(RawQueryTokenizerTest, Parentheses) { HasSubstr("Too many right parentheses"))); } -TEST_F(RawQueryTokenizerTest, Exclustion) { +TEST_F(RawQueryTokenizerTest, Exclusion) { language_segmenter_factory::SegmenterOptions options(ULOC_US); ICING_ASSERT_OK_AND_ASSIGN( auto language_segmenter, @@ -344,13 +344,23 @@ TEST_F(RawQueryTokenizerTest, PropertyRestriction) { EqualsToken(Token::Type::QUERY_PROPERTY, "email.title"), EqualsToken(Token::Type::REGULAR, "hello")))); - // The first colon ":" triggers property restriction, the second colon is used - // as a word connector per ICU's rule - // (https://unicode.org/reports/tr29/#Word_Boundaries). - EXPECT_THAT(raw_query_tokenizer->TokenizeAll("property:foo:bar"), - IsOkAndHolds(ElementsAre( - EqualsToken(Token::Type::QUERY_PROPERTY, "property"), - EqualsToken(Token::Type::REGULAR, "foo:bar")))); + // The first colon ":" triggers property restriction. Pre ICU 72, ':' was + // considered a word connector, so the second ':' will be interepreted as a + // connector pre-ICU 72. For ICU 72 and above, it's no longer considered a + // connector. + // TODO(b/254874614): Handle colon word breaks in ICU 72+ + if (IsIcu72PlusTokenization()) { + EXPECT_THAT(raw_query_tokenizer->TokenizeAll("property:foo:bar"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::QUERY_PROPERTY, "property"), + EqualsToken(Token::Type::REGULAR, "foo"), + EqualsToken(Token::Type::REGULAR, "bar")))); + } else { + EXPECT_THAT(raw_query_tokenizer->TokenizeAll("property:foo:bar"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::QUERY_PROPERTY, "property"), + EqualsToken(Token::Type::REGULAR, "foo:bar")))); + } // Property restriction only applies to the term right after it. // Note: "term1:term2" is not a term but 2 terms because word connectors diff --git a/icing/testing/snippet-helpers.cc b/icing/util/snippet-helpers.cc index 3105073..6d6277f 100644 --- a/icing/testing/snippet-helpers.cc +++ b/icing/util/snippet-helpers.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/testing/snippet-helpers.h" +#include "icing/util/snippet-helpers.h" #include <algorithm> #include <string_view> diff --git a/icing/testing/snippet-helpers.h b/icing/util/snippet-helpers.h index 73b2ce2..73b2ce2 100644 --- a/icing/testing/snippet-helpers.h +++ b/icing/util/snippet-helpers.h diff --git a/icing/util/tokenized-document.cc b/icing/util/tokenized-document.cc index e741987..facb267 100644 --- a/icing/util/tokenized-document.cc +++ b/icing/util/tokenized-document.cc @@ -31,29 +31,14 @@ namespace icing { namespace lib { -libtextclassifier3::StatusOr<TokenizedDocument> TokenizedDocument::Create( - const SchemaStore* schema_store, - const LanguageSegmenter* language_segmenter, DocumentProto document) { - TokenizedDocument tokenized_document(std::move(document)); - ICING_RETURN_IF_ERROR( - tokenized_document.Tokenize(schema_store, language_segmenter)); - return tokenized_document; -} +namespace { -TokenizedDocument::TokenizedDocument(DocumentProto document) - : document_(std::move(document)) {} - -libtextclassifier3::Status TokenizedDocument::Tokenize( +libtextclassifier3::StatusOr<std::vector<TokenizedSection>> Tokenize( const SchemaStore* schema_store, - const LanguageSegmenter* language_segmenter) { - DocumentValidator validator(schema_store); - ICING_RETURN_IF_ERROR(validator.Validate(document_)); - - ICING_ASSIGN_OR_RETURN(SectionGroup section_group, - schema_store->ExtractSections(document_)); - // string sections - for (const Section<std::string_view>& section : - section_group.string_sections) { + const LanguageSegmenter* language_segmenter, + const std::vector<Section<std::string_view>>& string_sections) { + std::vector<TokenizedSection> tokenized_string_sections; + for (const Section<std::string_view>& section : string_sections) { ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer> tokenizer, tokenizer_factory::CreateIndexingTokenizer( section.metadata.tokenizer, language_segmenter)); @@ -68,11 +53,34 @@ libtextclassifier3::Status TokenizedDocument::Tokenize( } } } - tokenized_sections_.emplace_back(SectionMetadata(section.metadata), - std::move(token_sequence)); + tokenized_string_sections.emplace_back(SectionMetadata(section.metadata), + std::move(token_sequence)); } - return libtextclassifier3::Status::OK; + return tokenized_string_sections; +} + +} // namespace + +/* static */ libtextclassifier3::StatusOr<TokenizedDocument> +TokenizedDocument::Create(const SchemaStore* schema_store, + const LanguageSegmenter* language_segmenter, + DocumentProto document) { + DocumentValidator validator(schema_store); + ICING_RETURN_IF_ERROR(validator.Validate(document)); + + ICING_ASSIGN_OR_RETURN(SectionGroup section_group, + schema_store->ExtractSections(document)); + + // Tokenize string sections + ICING_ASSIGN_OR_RETURN( + std::vector<TokenizedSection> tokenized_string_sections, + Tokenize(schema_store, language_segmenter, + section_group.string_sections)); + + return TokenizedDocument(std::move(document), + std::move(tokenized_string_sections), + std::move(section_group.integer_sections)); } } // namespace lib diff --git a/icing/util/tokenized-document.h b/icing/util/tokenized-document.h index 5d996d9..5729df2 100644 --- a/icing/util/tokenized-document.h +++ b/icing/util/tokenized-document.h @@ -46,28 +46,35 @@ class TokenizedDocument { const DocumentProto& document() const { return document_; } - int32_t num_tokens() const { - int32_t num_tokens = 0; - for (const TokenizedSection& section : tokenized_sections_) { - num_tokens += section.token_sequence.size(); + int32_t num_string_tokens() const { + int32_t num_string_tokens = 0; + for (const TokenizedSection& section : tokenized_string_sections_) { + num_string_tokens += section.token_sequence.size(); } - return num_tokens; + return num_string_tokens; } - const std::vector<TokenizedSection>& sections() const { - return tokenized_sections_; + const std::vector<TokenizedSection>& tokenized_string_sections() const { + return tokenized_string_sections_; + } + + const std::vector<Section<int64_t>>& integer_sections() const { + return integer_sections_; } private: // Use TokenizedDocument::Create() to instantiate. - explicit TokenizedDocument(DocumentProto document); + explicit TokenizedDocument( + DocumentProto&& document, + std::vector<TokenizedSection>&& tokenized_string_sections, + std::vector<Section<int64_t>>&& integer_sections) + : document_(std::move(document)), + tokenized_string_sections_(std::move(tokenized_string_sections)), + integer_sections_(std::move(integer_sections)) {} DocumentProto document_; - std::vector<TokenizedSection> tokenized_sections_; - - libtextclassifier3::Status Tokenize( - const SchemaStore* schema_store, - const LanguageSegmenter* language_segmenter); + std::vector<TokenizedSection> tokenized_string_sections_; + std::vector<Section<int64_t>> integer_sections_; }; } // namespace lib diff --git a/icing/util/tokenized-document_test.cc b/icing/util/tokenized-document_test.cc new file mode 100644 index 0000000..3497bef --- /dev/null +++ b/icing/util/tokenized-document_test.cc @@ -0,0 +1,335 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/util/tokenized-document.h" + +#include <memory> +#include <string> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/portable/platform.h" +#include "icing/proto/document.pb.h" +#include "icing/proto/schema.pb.h" +#include "icing/proto/term.pb.h" +#include "icing/schema-builder.h" +#include "icing/schema/schema-store.h" +#include "icing/schema/section.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/tokenization/language-segmenter.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::EqualsProto; +using ::testing::IsEmpty; +using ::testing::SizeIs; + +// schema types +constexpr std::string_view kFakeType = "FakeType"; + +// Indexable properties and section Id. Section Id is determined by the +// lexicographical order of indexable property path. +constexpr std::string_view kIndexableIntegerProperty1 = "indexableInteger1"; +constexpr std::string_view kIndexableIntegerProperty2 = "indexableInteger2"; +constexpr std::string_view kStringExactProperty = "stringExact"; +constexpr std::string_view kStringPrefixProperty = "stringPrefix"; + +constexpr SectionId kIndexableInteger1SectionId = 0; +constexpr SectionId kIndexableInteger2SectionId = 1; +constexpr SectionId kStringExactSectionId = 2; +constexpr SectionId kStringPrefixSectionId = 3; + +const SectionMetadata kIndexableInteger1SectionMetadata( + kIndexableInteger1SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, + NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty1)); + +const SectionMetadata kIndexableInteger2SectionMetadata( + kIndexableInteger2SectionId, TYPE_INT64, TOKENIZER_NONE, TERM_MATCH_UNKNOWN, + NUMERIC_MATCH_RANGE, std::string(kIndexableIntegerProperty2)); + +const SectionMetadata kStringExactSectionMetadata( + kStringExactSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_EXACT, + NUMERIC_MATCH_UNKNOWN, std::string(kStringExactProperty)); + +const SectionMetadata kStringPrefixSectionMetadata( + kStringPrefixSectionId, TYPE_STRING, TOKENIZER_PLAIN, TERM_MATCH_PREFIX, + NUMERIC_MATCH_UNKNOWN, std::string(kStringPrefixProperty)); + +// Other non-indexable properties. +constexpr std::string_view kUnindexedStringProperty = "unindexedString"; +constexpr std::string_view kUnindexedIntegerProperty = "unindexedInteger"; + +class TokenizedDocumentTest : public ::testing::Test { + protected: + void SetUp() override { + test_dir_ = GetTestTempDir() + "/icing"; + schema_store_dir_ = test_dir_ + "/schema_store"; + filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()); + + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + language_segmenter_factory::SegmenterOptions options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + lang_segmenter_, + language_segmenter_factory::Create(std::move(options))); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + + SchemaProto schema = + SchemaBuilder() + .AddType( + SchemaTypeConfigBuilder() + .SetType(kFakeType) + .AddProperty(PropertyConfigBuilder() + .SetName(kUnindexedStringProperty) + .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName(kUnindexedIntegerProperty) + .SetDataType(TYPE_INT64) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName(kIndexableIntegerProperty1) + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty(PropertyConfigBuilder() + .SetName(kIndexableIntegerProperty2) + .SetDataTypeInt64(NUMERIC_MATCH_RANGE) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName(kStringExactProperty) + .SetDataTypeString(TERM_MATCH_EXACT, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_REPEATED)) + .AddProperty(PropertyConfigBuilder() + .SetName(kStringPrefixProperty) + .SetDataTypeString(TERM_MATCH_PREFIX, + TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + ICING_ASSERT_OK(schema_store_->SetSchema(schema)); + } + + void TearDown() override { + schema_store_.reset(); + + // Check that the schema store directory is the *only* directory in the + // schema_store_dir_. IOW, ensure that all temporary directories have been + // properly cleaned up. + std::vector<std::string> sub_dirs; + ASSERT_TRUE(filesystem_.ListDirectory(test_dir_.c_str(), &sub_dirs)); + ASSERT_THAT(sub_dirs, ElementsAre("schema_store")); + + // Finally, clean everything up. + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(test_dir_.c_str())); + } + + Filesystem filesystem_; + FakeClock fake_clock_; + std::string test_dir_; + std::string schema_store_dir_; + std::unique_ptr<LanguageSegmenter> lang_segmenter_; + std::unique_ptr<SchemaStore> schema_store_; +}; + +TEST_F(TokenizedDocumentTest, CreateAll) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kUnindexedStringProperty), + "hello world unindexed") + .AddStringProperty(std::string(kStringExactProperty), "test foo", + "test bar", "test baz") + .AddStringProperty(std::string(kStringPrefixProperty), "foo bar baz") + .AddInt64Property(std::string(kUnindexedIntegerProperty), 789) + .AddInt64Property(std::string(kIndexableIntegerProperty1), 1, 2, 3) + .AddInt64Property(std::string(kIndexableIntegerProperty2), 456) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + + EXPECT_THAT(tokenized_document.document(), EqualsProto(document)); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(9)); + + // string sections + EXPECT_THAT(tokenized_document.tokenized_string_sections(), SizeIs(2)); + EXPECT_THAT(tokenized_document.tokenized_string_sections().at(0).metadata, + Eq(kStringExactSectionMetadata)); + EXPECT_THAT( + tokenized_document.tokenized_string_sections().at(0).token_sequence, + ElementsAre("test", "foo", "test", "bar", "test", "baz")); + EXPECT_THAT(tokenized_document.tokenized_string_sections().at(1).metadata, + Eq(kStringPrefixSectionMetadata)); + EXPECT_THAT( + tokenized_document.tokenized_string_sections().at(1).token_sequence, + ElementsAre("foo", "bar", "baz")); + + // integer sections + EXPECT_THAT(tokenized_document.integer_sections(), SizeIs(2)); + EXPECT_THAT(tokenized_document.integer_sections().at(0).metadata, + Eq(kIndexableInteger1SectionMetadata)); + EXPECT_THAT(tokenized_document.integer_sections().at(0).content, + ElementsAre(1, 2, 3)); + EXPECT_THAT(tokenized_document.integer_sections().at(1).metadata, + Eq(kIndexableInteger2SectionMetadata)); + EXPECT_THAT(tokenized_document.integer_sections().at(1).content, + ElementsAre(456)); +} + +TEST_F(TokenizedDocumentTest, CreateNoIndexableIntegerProperties) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddInt64Property(std::string(kUnindexedIntegerProperty), 789) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + + EXPECT_THAT(tokenized_document.document(), EqualsProto(document)); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0)); + + // string sections + EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty()); + + // integer sections + EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); +} + +TEST_F(TokenizedDocumentTest, CreateMultipleIndexableIntegerProperties) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddInt64Property(std::string(kUnindexedIntegerProperty), 789) + .AddInt64Property(std::string(kIndexableIntegerProperty1), 1, 2, 3) + .AddInt64Property(std::string(kIndexableIntegerProperty2), 456) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + + EXPECT_THAT(tokenized_document.document(), EqualsProto(document)); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0)); + + // string sections + EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty()); + + // integer sections + EXPECT_THAT(tokenized_document.integer_sections(), SizeIs(2)); + EXPECT_THAT(tokenized_document.integer_sections().at(0).metadata, + Eq(kIndexableInteger1SectionMetadata)); + EXPECT_THAT(tokenized_document.integer_sections().at(0).content, + ElementsAre(1, 2, 3)); + EXPECT_THAT(tokenized_document.integer_sections().at(1).metadata, + Eq(kIndexableInteger2SectionMetadata)); + EXPECT_THAT(tokenized_document.integer_sections().at(1).content, + ElementsAre(456)); +} + +TEST_F(TokenizedDocumentTest, CreateNoIndexableStringProperties) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kUnindexedStringProperty), + "hello world unindexed") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + + EXPECT_THAT(tokenized_document.document(), EqualsProto(document)); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(0)); + + // string sections + EXPECT_THAT(tokenized_document.tokenized_string_sections(), IsEmpty()); + + // integer sections + EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); +} + +TEST_F(TokenizedDocumentTest, CreateMultipleIndexableStringProperties) { + DocumentProto document = + DocumentBuilder() + .SetKey("icing", "fake_type/1") + .SetSchema(std::string(kFakeType)) + .AddStringProperty(std::string(kUnindexedStringProperty), + "hello world unindexed") + .AddStringProperty(std::string(kStringExactProperty), "test foo", + "test bar", "test baz") + .AddStringProperty(std::string(kStringPrefixProperty), "foo bar baz") + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + TokenizedDocument tokenized_document, + TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), + document)); + + EXPECT_THAT(tokenized_document.document(), EqualsProto(document)); + EXPECT_THAT(tokenized_document.num_string_tokens(), Eq(9)); + + // string sections + EXPECT_THAT(tokenized_document.tokenized_string_sections(), SizeIs(2)); + EXPECT_THAT(tokenized_document.tokenized_string_sections().at(0).metadata, + Eq(kStringExactSectionMetadata)); + EXPECT_THAT( + tokenized_document.tokenized_string_sections().at(0).token_sequence, + ElementsAre("test", "foo", "test", "bar", "test", "baz")); + EXPECT_THAT(tokenized_document.tokenized_string_sections().at(1).metadata, + Eq(kStringPrefixSectionMetadata)); + EXPECT_THAT( + tokenized_document.tokenized_string_sections().at(1).token_sequence, + ElementsAre("foo", "bar", "baz")); + + // integer sections + EXPECT_THAT(tokenized_document.integer_sections(), IsEmpty()); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java index 81223f2..47b94a5 100644 --- a/java/src/com/google/android/icing/IcingSearchEngine.java +++ b/java/src/com/google/android/icing/IcingSearchEngine.java @@ -14,7 +14,6 @@ package com.google.android.icing; -import android.util.Log; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import com.google.android.icing.proto.DebugInfoResultProto; @@ -45,17 +44,15 @@ import com.google.android.icing.proto.ScoringSpecProto; import com.google.android.icing.proto.SearchResultProto; import com.google.android.icing.proto.SearchSpecProto; import com.google.android.icing.proto.SetSchemaResultProto; -import com.google.android.icing.proto.StatusProto; import com.google.android.icing.proto.StorageInfoResultProto; import com.google.android.icing.proto.SuggestionResponse; import com.google.android.icing.proto.SuggestionSpecProto; import com.google.android.icing.proto.UsageReport; -import com.google.protobuf.ExtensionRegistryLite; -import com.google.protobuf.InvalidProtocolBufferException; -import java.io.Closeable; /** - * Java wrapper to access native APIs in external/icing/icing/icing-search-engine.h + * Java wrapper to access {@link IcingSearchEngineImpl}. + * + * <p>It converts byte array from {@link IcingSearchEngineImpl} to corresponding protos. * * <p>If this instance has been closed, the instance is no longer usable. * @@ -63,574 +60,197 @@ import java.io.Closeable; * * <p>NOTE: This class is NOT thread-safe. */ -public class IcingSearchEngine implements Closeable { +public class IcingSearchEngine implements IcingSearchEngineInterface { private static final String TAG = "IcingSearchEngine"; - private static final ExtensionRegistryLite EXTENSION_REGISTRY_LITE = - ExtensionRegistryLite.getEmptyRegistry(); - - private long nativePointer; - - private boolean closed = false; - - static { - // NOTE: This can fail with an UnsatisfiedLinkError - System.loadLibrary("icing"); - } + private final IcingSearchEngineImpl icingSearchEngineImpl; /** * @throws IllegalStateException if IcingSearchEngine fails to be created */ public IcingSearchEngine(@NonNull IcingSearchEngineOptions options) { - nativePointer = nativeCreate(options.toByteArray()); - if (nativePointer == 0) { - Log.e(TAG, "Failed to create IcingSearchEngine."); - throw new IllegalStateException("Failed to create IcingSearchEngine."); - } - } - - private void throwIfClosed() { - if (closed) { - throw new IllegalStateException("Trying to use a closed IcingSearchEngine instance."); - } + icingSearchEngineImpl = new IcingSearchEngineImpl(options.toByteArray()); } @Override public void close() { - if (closed) { - return; - } - - if (nativePointer != 0) { - nativeDestroy(this); - } - nativePointer = 0; - closed = true; + icingSearchEngineImpl.close(); } @Override protected void finalize() throws Throwable { - close(); + icingSearchEngineImpl.close(); super.finalize(); } @NonNull + @Override public InitializeResultProto initialize() { - throwIfClosed(); - - byte[] initializeResultBytes = nativeInitialize(this); - if (initializeResultBytes == null) { - Log.e(TAG, "Received null InitializeResult from native."); - return InitializeResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return InitializeResultProto.parseFrom(initializeResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing InitializeResultProto.", e); - return InitializeResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToInitializeResultProto( + icingSearchEngineImpl.initialize()); } @NonNull + @Override public SetSchemaResultProto setSchema(@NonNull SchemaProto schema) { return setSchema(schema, /*ignoreErrorsAndDeleteDocuments=*/ false); } @NonNull + @Override public SetSchemaResultProto setSchema( @NonNull SchemaProto schema, boolean ignoreErrorsAndDeleteDocuments) { - throwIfClosed(); - - byte[] setSchemaResultBytes = - nativeSetSchema(this, schema.toByteArray(), ignoreErrorsAndDeleteDocuments); - if (setSchemaResultBytes == null) { - Log.e(TAG, "Received null SetSchemaResultProto from native."); - return SetSchemaResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return SetSchemaResultProto.parseFrom(setSchemaResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing SetSchemaResultProto.", e); - return SetSchemaResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToSetSchemaResultProto( + icingSearchEngineImpl.setSchema(schema.toByteArray(), ignoreErrorsAndDeleteDocuments)); } @NonNull + @Override public GetSchemaResultProto getSchema() { - throwIfClosed(); - - byte[] getSchemaResultBytes = nativeGetSchema(this); - if (getSchemaResultBytes == null) { - Log.e(TAG, "Received null GetSchemaResultProto from native."); - return GetSchemaResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return GetSchemaResultProto.parseFrom(getSchemaResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing GetSchemaResultProto.", e); - return GetSchemaResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToGetSchemaResultProto( + icingSearchEngineImpl.getSchema()); } @NonNull + @Override public GetSchemaTypeResultProto getSchemaType(@NonNull String schemaType) { - throwIfClosed(); - - byte[] getSchemaTypeResultBytes = nativeGetSchemaType(this, schemaType); - if (getSchemaTypeResultBytes == null) { - Log.e(TAG, "Received null GetSchemaTypeResultProto from native."); - return GetSchemaTypeResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return GetSchemaTypeResultProto.parseFrom(getSchemaTypeResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing GetSchemaTypeResultProto.", e); - return GetSchemaTypeResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToGetSchemaTypeResultProto( + icingSearchEngineImpl.getSchemaType(schemaType)); } @NonNull + @Override public PutResultProto put(@NonNull DocumentProto document) { - throwIfClosed(); - - byte[] putResultBytes = nativePut(this, document.toByteArray()); - if (putResultBytes == null) { - Log.e(TAG, "Received null PutResultProto from native."); - return PutResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return PutResultProto.parseFrom(putResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing PutResultProto.", e); - return PutResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToPutResultProto( + icingSearchEngineImpl.put(document.toByteArray())); } @NonNull + @Override public GetResultProto get( @NonNull String namespace, @NonNull String uri, @NonNull GetResultSpecProto getResultSpec) { - throwIfClosed(); - - byte[] getResultBytes = nativeGet(this, namespace, uri, getResultSpec.toByteArray()); - if (getResultBytes == null) { - Log.e(TAG, "Received null GetResultProto from native."); - return GetResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return GetResultProto.parseFrom(getResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing GetResultProto.", e); - return GetResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToGetResultProto( + icingSearchEngineImpl.get(namespace, uri, getResultSpec.toByteArray())); } @NonNull + @Override public ReportUsageResultProto reportUsage(@NonNull UsageReport usageReport) { - throwIfClosed(); - - byte[] reportUsageResultBytes = nativeReportUsage(this, usageReport.toByteArray()); - if (reportUsageResultBytes == null) { - Log.e(TAG, "Received null ReportUsageResultProto from native."); - return ReportUsageResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return ReportUsageResultProto.parseFrom(reportUsageResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing ReportUsageResultProto.", e); - return ReportUsageResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToReportUsageResultProto( + icingSearchEngineImpl.reportUsage(usageReport.toByteArray())); } @NonNull + @Override public GetAllNamespacesResultProto getAllNamespaces() { - throwIfClosed(); - - byte[] getAllNamespacesResultBytes = nativeGetAllNamespaces(this); - if (getAllNamespacesResultBytes == null) { - Log.e(TAG, "Received null GetAllNamespacesResultProto from native."); - return GetAllNamespacesResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return GetAllNamespacesResultProto.parseFrom( - getAllNamespacesResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing GetAllNamespacesResultProto.", e); - return GetAllNamespacesResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToGetAllNamespacesResultProto( + icingSearchEngineImpl.getAllNamespaces()); } @NonNull + @Override public SearchResultProto search( @NonNull SearchSpecProto searchSpec, @NonNull ScoringSpecProto scoringSpec, @NonNull ResultSpecProto resultSpec) { - throwIfClosed(); - - // Note that on Android System.currentTimeMillis() is the standard "wall" clock and can be set - // by the user or the phone network so the time may jump backwards or forwards unpredictably. - // This could lead to inaccurate final JNI latency calculations or unexpected negative numbers - // in the case where the phone time is changed while sending data across JNI layers. - // However these occurrences should be very rare, so we will keep usage of - // System.currentTimeMillis() due to the lack of better time functions that can provide a - // consistent timestamp across all platforms. - long javaToNativeStartTimestampMs = System.currentTimeMillis(); - byte[] searchResultBytes = - nativeSearch( - this, - searchSpec.toByteArray(), - scoringSpec.toByteArray(), - resultSpec.toByteArray(), - javaToNativeStartTimestampMs); - if (searchResultBytes == null) { - Log.e(TAG, "Received null SearchResultProto from native."); - return SearchResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - SearchResultProto.Builder searchResultProtoBuilder = - SearchResultProto.newBuilder().mergeFrom(searchResultBytes, EXTENSION_REGISTRY_LITE); - setNativeToJavaJniLatency(searchResultProtoBuilder); - return searchResultProtoBuilder.build(); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing SearchResultProto.", e); - return SearchResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToSearchResultProto( + icingSearchEngineImpl.search( + searchSpec.toByteArray(), scoringSpec.toByteArray(), resultSpec.toByteArray())); } @NonNull + @Override public SearchResultProto getNextPage(long nextPageToken) { - throwIfClosed(); - - byte[] searchResultBytes = nativeGetNextPage(this, nextPageToken, System.currentTimeMillis()); - if (searchResultBytes == null) { - Log.e(TAG, "Received null SearchResultProto from native."); - return SearchResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - SearchResultProto.Builder searchResultProtoBuilder = - SearchResultProto.newBuilder().mergeFrom(searchResultBytes, EXTENSION_REGISTRY_LITE); - setNativeToJavaJniLatency(searchResultProtoBuilder); - return searchResultProtoBuilder.build(); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing SearchResultProto.", e); - return SearchResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - } - - private void setNativeToJavaJniLatency(SearchResultProto.Builder searchResultProtoBuilder) { - int nativeToJavaLatencyMs = - (int) - (System.currentTimeMillis() - - searchResultProtoBuilder.getQueryStats().getNativeToJavaStartTimestampMs()); - searchResultProtoBuilder.setQueryStats( - searchResultProtoBuilder.getQueryStats().toBuilder() - .setNativeToJavaJniLatencyMs(nativeToJavaLatencyMs)); + return IcingSearchEngineUtils.byteArrayToSearchResultProto( + icingSearchEngineImpl.getNextPage(nextPageToken)); } @NonNull + @Override public void invalidateNextPageToken(long nextPageToken) { - throwIfClosed(); - - nativeInvalidateNextPageToken(this, nextPageToken); + icingSearchEngineImpl.invalidateNextPageToken(nextPageToken); } @NonNull + @Override public DeleteResultProto delete(@NonNull String namespace, @NonNull String uri) { - throwIfClosed(); - - byte[] deleteResultBytes = nativeDelete(this, namespace, uri); - if (deleteResultBytes == null) { - Log.e(TAG, "Received null DeleteResultProto from native."); - return DeleteResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return DeleteResultProto.parseFrom(deleteResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing DeleteResultProto.", e); - return DeleteResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToDeleteResultProto( + icingSearchEngineImpl.delete(namespace, uri)); } @NonNull + @Override public SuggestionResponse searchSuggestions(@NonNull SuggestionSpecProto suggestionSpec) { - byte[] suggestionResponseBytes = nativeSearchSuggestions(this, suggestionSpec.toByteArray()); - if (suggestionResponseBytes == null) { - Log.e(TAG, "Received null suggestionResponseBytes from native."); - return SuggestionResponse.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return SuggestionResponse.parseFrom(suggestionResponseBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing suggestionResponseBytes.", e); - return SuggestionResponse.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToSuggestionResponse( + icingSearchEngineImpl.searchSuggestions(suggestionSpec.toByteArray())); } @NonNull + @Override public DeleteByNamespaceResultProto deleteByNamespace(@NonNull String namespace) { - throwIfClosed(); - - byte[] deleteByNamespaceResultBytes = nativeDeleteByNamespace(this, namespace); - if (deleteByNamespaceResultBytes == null) { - Log.e(TAG, "Received null DeleteByNamespaceResultProto from native."); - return DeleteByNamespaceResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return DeleteByNamespaceResultProto.parseFrom( - deleteByNamespaceResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing DeleteByNamespaceResultProto.", e); - return DeleteByNamespaceResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToDeleteByNamespaceResultProto( + icingSearchEngineImpl.deleteByNamespace(namespace)); } @NonNull + @Override public DeleteBySchemaTypeResultProto deleteBySchemaType(@NonNull String schemaType) { - throwIfClosed(); - - byte[] deleteBySchemaTypeResultBytes = nativeDeleteBySchemaType(this, schemaType); - if (deleteBySchemaTypeResultBytes == null) { - Log.e(TAG, "Received null DeleteBySchemaTypeResultProto from native."); - return DeleteBySchemaTypeResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return DeleteBySchemaTypeResultProto.parseFrom( - deleteBySchemaTypeResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing DeleteBySchemaTypeResultProto.", e); - return DeleteBySchemaTypeResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToDeleteBySchemaTypeResultProto( + icingSearchEngineImpl.deleteBySchemaType(schemaType)); } @NonNull + @Override public DeleteByQueryResultProto deleteByQuery(@NonNull SearchSpecProto searchSpec) { return deleteByQuery(searchSpec, /*returnDeletedDocumentInfo=*/ false); } @NonNull + @Override public DeleteByQueryResultProto deleteByQuery( @NonNull SearchSpecProto searchSpec, boolean returnDeletedDocumentInfo) { - throwIfClosed(); - - byte[] deleteResultBytes = - nativeDeleteByQuery(this, searchSpec.toByteArray(), returnDeletedDocumentInfo); - if (deleteResultBytes == null) { - Log.e(TAG, "Received null DeleteResultProto from native."); - return DeleteByQueryResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return DeleteByQueryResultProto.parseFrom(deleteResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing DeleteResultProto.", e); - return DeleteByQueryResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToDeleteByQueryResultProto( + icingSearchEngineImpl.deleteByQuery(searchSpec.toByteArray(), returnDeletedDocumentInfo)); } @NonNull + @Override public PersistToDiskResultProto persistToDisk(@NonNull PersistType.Code persistTypeCode) { - throwIfClosed(); - - byte[] persistToDiskResultBytes = nativePersistToDisk(this, persistTypeCode.getNumber()); - if (persistToDiskResultBytes == null) { - Log.e(TAG, "Received null PersistToDiskResultProto from native."); - return PersistToDiskResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return PersistToDiskResultProto.parseFrom(persistToDiskResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing PersistToDiskResultProto.", e); - return PersistToDiskResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToPersistToDiskResultProto( + icingSearchEngineImpl.persistToDisk(persistTypeCode.getNumber())); } @NonNull + @Override public OptimizeResultProto optimize() { - throwIfClosed(); - - byte[] optimizeResultBytes = nativeOptimize(this); - if (optimizeResultBytes == null) { - Log.e(TAG, "Received null OptimizeResultProto from native."); - return OptimizeResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return OptimizeResultProto.parseFrom(optimizeResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing OptimizeResultProto.", e); - return OptimizeResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToOptimizeResultProto(icingSearchEngineImpl.optimize()); } @NonNull + @Override public GetOptimizeInfoResultProto getOptimizeInfo() { - throwIfClosed(); - - byte[] getOptimizeInfoResultBytes = nativeGetOptimizeInfo(this); - if (getOptimizeInfoResultBytes == null) { - Log.e(TAG, "Received null GetOptimizeInfoResultProto from native."); - return GetOptimizeInfoResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return GetOptimizeInfoResultProto.parseFrom( - getOptimizeInfoResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing GetOptimizeInfoResultProto.", e); - return GetOptimizeInfoResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToGetOptimizeInfoResultProto( + icingSearchEngineImpl.getOptimizeInfo()); } @NonNull + @Override public StorageInfoResultProto getStorageInfo() { - throwIfClosed(); - - byte[] storageInfoResultProtoBytes = nativeGetStorageInfo(this); - if (storageInfoResultProtoBytes == null) { - Log.e(TAG, "Received null StorageInfoResultProto from native."); - return StorageInfoResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return StorageInfoResultProto.parseFrom(storageInfoResultProtoBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing GetOptimizeInfoResultProto.", e); - return StorageInfoResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToStorageInfoResultProto( + icingSearchEngineImpl.getStorageInfo()); } @NonNull + @Override public DebugInfoResultProto getDebugInfo(DebugInfoVerbosity.Code verbosity) { - throwIfClosed(); - - byte[] debugInfoResultProtoBytes = nativeGetDebugInfo(this, verbosity.getNumber()); - if (debugInfoResultProtoBytes == null) { - Log.e(TAG, "Received null DebugInfoResultProto from native."); - return DebugInfoResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return DebugInfoResultProto.parseFrom(debugInfoResultProtoBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing DebugInfoResultProto.", e); - return DebugInfoResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToDebugInfoResultProto( + icingSearchEngineImpl.getDebugInfo(verbosity.getNumber())); } @NonNull + @Override public ResetResultProto reset() { - throwIfClosed(); - - byte[] resetResultBytes = nativeReset(this); - if (resetResultBytes == null) { - Log.e(TAG, "Received null ResetResultProto from native."); - return ResetResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } - - try { - return ResetResultProto.parseFrom(resetResultBytes, EXTENSION_REGISTRY_LITE); - } catch (InvalidProtocolBufferException e) { - Log.e(TAG, "Error parsing ResetResultProto.", e); - return ResetResultProto.newBuilder() - .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) - .build(); - } + return IcingSearchEngineUtils.byteArrayToResetResultProto(icingSearchEngineImpl.reset()); } public static boolean shouldLog(LogSeverity.Code severity) { @@ -638,7 +258,7 @@ public class IcingSearchEngine implements Closeable { } public static boolean shouldLog(LogSeverity.Code severity, short verbosity) { - return nativeShouldLog((short) severity.getNumber(), verbosity); + return IcingSearchEngineImpl.shouldLog((short) severity.getNumber(), verbosity); } public static boolean setLoggingLevel(LogSeverity.Code severity) { @@ -646,84 +266,11 @@ public class IcingSearchEngine implements Closeable { } public static boolean setLoggingLevel(LogSeverity.Code severity, short verbosity) { - return nativeSetLoggingLevel((short) severity.getNumber(), verbosity); + return IcingSearchEngineImpl.setLoggingLevel((short) severity.getNumber(), verbosity); } @Nullable public static String getLoggingTag() { - String tag = nativeGetLoggingTag(); - if (tag == null) { - Log.e(TAG, "Received null logging tag from native."); - } - return tag; + return IcingSearchEngineImpl.getLoggingTag(); } - - private static native long nativeCreate(byte[] icingSearchEngineOptionsBytes); - - private static native void nativeDestroy(IcingSearchEngine instance); - - private static native byte[] nativeInitialize(IcingSearchEngine instance); - - private static native byte[] nativeSetSchema( - IcingSearchEngine instance, byte[] schemaBytes, boolean ignoreErrorsAndDeleteDocuments); - - private static native byte[] nativeGetSchema(IcingSearchEngine instance); - - private static native byte[] nativeGetSchemaType(IcingSearchEngine instance, String schemaType); - - private static native byte[] nativePut(IcingSearchEngine instance, byte[] documentBytes); - - private static native byte[] nativeGet( - IcingSearchEngine instance, String namespace, String uri, byte[] getResultSpecBytes); - - private static native byte[] nativeReportUsage( - IcingSearchEngine instance, byte[] usageReportBytes); - - private static native byte[] nativeGetAllNamespaces(IcingSearchEngine instance); - - private static native byte[] nativeSearch( - IcingSearchEngine instance, - byte[] searchSpecBytes, - byte[] scoringSpecBytes, - byte[] resultSpecBytes, - long javaToNativeStartTimestampMs); - - private static native byte[] nativeGetNextPage( - IcingSearchEngine instance, long nextPageToken, long javaToNativeStartTimestampMs); - - private static native void nativeInvalidateNextPageToken( - IcingSearchEngine instance, long nextPageToken); - - private static native byte[] nativeDelete( - IcingSearchEngine instance, String namespace, String uri); - - private static native byte[] nativeDeleteByNamespace( - IcingSearchEngine instance, String namespace); - - private static native byte[] nativeDeleteBySchemaType( - IcingSearchEngine instance, String schemaType); - - private static native byte[] nativeDeleteByQuery( - IcingSearchEngine instance, byte[] searchSpecBytes, boolean returnDeletedDocumentInfo); - - private static native byte[] nativePersistToDisk(IcingSearchEngine instance, int persistType); - - private static native byte[] nativeOptimize(IcingSearchEngine instance); - - private static native byte[] nativeGetOptimizeInfo(IcingSearchEngine instance); - - private static native byte[] nativeGetStorageInfo(IcingSearchEngine instance); - - private static native byte[] nativeReset(IcingSearchEngine instance); - - private static native byte[] nativeSearchSuggestions( - IcingSearchEngine instance, byte[] suggestionSpecBytes); - - private static native byte[] nativeGetDebugInfo(IcingSearchEngine instance, int verbosity); - - private static native boolean nativeShouldLog(short severity, short verbosity); - - private static native boolean nativeSetLoggingLevel(short severity, short verbosity); - - private static native String nativeGetLoggingTag(); } diff --git a/java/src/com/google/android/icing/IcingSearchEngineImpl.java b/java/src/com/google/android/icing/IcingSearchEngineImpl.java new file mode 100644 index 0000000..8e79a88 --- /dev/null +++ b/java/src/com/google/android/icing/IcingSearchEngineImpl.java @@ -0,0 +1,330 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.android.icing; + +import android.util.Log; +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import java.io.Closeable; + +/** + * Java wrapper to access native APIs in external/icing/icing/icing-search-engine.h + * + * <p>If this instance has been closed, the instance is no longer usable. + * + * <p>Keep this class to be non-Final so that it can be mocked in AppSearch. + * + * <p>NOTE: This class is NOT thread-safe. + */ +public class IcingSearchEngineImpl implements Closeable { + + private static final String TAG = "IcingSearchEngineImpl"; + + private long nativePointer; + + private boolean closed = false; + + static { + // NOTE: This can fail with an UnsatisfiedLinkError + System.loadLibrary("icing"); + } + + /** + * @throws IllegalStateException if IcingSearchEngineImpl fails to be created + */ + public IcingSearchEngineImpl(@NonNull byte[] optionsBytes) { + nativePointer = nativeCreate(optionsBytes); + if (nativePointer == 0) { + Log.e(TAG, "Failed to create IcingSearchEngineImpl."); + throw new IllegalStateException("Failed to create IcingSearchEngineImpl."); + } + } + + private void throwIfClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed IcingSearchEngineImpl instance."); + } + } + + @Override + public void close() { + if (closed) { + return; + } + + if (nativePointer != 0) { + nativeDestroy(this); + } + nativePointer = 0; + closed = true; + } + + @Override + protected void finalize() throws Throwable { + close(); + super.finalize(); + } + + @Nullable + public byte[] initialize() { + throwIfClosed(); + return nativeInitialize(this); + } + + @Nullable + public byte[] setSchema(@NonNull byte[] schemaBytes) { + return setSchema(schemaBytes, /* ignoreErrorsAndDeleteDocuments= */ false); + } + + @Nullable + public byte[] setSchema(@NonNull byte[] schemaBytes, boolean ignoreErrorsAndDeleteDocuments) { + throwIfClosed(); + return nativeSetSchema(this, schemaBytes, ignoreErrorsAndDeleteDocuments); + } + + @Nullable + public byte[] getSchema() { + throwIfClosed(); + return nativeGetSchema(this); + } + + @Nullable + public byte[] getSchemaType(@NonNull String schemaType) { + throwIfClosed(); + return nativeGetSchemaType(this, schemaType); + } + + @Nullable + public byte[] put(@NonNull byte[] documentBytes) { + throwIfClosed(); + return nativePut(this, documentBytes); + } + + @Nullable + public byte[] get( + @NonNull String namespace, @NonNull String uri, @NonNull byte[] getResultSpecBytes) { + throwIfClosed(); + return nativeGet(this, namespace, uri, getResultSpecBytes); + } + + @Nullable + public byte[] reportUsage(@NonNull byte[] usageReportBytes) { + throwIfClosed(); + return nativeReportUsage(this, usageReportBytes); + } + + @Nullable + public byte[] getAllNamespaces() { + throwIfClosed(); + return nativeGetAllNamespaces(this); + } + + @Nullable + public byte[] search( + @NonNull byte[] searchSpecBytes, + @NonNull byte[] scoringSpecBytes, + @NonNull byte[] resultSpecBytes) { + throwIfClosed(); + + // Note that on Android System.currentTimeMillis() is the standard "wall" clock and can be set + // by the user or the phone network so the time may jump backwards or forwards unpredictably. + // This could lead to inaccurate final JNI latency calculations or unexpected negative numbers + // in the case where the phone time is changed while sending data across JNI layers. + // However these occurrences should be very rare, so we will keep usage of + // System.currentTimeMillis() due to the lack of better time functions that can provide a + // consistent timestamp across all platforms. + long javaToNativeStartTimestampMs = System.currentTimeMillis(); + return nativeSearch( + this, searchSpecBytes, scoringSpecBytes, resultSpecBytes, javaToNativeStartTimestampMs); + } + + @Nullable + public byte[] getNextPage(long nextPageToken) { + throwIfClosed(); + return nativeGetNextPage(this, nextPageToken, System.currentTimeMillis()); + } + + @NonNull + public void invalidateNextPageToken(long nextPageToken) { + throwIfClosed(); + nativeInvalidateNextPageToken(this, nextPageToken); + } + + @Nullable + public byte[] delete(@NonNull String namespace, @NonNull String uri) { + throwIfClosed(); + return nativeDelete(this, namespace, uri); + } + + @Nullable + public byte[] searchSuggestions(@NonNull byte[] suggestionSpecBytes) { + throwIfClosed(); + return nativeSearchSuggestions(this, suggestionSpecBytes); + } + + @Nullable + public byte[] deleteByNamespace(@NonNull String namespace) { + throwIfClosed(); + return nativeDeleteByNamespace(this, namespace); + } + + @Nullable + public byte[] deleteBySchemaType(@NonNull String schemaType) { + throwIfClosed(); + return nativeDeleteBySchemaType(this, schemaType); + } + + @Nullable + public byte[] deleteByQuery(@NonNull byte[] searchSpecBytes) { + return deleteByQuery(searchSpecBytes, /* returnDeletedDocumentInfo= */ false); + } + + @Nullable + public byte[] deleteByQuery(@NonNull byte[] searchSpecBytes, boolean returnDeletedDocumentInfo) { + throwIfClosed(); + return nativeDeleteByQuery(this, searchSpecBytes, returnDeletedDocumentInfo); + } + + @Nullable + public byte[] persistToDisk(int persistTypeCode) { + throwIfClosed(); + return nativePersistToDisk(this, persistTypeCode); + } + + @Nullable + public byte[] optimize() { + throwIfClosed(); + return nativeOptimize(this); + } + + @Nullable + public byte[] getOptimizeInfo() { + throwIfClosed(); + return nativeGetOptimizeInfo(this); + } + + @Nullable + public byte[] getStorageInfo() { + throwIfClosed(); + return nativeGetStorageInfo(this); + } + + @Nullable + public byte[] getDebugInfo(int verbosityCode) { + throwIfClosed(); + return nativeGetDebugInfo(this, verbosityCode); + } + + @Nullable + public byte[] reset() { + throwIfClosed(); + return nativeReset(this); + } + + public static boolean shouldLog(short severity) { + return shouldLog(severity, (short) 0); + } + + public static boolean shouldLog(short severity, short verbosity) { + return nativeShouldLog(severity, verbosity); + } + + public static boolean setLoggingLevel(short severity) { + return setLoggingLevel(severity, (short) 0); + } + + public static boolean setLoggingLevel(short severity, short verbosity) { + return nativeSetLoggingLevel(severity, verbosity); + } + + @Nullable + public static String getLoggingTag() { + String tag = nativeGetLoggingTag(); + if (tag == null) { + Log.e(TAG, "Received null logging tag from native."); + } + return tag; + } + + private static native long nativeCreate(byte[] icingSearchEngineOptionsBytes); + + private static native void nativeDestroy(IcingSearchEngineImpl instance); + + private static native byte[] nativeInitialize(IcingSearchEngineImpl instance); + + private static native byte[] nativeSetSchema( + IcingSearchEngineImpl instance, byte[] schemaBytes, boolean ignoreErrorsAndDeleteDocuments); + + private static native byte[] nativeGetSchema(IcingSearchEngineImpl instance); + + private static native byte[] nativeGetSchemaType( + IcingSearchEngineImpl instance, String schemaType); + + private static native byte[] nativePut(IcingSearchEngineImpl instance, byte[] documentBytes); + + private static native byte[] nativeGet( + IcingSearchEngineImpl instance, String namespace, String uri, byte[] getResultSpecBytes); + + private static native byte[] nativeReportUsage( + IcingSearchEngineImpl instance, byte[] usageReportBytes); + + private static native byte[] nativeGetAllNamespaces(IcingSearchEngineImpl instance); + + private static native byte[] nativeSearch( + IcingSearchEngineImpl instance, + byte[] searchSpecBytes, + byte[] scoringSpecBytes, + byte[] resultSpecBytes, + long javaToNativeStartTimestampMs); + + private static native byte[] nativeGetNextPage( + IcingSearchEngineImpl instance, long nextPageToken, long javaToNativeStartTimestampMs); + + private static native void nativeInvalidateNextPageToken( + IcingSearchEngineImpl instance, long nextPageToken); + + private static native byte[] nativeDelete( + IcingSearchEngineImpl instance, String namespace, String uri); + + private static native byte[] nativeDeleteByNamespace( + IcingSearchEngineImpl instance, String namespace); + + private static native byte[] nativeDeleteBySchemaType( + IcingSearchEngineImpl instance, String schemaType); + + private static native byte[] nativeDeleteByQuery( + IcingSearchEngineImpl instance, byte[] searchSpecBytes, boolean returnDeletedDocumentInfo); + + private static native byte[] nativePersistToDisk(IcingSearchEngineImpl instance, int persistType); + + private static native byte[] nativeOptimize(IcingSearchEngineImpl instance); + + private static native byte[] nativeGetOptimizeInfo(IcingSearchEngineImpl instance); + + private static native byte[] nativeGetStorageInfo(IcingSearchEngineImpl instance); + + private static native byte[] nativeReset(IcingSearchEngineImpl instance); + + private static native byte[] nativeSearchSuggestions( + IcingSearchEngineImpl instance, byte[] suggestionSpecBytes); + + private static native byte[] nativeGetDebugInfo(IcingSearchEngineImpl instance, int verbosity); + + private static native boolean nativeShouldLog(short severity, short verbosity); + + private static native boolean nativeSetLoggingLevel(short severity, short verbosity); + + private static native String nativeGetLoggingTag(); +} diff --git a/java/src/com/google/android/icing/IcingSearchEngineInterface.java b/java/src/com/google/android/icing/IcingSearchEngineInterface.java new file mode 100644 index 0000000..9d567f9 --- /dev/null +++ b/java/src/com/google/android/icing/IcingSearchEngineInterface.java @@ -0,0 +1,153 @@ +package com.google.android.icing; + +import android.os.RemoteException; +import com.google.android.icing.proto.DebugInfoResultProto; +import com.google.android.icing.proto.DebugInfoVerbosity; +import com.google.android.icing.proto.DeleteByNamespaceResultProto; +import com.google.android.icing.proto.DeleteByQueryResultProto; +import com.google.android.icing.proto.DeleteBySchemaTypeResultProto; +import com.google.android.icing.proto.DeleteResultProto; +import com.google.android.icing.proto.DocumentProto; +import com.google.android.icing.proto.GetAllNamespacesResultProto; +import com.google.android.icing.proto.GetOptimizeInfoResultProto; +import com.google.android.icing.proto.GetResultProto; +import com.google.android.icing.proto.GetResultSpecProto; +import com.google.android.icing.proto.GetSchemaResultProto; +import com.google.android.icing.proto.GetSchemaTypeResultProto; +import com.google.android.icing.proto.InitializeResultProto; +import com.google.android.icing.proto.OptimizeResultProto; +import com.google.android.icing.proto.PersistToDiskResultProto; +import com.google.android.icing.proto.PersistType; +import com.google.android.icing.proto.PutResultProto; +import com.google.android.icing.proto.ReportUsageResultProto; +import com.google.android.icing.proto.ResetResultProto; +import com.google.android.icing.proto.ResultSpecProto; +import com.google.android.icing.proto.SchemaProto; +import com.google.android.icing.proto.ScoringSpecProto; +import com.google.android.icing.proto.SearchResultProto; +import com.google.android.icing.proto.SearchSpecProto; +import com.google.android.icing.proto.SetSchemaResultProto; +import com.google.android.icing.proto.StorageInfoResultProto; +import com.google.android.icing.proto.SuggestionResponse; +import com.google.android.icing.proto.SuggestionSpecProto; +import com.google.android.icing.proto.UsageReport; + +/** + * A common user-facing interface to expose the funcationalities provided by Icing Library. + * + * <p>All the methods here throw {@link RemoteException} because the implementation for + * gmscore-appsearch-dynamite will throw it. + */ +public interface IcingSearchEngineInterface extends AutoCloseable { + /** + * Initializes the current IcingSearchEngine implementation. + * + * <p>Internally the icing instance will be initialized. + */ + InitializeResultProto initialize(); + + /** Sets the schema for the icing instance. */ + SetSchemaResultProto setSchema(SchemaProto schema); + + /** + * Sets the schema for the icing instance. + * + * @param ignoreErrorsAndDeleteDocuments force to set the schema and delete documents in case of + * incompatible schema change. + */ + SetSchemaResultProto setSchema(SchemaProto schema, boolean ignoreErrorsAndDeleteDocuments); + + /** Gets the schema for the icing instance. */ + GetSchemaResultProto getSchema(); + + /** + * Gets the schema for the icing instance. + * + * @param schemaType type of the schema. + */ + GetSchemaTypeResultProto getSchemaType(String schemaType); + + /** Puts the document. */ + PutResultProto put(DocumentProto document); + + /** + * Gets the document. + * + * @param namespace namespace of the document. + * @param uri uri of the document. + * @param getResultSpec the spec for getting the document. + */ + GetResultProto get(String namespace, String uri, GetResultSpecProto getResultSpec); + + /** Reports usage. */ + ReportUsageResultProto reportUsage(UsageReport usageReport); + + /** Gets all namespaces. */ + GetAllNamespacesResultProto getAllNamespaces(); + + /** + * Searches over the documents. + * + * <p>Documents need to be retrieved on the following {@link #getNextPage} calls on the returned + * {@link SearchResultProto}. + */ + SearchResultProto search( + SearchSpecProto searchSpec, ScoringSpecProto scoringSpec, ResultSpecProto resultSpec); + + /** Gets the next page. */ + SearchResultProto getNextPage(long nextPageToken); + + /** Invalidates the next page token. */ + void invalidateNextPageToken(long nextPageToken); + + /** + * Deletes the document. + * + * @param namespace the namespace the document to be deleted belong to. + * @param uri the uri for the document to be deleted. + */ + DeleteResultProto delete(String namespace, String uri); + + /** Returns the suggestions for the search query. */ + SuggestionResponse searchSuggestions(SuggestionSpecProto suggestionSpec); + + /** Deletes documents by the namespace. */ + DeleteByNamespaceResultProto deleteByNamespace(String namespace); + + /** Deletes documents by the schema type. */ + DeleteBySchemaTypeResultProto deleteBySchemaType(String schemaType); + + /** Deletes documents by the search query. */ + DeleteByQueryResultProto deleteByQuery(SearchSpecProto searchSpec); + + /** + * Deletes document by the search query + * + * @param returnDeletedDocumentInfo whether additional information about deleted documents will be + * included in {@link DeleteByQueryResultProto}. + */ + DeleteByQueryResultProto deleteByQuery( + SearchSpecProto searchSpec, boolean returnDeletedDocumentInfo); + + /** Makes sure every update/delete received till this point is flushed to disk. */ + PersistToDiskResultProto persistToDisk(PersistType.Code persistTypeCode); + + /** Makes the icing instance run tasks that are too expensive to be run in real-time. */ + OptimizeResultProto optimize(); + + /** Gets information about the optimization. */ + GetOptimizeInfoResultProto getOptimizeInfo(); + + /** Gets information about the storage. */ + StorageInfoResultProto getStorageInfo(); + + /** Gets the debug information for the current icing instance. */ + DebugInfoResultProto getDebugInfo(DebugInfoVerbosity.Code verbosity); + + /** Clears all data from the current icing instance, and reinitializes it. */ + ResetResultProto reset(); + + /** Closes the current icing instance. */ + @Override + void close(); +} diff --git a/java/src/com/google/android/icing/IcingSearchEngineUtils.java b/java/src/com/google/android/icing/IcingSearchEngineUtils.java new file mode 100644 index 0000000..0913216 --- /dev/null +++ b/java/src/com/google/android/icing/IcingSearchEngineUtils.java @@ -0,0 +1,471 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.google.android.icing; + +import android.util.Log; +import androidx.annotation.NonNull; +import androidx.annotation.Nullable; +import com.google.android.icing.proto.DebugInfoResultProto; +import com.google.android.icing.proto.DeleteByNamespaceResultProto; +import com.google.android.icing.proto.DeleteByQueryResultProto; +import com.google.android.icing.proto.DeleteBySchemaTypeResultProto; +import com.google.android.icing.proto.DeleteResultProto; +import com.google.android.icing.proto.GetAllNamespacesResultProto; +import com.google.android.icing.proto.GetOptimizeInfoResultProto; +import com.google.android.icing.proto.GetResultProto; +import com.google.android.icing.proto.GetSchemaResultProto; +import com.google.android.icing.proto.GetSchemaTypeResultProto; +import com.google.android.icing.proto.InitializeResultProto; +import com.google.android.icing.proto.OptimizeResultProto; +import com.google.android.icing.proto.PersistToDiskResultProto; +import com.google.android.icing.proto.PutResultProto; +import com.google.android.icing.proto.ReportUsageResultProto; +import com.google.android.icing.proto.ResetResultProto; +import com.google.android.icing.proto.SearchResultProto; +import com.google.android.icing.proto.SetSchemaResultProto; +import com.google.android.icing.proto.StatusProto; +import com.google.android.icing.proto.StorageInfoResultProto; +import com.google.android.icing.proto.SuggestionResponse; +import com.google.protobuf.ExtensionRegistryLite; +import com.google.protobuf.InvalidProtocolBufferException; + +/** + * Contains utility methods for IcingSearchEngine to convert byte arrays to the corresponding + * protos. + * + * <p>It is also being used by AppSearch dynamite 0p client APIs to convert byte arrays to the + * protos. + */ +public final class IcingSearchEngineUtils { + private static final String TAG = "IcingSearchEngineUtils"; + private static final ExtensionRegistryLite EXTENSION_REGISTRY_LITE = + ExtensionRegistryLite.getEmptyRegistry(); + + private IcingSearchEngineUtils() {} + + // TODO(b/240333360) Check to see if we can use one template function to replace those + @NonNull + public static InitializeResultProto byteArrayToInitializeResultProto( + @Nullable byte[] initializeResultBytes) { + if (initializeResultBytes == null) { + Log.e(TAG, "Received null InitializeResult from native."); + return InitializeResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return InitializeResultProto.parseFrom(initializeResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing InitializeResultProto.", e); + return InitializeResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static SetSchemaResultProto byteArrayToSetSchemaResultProto( + @Nullable byte[] setSchemaResultBytes) { + if (setSchemaResultBytes == null) { + Log.e(TAG, "Received null SetSchemaResultProto from native."); + return SetSchemaResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return SetSchemaResultProto.parseFrom(setSchemaResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing SetSchemaResultProto.", e); + return SetSchemaResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static GetSchemaResultProto byteArrayToGetSchemaResultProto( + @Nullable byte[] getSchemaResultBytes) { + if (getSchemaResultBytes == null) { + Log.e(TAG, "Received null GetSchemaResultProto from native."); + return GetSchemaResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return GetSchemaResultProto.parseFrom(getSchemaResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing GetSchemaResultProto.", e); + return GetSchemaResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static GetSchemaTypeResultProto byteArrayToGetSchemaTypeResultProto( + @Nullable byte[] getSchemaTypeResultBytes) { + if (getSchemaTypeResultBytes == null) { + Log.e(TAG, "Received null GetSchemaTypeResultProto from native."); + return GetSchemaTypeResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return GetSchemaTypeResultProto.parseFrom(getSchemaTypeResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing GetSchemaTypeResultProto.", e); + return GetSchemaTypeResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static PutResultProto byteArrayToPutResultProto(@Nullable byte[] putResultBytes) { + if (putResultBytes == null) { + Log.e(TAG, "Received null PutResultProto from native."); + return PutResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return PutResultProto.parseFrom(putResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing PutResultProto.", e); + return PutResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static GetResultProto byteArrayToGetResultProto(@Nullable byte[] getResultBytes) { + if (getResultBytes == null) { + Log.e(TAG, "Received null GetResultProto from native."); + return GetResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return GetResultProto.parseFrom(getResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing GetResultProto.", e); + return GetResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static ReportUsageResultProto byteArrayToReportUsageResultProto( + @Nullable byte[] reportUsageResultBytes) { + if (reportUsageResultBytes == null) { + Log.e(TAG, "Received null ReportUsageResultProto from native."); + return ReportUsageResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return ReportUsageResultProto.parseFrom(reportUsageResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing ReportUsageResultProto.", e); + return ReportUsageResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static GetAllNamespacesResultProto byteArrayToGetAllNamespacesResultProto( + @Nullable byte[] getAllNamespacesResultBytes) { + if (getAllNamespacesResultBytes == null) { + Log.e(TAG, "Received null GetAllNamespacesResultProto from native."); + return GetAllNamespacesResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return GetAllNamespacesResultProto.parseFrom( + getAllNamespacesResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing GetAllNamespacesResultProto.", e); + return GetAllNamespacesResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static SearchResultProto byteArrayToSearchResultProto(@Nullable byte[] searchResultBytes) { + if (searchResultBytes == null) { + Log.e(TAG, "Received null SearchResultProto from native."); + return SearchResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + SearchResultProto.Builder searchResultProtoBuilder = + SearchResultProto.newBuilder().mergeFrom(searchResultBytes, EXTENSION_REGISTRY_LITE); + setNativeToJavaJniLatency(searchResultProtoBuilder); + return searchResultProtoBuilder.build(); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing SearchResultProto.", e); + return SearchResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + private static void setNativeToJavaJniLatency( + SearchResultProto.Builder searchResultProtoBuilder) { + int nativeToJavaLatencyMs = + (int) + (System.currentTimeMillis() + - searchResultProtoBuilder.getQueryStats().getNativeToJavaStartTimestampMs()); + searchResultProtoBuilder.setQueryStats( + searchResultProtoBuilder.getQueryStats().toBuilder() + .setNativeToJavaJniLatencyMs(nativeToJavaLatencyMs)); + } + + @NonNull + public static DeleteResultProto byteArrayToDeleteResultProto(@Nullable byte[] deleteResultBytes) { + if (deleteResultBytes == null) { + Log.e(TAG, "Received null DeleteResultProto from native."); + return DeleteResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return DeleteResultProto.parseFrom(deleteResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing DeleteResultProto.", e); + return DeleteResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static SuggestionResponse byteArrayToSuggestionResponse( + @Nullable byte[] suggestionResponseBytes) { + if (suggestionResponseBytes == null) { + Log.e(TAG, "Received null suggestionResponseBytes from native."); + return SuggestionResponse.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return SuggestionResponse.parseFrom(suggestionResponseBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing suggestionResponseBytes.", e); + return SuggestionResponse.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static DeleteByNamespaceResultProto byteArrayToDeleteByNamespaceResultProto( + @Nullable byte[] deleteByNamespaceResultBytes) { + if (deleteByNamespaceResultBytes == null) { + Log.e(TAG, "Received null DeleteByNamespaceResultProto from native."); + return DeleteByNamespaceResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return DeleteByNamespaceResultProto.parseFrom( + deleteByNamespaceResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing DeleteByNamespaceResultProto.", e); + return DeleteByNamespaceResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static DeleteBySchemaTypeResultProto byteArrayToDeleteBySchemaTypeResultProto( + @Nullable byte[] deleteBySchemaTypeResultBytes) { + if (deleteBySchemaTypeResultBytes == null) { + Log.e(TAG, "Received null DeleteBySchemaTypeResultProto from native."); + return DeleteBySchemaTypeResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return DeleteBySchemaTypeResultProto.parseFrom( + deleteBySchemaTypeResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing DeleteBySchemaTypeResultProto.", e); + return DeleteBySchemaTypeResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static DeleteByQueryResultProto byteArrayToDeleteByQueryResultProto( + @Nullable byte[] deleteResultBytes) { + if (deleteResultBytes == null) { + Log.e(TAG, "Received null DeleteResultProto from native."); + return DeleteByQueryResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return DeleteByQueryResultProto.parseFrom(deleteResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing DeleteResultProto.", e); + return DeleteByQueryResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static PersistToDiskResultProto byteArrayToPersistToDiskResultProto( + @Nullable byte[] persistToDiskResultBytes) { + if (persistToDiskResultBytes == null) { + Log.e(TAG, "Received null PersistToDiskResultProto from native."); + return PersistToDiskResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return PersistToDiskResultProto.parseFrom(persistToDiskResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing PersistToDiskResultProto.", e); + return PersistToDiskResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static OptimizeResultProto byteArrayToOptimizeResultProto( + @Nullable byte[] optimizeResultBytes) { + if (optimizeResultBytes == null) { + Log.e(TAG, "Received null OptimizeResultProto from native."); + return OptimizeResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return OptimizeResultProto.parseFrom(optimizeResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing OptimizeResultProto.", e); + return OptimizeResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static GetOptimizeInfoResultProto byteArrayToGetOptimizeInfoResultProto( + @Nullable byte[] getOptimizeInfoResultBytes) { + if (getOptimizeInfoResultBytes == null) { + Log.e(TAG, "Received null GetOptimizeInfoResultProto from native."); + return GetOptimizeInfoResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return GetOptimizeInfoResultProto.parseFrom( + getOptimizeInfoResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing GetOptimizeInfoResultProto.", e); + return GetOptimizeInfoResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static StorageInfoResultProto byteArrayToStorageInfoResultProto( + @Nullable byte[] storageInfoResultProtoBytes) { + if (storageInfoResultProtoBytes == null) { + Log.e(TAG, "Received null StorageInfoResultProto from native."); + return StorageInfoResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return StorageInfoResultProto.parseFrom(storageInfoResultProtoBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing GetOptimizeInfoResultProto.", e); + return StorageInfoResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static DebugInfoResultProto byteArrayToDebugInfoResultProto( + @Nullable byte[] debugInfoResultProtoBytes) { + if (debugInfoResultProtoBytes == null) { + Log.e(TAG, "Received null DebugInfoResultProto from native."); + return DebugInfoResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return DebugInfoResultProto.parseFrom(debugInfoResultProtoBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing DebugInfoResultProto.", e); + return DebugInfoResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull + public static ResetResultProto byteArrayToResetResultProto(@Nullable byte[] resetResultBytes) { + if (resetResultBytes == null) { + Log.e(TAG, "Received null ResetResultProto from native."); + return ResetResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return ResetResultProto.parseFrom(resetResultBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing ResetResultProto.", e); + return ResetResultProto.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } +} diff --git a/proto/icing/proto/scoring.proto b/proto/icing/proto/scoring.proto index 13861c9..a8040a1 100644 --- a/proto/icing/proto/scoring.proto +++ b/proto/icing/proto/scoring.proto @@ -25,7 +25,7 @@ option objc_class_prefix = "ICNG"; // Encapsulates the configurations on how Icing should score and rank the search // results. // TODO(b/170347684): Change all timestamps to seconds. -// Next tag: 4 +// Next tag: 12 message ScoringSpecProto { // OPTIONAL: Indicates how the search results will be ranked. message RankingStrategy { @@ -71,6 +71,9 @@ message ScoringSpecProto { // Ranked by the aggregated score of the joined documents. JOIN_AGGREGATE_SCORE = 10; + + // Ranked by the advanced scoring expression provided. + ADVANCED_SCORING_EXPRESSION = 11; } } optional RankingStrategy.Code rank_by = 1; @@ -99,6 +102,10 @@ message ScoringSpecProto { // all properties that are not specified are given a raw, pre-normalized // weight of 1.0 when scoring. repeated TypePropertyWeights type_property_weights = 3; + + // OPTIONAL: Specifies the scoring expression for ADVANCED_SCORING_EXPRESSION + // RankingStrategy. + optional string advanced_scoring_expression = 4; } // Next tag: 3 diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index 181c63c..e7e0208 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -299,7 +299,7 @@ message SearchResultProto { // determined by ScoringSpecProto.rank_by. optional double score = 3; - // The documents that were joined to a parent document. + // The child documents that were joined to a parent document. repeated ResultProto joined_results = 4; } repeated ResultProto results = 2; @@ -430,37 +430,53 @@ message SuggestionResponse { // // Next tag: 7 message JoinSpecProto { - // A nested SearchSpec that will be used to retrieve joined documents. If you - // are only looking to join on Action type documents, you could set a schema - // filter in this SearchSpec. This includes the nested search query. See - // SearchSpecProto. - optional SearchSpecProto nested_search_spec = 1; + // Collection of several specs that will be used for searching and joining + // child documents. + // + // Next tag: 4 + message NestedSpecProto { + // A nested SearchSpec that will be used to retrieve child documents. If you + // are only looking to join on a specific type documents, you could set a + // schema filter in this SearchSpec. This includes the nested search query. + // See SearchSpecProto. + optional SearchSpecProto search_spec = 1; + + // A nested ScoringSpec that will be used to score child documents. + // See ScoringSpecProto. + optional ScoringSpecProto scoring_spec = 2; + + // A nested ResultSpec that will be used to format child documents in the + // result joined documents, e.g. snippeting, projection. + // See ResultSpecProto. + optional ResultSpecProto result_spec = 3; + } + optional NestedSpecProto nested_spec = 1; // The equivalent of a primary key in SQL. This is an expression that will be // used to match child documents from the nested search to this document. One - // such expression is qualifiedId(). When used, it means the - // child_property_expression in the joined documents must be equal to the - // qualified id. + // such expression is qualifiedId(). When used, it means the contents of + // child_property_expression property in the child documents must be equal to + // the qualified id. // TODO(b/256022027) allow for parent_property_expression to be any property // of the parent document. optional string parent_property_expression = 2; // The equivalent of a foreign key in SQL. This defines an equality constraint // between a property in a child document and a property in the parent - // document. For example, if you want to join Action documents which an + // document. For example, if you want to join child documents which an // entityId property containing a fully qualified document id, // child_property_expression can be set to "entityId". // TODO(b/256022027) figure out how to allow this to refer to documents // outside of same pkg+db+ns. optional string child_property_expression = 3; - // The max amount of joined documents to join to a parent document. - optional int32 max_joined_result_count = 4; + // The max number of child documents to join to a parent document. + optional int32 max_joined_child_count = 4; - // The strategy by which to score the aggregation of joined documents. For + // The strategy by which to score the aggregation of child documents. For // example, you might want to know which entity document has the most actions // taken on it. If JOIN_AGGREGATE_SCORE is used in the base SearchSpecProto, - // the COUNT value will rank entity documents based on the number of joined + // the COUNT value will rank entity documents based on the number of child // documents. enum AggregationScore { UNDEFINED = 0; diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index 55403b4..654903b 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=487674301) +set(synced_AOSP_CL_number=494856295) |