diff options
author | Tim Barron <tjbarron@google.com> | 2022-08-11 17:05:22 -0700 |
---|---|---|
committer | Tim Barron <tjbarron@google.com> | 2022-08-11 17:05:22 -0700 |
commit | 87267cbc5531600072a283ba0c9500c3fcac87af (patch) | |
tree | 2ec5c19afb1f4d5ee229d0619c25b2f2819ccea0 | |
parent | 7c93c404e1fb4ed5e35326245ebc820ed774c6b2 (diff) | |
download | icing-87267cbc5531600072a283ba0c9500c3fcac87af.tar.gz |
Sync from upstream.
Descriptions:
======================================================================
Implement new version of ResultState and ResultStateManager to 1)
enforce a page byte size limit and 2) improve handling of pagination
when we encounter deleted documents.
======================================================================
Fix bugs in IcingDynamicTrie::Delete.
======================================================================
Implement IcingDynamicTrie::IsBranchingTerm.
======================================================================
Change Icing default logging level to INFO
======================================================================
Refactor KeyMapper class to be an interface.
======================================================================
Improve NamespaceChecker logic to improve Suggest latency.
======================================================================
Change icing native log tag to "AppSearchIcing"
======================================================================
Implement Index Compaction rather than rebuilding index during
Compaction.
======================================================================
Implement reverse iterator for IcingDynamicTrie
======================================================================
Avoid adding unnecessary branch points during index compaction
======================================================================
Invalidate expired result states when adding to/retrieving from
ResultStateManager.
======================================================================
Add new methods (MutableView, MutableArrayView, Append, Allocate) to
FileBackedVector
======================================================================
Create and implement PersistentHashMap class.
======================================================================
Implement RFC822 Tokenizer
======================================================================
Remove uses of StringPrintf in ICING_LOG statements
======================================================================
Properly set query latency when an error is encountered or results are
empty.
======================================================================
Bug: 146903474
Bug: 152934343
Bug: 193919210
Bug: 193453081
Bug: 231368517
Bug: 235395538
Bug: 236412165
Change-Id: I8aa278cebb12b25b39deb0ef584c0f198952659d
73 files changed, 7362 insertions, 1872 deletions
diff --git a/icing/file/file-backed-bitmap.cc b/icing/file/file-backed-bitmap.cc index eec7668..a8231e3 100644 --- a/icing/file/file-backed-bitmap.cc +++ b/icing/file/file-backed-bitmap.cc @@ -269,8 +269,7 @@ libtextclassifier3::Status FileBackedBitmap::GrowTo(int new_num_bits) { return status; } - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Grew file %s to new size %zd", file_path_.c_str(), new_file_size); + ICING_VLOG(1) << "Grew file " << file_path_ << " to new size " << new_file_size; mutable_header()->state = Header::ChecksumState::kStale; return libtextclassifier3::Status::OK; } diff --git a/icing/file/file-backed-vector.h b/icing/file/file-backed-vector.h index 7e42e32..bcfbbdd 100644 --- a/icing/file/file-backed-vector.h +++ b/icing/file/file-backed-vector.h @@ -58,8 +58,12 @@ #include <sys/mman.h> +#include <algorithm> #include <cinttypes> #include <cstdint> +#include <cstring> +#include <functional> +#include <limits> #include <memory> #include <string> #include <utility> @@ -83,6 +87,9 @@ namespace lib { template <typename T> class FileBackedVector { public: + class MutableArrayView; + class MutableView; + // Header stored at the beginning of the file before the rest of the vector // elements. Stores metadata on the vector. struct Header { @@ -133,15 +140,24 @@ class FileBackedVector { kHeaderChecksumOffset, ""); - Crc32 crc; - std::string_view header_str( - reinterpret_cast<const char*>(this), - offsetof(FileBackedVector::Header, header_checksum)); - crc.Append(header_str); - return crc.Get(); + return Crc32(std::string_view( + reinterpret_cast<const char*>(this), + offsetof(FileBackedVector::Header, header_checksum))) + .Get(); } }; + // Absolute max file size for FileBackedVector. Note that Android has a + // (2^31-1)-byte single file size limit, so kMaxFileSize is 2^31-1. + static constexpr int32_t kMaxFileSize = + std::numeric_limits<int32_t>::max(); // 2^31-1 Bytes, ~2.1 GB; + + // Size of element type T. The value is same as sizeof(T), while we should + // avoid using sizeof(T) in our codebase to prevent unexpected unsigned + // integer casting. + static constexpr int32_t kElementTypeSize = static_cast<int32_t>(sizeof(T)); + static_assert(sizeof(T) <= (1 << 10)); + // Creates a new FileBackedVector to read/write content to. // // filesystem: Object to make system level calls @@ -149,15 +165,20 @@ class FileBackedVector { // within a directory that already exists. // mmap_strategy : Strategy/optimizations to access the content in the vector, // see MemoryMappedFile::Strategy for more details + // max_file_size: Maximum file size for FileBackedVector, default + // kMaxFileSize. See max_file_size_ and kMaxFileSize for more + // details. // // Return: // FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored // checksum. // INTERNAL_ERROR on I/O errors. + // INVALID_ARGUMENT_ERROR if max_file_size is incorrect. // UNIMPLEMENTED_ERROR if created with strategy READ_WRITE_MANUAL_SYNC. static libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> Create(const Filesystem& filesystem, const std::string& file_path, - MemoryMappedFile::Strategy mmap_strategy); + MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size = kMaxFileSize); // Deletes the FileBackedVector // @@ -184,13 +205,13 @@ class FileBackedVector { // referencing the now-invalidated region. // // Returns: - // OUT_OF_RANGE_ERROR if idx < 0 or > num_elements() + // OUT_OF_RANGE_ERROR if idx < 0 or idx >= num_elements() libtextclassifier3::StatusOr<T> GetCopy(int32_t idx) const; - // Gets a pointer to the element at idx. + // Gets an immutable pointer to the element at idx. // - // WARNING: Subsequent calls to Set may invalidate the pointer returned by - // Get. + // WARNING: Subsequent calls to Set/Append/Allocate may invalidate the pointer + // returned by Get. // // This is useful if you do not think the FileBackedVector will grow before // you need to reference this value, and you want to avoid a copy. When the @@ -198,27 +219,102 @@ class FileBackedVector { // which will invalidate this pointer to the previously mapped region. // // Returns: - // OUT_OF_RANGE_ERROR if idx < 0 or > num_elements() + // OUT_OF_RANGE_ERROR if idx < 0 or idx >= num_elements() libtextclassifier3::StatusOr<const T*> Get(int32_t idx) const; + // Gets a MutableView to the element at idx. + // + // WARNING: Subsequent calls to Set/Append/Allocate may invalidate the + // reference returned by MutableView::Get(). + // + // This is useful if you do not think the FileBackedVector will grow before + // you need to reference this value, and you want to mutate the underlying + // data directly. When the FileBackedVector grows, the underlying mmap will be + // unmapped and remapped, which will invalidate this MutableView to the + // previously mapped region. + // + // Returns: + // OUT_OF_RANGE_ERROR if idx < 0 or idx >= num_elements() + libtextclassifier3::StatusOr<MutableView> GetMutable(int32_t idx); + + // Gets a MutableArrayView to the elements at range [idx, idx + len). + // + // WARNING: Subsequent calls to Set/Append/Allocate may invalidate the + // reference/pointer returned by MutableArrayView::operator[]/data(). + // + // This is useful if you do not think the FileBackedVector will grow before + // you need to reference this value, and you want to mutate the underlying + // data directly. When the FileBackedVector grows, the underlying mmap will be + // unmapped and remapped, which will invalidate this MutableArrayView to the + // previously mapped region. + // + // Returns: + // OUT_OF_RANGE_ERROR if idx < 0 or idx + len > num_elements() + libtextclassifier3::StatusOr<MutableArrayView> GetMutable(int32_t idx, + int32_t len); + // Writes the value at idx. // // May grow the underlying file and mmapped region as needed to fit the new - // value. If it does grow, then any pointers to previous values returned - // from Get() may be invalidated. + // value. If it does grow, then any pointers/references to previous values + // returned from Get/GetMutable/Allocate may be invalidated. // // Returns: - // OUT_OF_RANGE_ERROR if idx < 0 or file cannot be grown idx size + // OUT_OF_RANGE_ERROR if idx < 0 or idx > kMaxIndex or file cannot be grown + // idx size libtextclassifier3::Status Set(int32_t idx, const T& value); + // Appends the value to the end of the vector. + // + // May grow the underlying file and mmapped region as needed to fit the new + // value. If it does grow, then any pointers/references to previous values + // returned from Get/GetMutable/Allocate may be invalidated. + // + // Returns: + // OUT_OF_RANGE_ERROR if file cannot be grown (i.e. reach max_file_size_) + libtextclassifier3::Status Append(const T& value) { + return Set(header_->num_elements, value); + } + + // Allocates spaces with given length in the end of the vector and returns a + // MutableArrayView to the space. + // + // May grow the underlying file and mmapped region as needed to fit the new + // value. If it does grow, then any pointers/references to previous values + // returned from Get/GetMutable/Allocate may be invalidated. + // + // WARNING: Subsequent calls to Set/Append/Allocate may invalidate the + // reference/pointer returned by MutableArrayView::operator[]/data(). + // + // This is useful if you do not think the FileBackedVector will grow before + // you need to reference this value, and you want to allocate adjacent spaces + // for multiple elements and mutate the underlying data directly. When the + // FileBackedVector grows, the underlying mmap will be unmapped and remapped, + // which will invalidate this MutableArrayView to the previously mapped + // region. + // + // Returns: + // OUT_OF_RANGE_ERROR if len <= 0 or file cannot be grown (i.e. reach + // max_file_size_) + libtextclassifier3::StatusOr<MutableArrayView> Allocate(int32_t len); + // Resizes to first len elements. The crc is cleared on truncation and will be // updated on destruction, or once the client calls ComputeChecksum() or // PersistToDisk(). // // Returns: - // OUT_OF_RANGE_ERROR if len < 0 or >= num_elements() + // OUT_OF_RANGE_ERROR if len < 0 or len >= num_elements() libtextclassifier3::Status TruncateTo(int32_t new_num_elements); + // Mark idx as changed iff idx < changes_end_, so later ComputeChecksum() can + // update checksum by the cached changes without going over [0, changes_end_). + // + // If the buffer size exceeds kPartialCrcLimitDiv, then clear all change + // buffers and set changes_end_ as 0, indicating that the checksum should be + // recomputed from idx 0 (starting from the beginning). Otherwise cache the + // change. + void SetDirty(int32_t idx); + // Flushes content to underlying file. // // Returns: @@ -248,10 +344,6 @@ class FileBackedVector { return reinterpret_cast<const T*>(mmapped_file_->region()); } - T* mutable_array() const { - return reinterpret_cast<T*>(mmapped_file_->mutable_region()); - } - int32_t num_elements() const { return header_->num_elements; } // Updates checksum of the vector contents and returns it. @@ -260,6 +352,66 @@ class FileBackedVector { // INTERNAL_ERROR if the vector's internal state is inconsistent libtextclassifier3::StatusOr<Crc32> ComputeChecksum(); + public: + class MutableArrayView { + public: + const T& operator[](int32_t idx) const { return data_[idx]; } + T& operator[](int32_t idx) { + SetDirty(idx); + return data_[idx]; + } + + const T* data() const { return data_; } + + int32_t size() const { return len_; } + + // Set the mutable array slice (starting at idx) by the given element array. + // It handles SetDirty properly for the file-backed-vector when modifying + // elements. + // + // REQUIRES: arr is valid && arr_len >= 0 && idx + arr_len <= size(), + // otherwise the behavior is undefined. + void SetArray(int32_t idx, const T* arr, int32_t arr_len) { + for (int32_t i = 0; i < arr_len; ++i) { + SetDirty(idx + i); + data_[idx + i] = arr[i]; + } + } + + private: + MutableArrayView(FileBackedVector<T>* vector, T* data, int32_t len) + : vector_(vector), + data_(data), + original_idx_(data - vector->array()), + len_(len) {} + + void SetDirty(int32_t idx) { vector_->SetDirty(original_idx_ + idx); } + + // Does not own. For SetDirty only. + FileBackedVector<T>* vector_; + + // data_ points at vector_->mutable_array()[original_idx_] + T* data_; + int32_t original_idx_; + int32_t len_; + + friend class FileBackedVector; + }; + + class MutableView { + public: + const T& Get() const { return mutable_array_view_[0]; } + T& Get() { return mutable_array_view_[0]; } + + private: + MutableView(FileBackedVector<T>* vector, T* data) + : mutable_array_view_(vector, data, 1) {} + + MutableArrayView mutable_array_view_; + + friend class FileBackedVector; + }; + private: // We track partial updates to the array for crc updating. This // requires extra memory to keep track of original buffers but @@ -271,24 +423,33 @@ class FileBackedVector { // Grow file by at least this many elements if array is growable. static constexpr int64_t kGrowElements = 1u << 14; // 16K - // Max number of elements that can be held by the vector. - static constexpr int64_t kMaxNumElements = 1u << 20; // 1M + // Absolute max # of elements allowed. Since we are using int32_t to store + // num_elements, max value is 2^31-1. Still the actual max # of elements are + // determined by max_file_size, kElementTypeSize, and Header::kHeaderSize. + static constexpr int32_t kMaxNumElements = + std::numeric_limits<int32_t>::max(); + + // Absolute max index allowed. + static constexpr int32_t kMaxIndex = kMaxNumElements - 1; // Can only be created through the factory ::Create function FileBackedVector(const Filesystem& filesystem, const std::string& file_path, std::unique_ptr<Header> header, - std::unique_ptr<MemoryMappedFile> mmapped_file); + std::unique_ptr<MemoryMappedFile> mmapped_file, + int32_t max_file_size); // Initialize a new FileBackedVector, and create the file. static libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> InitializeNewFile(const Filesystem& filesystem, const std::string& file_path, - ScopedFd fd, MemoryMappedFile::Strategy mmap_strategy); + ScopedFd fd, MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size); // Initialize a FileBackedVector from an existing file. static libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> InitializeExistingFile(const Filesystem& filesystem, const std::string& file_path, ScopedFd fd, - MemoryMappedFile::Strategy mmap_strategy); + MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size); // Grows the underlying file to hold at least num_elements // @@ -296,6 +457,10 @@ class FileBackedVector { // OUT_OF_RANGE_ERROR if we can't grow to the specified size libtextclassifier3::Status GrowIfNecessary(int32_t num_elements); + T* mutable_array() const { + return reinterpret_cast<T*>(mmapped_file_->mutable_region()); + } + // Cached constructor params. const Filesystem* const filesystem_; const std::string file_path_; @@ -314,25 +479,42 @@ class FileBackedVector { // update. Will be cleared if the size grows too big. std::string saved_original_buffer_; - // Keep track of all pages we touched so we can write them back to - // disk. - std::vector<bool> dirty_pages_; + // Max file size for FileBackedVector, default kMaxFileSize. Note that this + // value won't be written into the header, so maximum file size will always be + // specified in runtime and the caller should make sure its value is correct + // and reasonable. Note that file size includes size of header + elements. + // + // The range should be in + // [Header::kHeaderSize + kElementTypeSize, kMaxFileSize], and + // (max_file_size_ - Header::kHeaderSize) / kElementTypeSize is max # of + // elements that can be stored. + int32_t max_file_size_; }; template <typename T> +constexpr int32_t FileBackedVector<T>::kMaxFileSize; + +template <typename T> +constexpr int32_t FileBackedVector<T>::kElementTypeSize; + +template <typename T> constexpr int32_t FileBackedVector<T>::kPartialCrcLimitDiv; template <typename T> constexpr int64_t FileBackedVector<T>::kGrowElements; template <typename T> -constexpr int64_t FileBackedVector<T>::kMaxNumElements; +constexpr int32_t FileBackedVector<T>::kMaxNumElements; + +template <typename T> +constexpr int32_t FileBackedVector<T>::kMaxIndex; template <typename T> libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> FileBackedVector<T>::Create(const Filesystem& filesystem, const std::string& file_path, - MemoryMappedFile::Strategy mmap_strategy) { + MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size) { if (mmap_strategy == MemoryMappedFile::Strategy::READ_WRITE_MANUAL_SYNC) { // FileBackedVector's behavior of growing the file underneath the mmap is // inherently broken with MAP_PRIVATE. Growing the vector requires extending @@ -345,6 +527,14 @@ FileBackedVector<T>::Create(const Filesystem& filesystem, "mmap strategy."); } + if (max_file_size < Header::kHeaderSize + kElementTypeSize || + max_file_size > kMaxFileSize) { + // FileBackedVector should be able to store at least 1 element, so + // max_file_size should be at least Header::kHeaderSize + kElementTypeSize. + return absl_ports::InvalidArgumentError( + "Invalid max file size for FileBackedVector"); + } + ScopedFd fd(filesystem.OpenForWrite(file_path.c_str())); if (!fd.is_valid()) { return absl_ports::InternalError( @@ -357,31 +547,38 @@ FileBackedVector<T>::Create(const Filesystem& filesystem, absl_ports::StrCat("Bad file size for file ", file_path)); } + if (max_file_size < file_size) { + return absl_ports::InvalidArgumentError( + "Max file size should not be smaller than the existing file size"); + } + const bool new_file = file_size == 0; if (new_file) { return InitializeNewFile(filesystem, file_path, std::move(fd), - mmap_strategy); + mmap_strategy, max_file_size); } return InitializeExistingFile(filesystem, file_path, std::move(fd), - mmap_strategy); + mmap_strategy, max_file_size); } template <typename T> libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> -FileBackedVector<T>::InitializeNewFile( - const Filesystem& filesystem, const std::string& file_path, ScopedFd fd, - MemoryMappedFile::Strategy mmap_strategy) { +FileBackedVector<T>::InitializeNewFile(const Filesystem& filesystem, + const std::string& file_path, + ScopedFd fd, + MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size) { // Create header. auto header = std::make_unique<Header>(); header->magic = FileBackedVector<T>::Header::kMagic; - header->element_size = sizeof(T); + header->element_size = kElementTypeSize; header->header_checksum = header->CalculateHeaderChecksum(); // We use Write() here, instead of writing through the mmapped region // created below, so we can gracefully handle errors that occur when the // disk is full. See b/77309668 for details. if (!filesystem.PWrite(fd.get(), /*offset=*/0, header.get(), - sizeof(Header))) { + Header::kHeaderSize)) { return absl_ports::InternalError("Failed to write header"); } @@ -393,23 +590,30 @@ FileBackedVector<T>::InitializeNewFile( auto mmapped_file = std::make_unique<MemoryMappedFile>(filesystem, file_path, mmap_strategy); - return std::unique_ptr<FileBackedVector<T>>(new FileBackedVector<T>( - filesystem, file_path, std::move(header), std::move(mmapped_file))); + return std::unique_ptr<FileBackedVector<T>>( + new FileBackedVector<T>(filesystem, file_path, std::move(header), + std::move(mmapped_file), max_file_size)); } template <typename T> libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> FileBackedVector<T>::InitializeExistingFile( const Filesystem& filesystem, const std::string& file_path, - const ScopedFd fd, MemoryMappedFile::Strategy mmap_strategy) { + const ScopedFd fd, MemoryMappedFile::Strategy mmap_strategy, + int32_t max_file_size) { int64_t file_size = filesystem.GetFileSize(file_path.c_str()); - if (file_size < sizeof(FileBackedVector<T>::Header)) { + if (file_size == Filesystem::kBadFileSize) { + return absl_ports::InternalError( + absl_ports::StrCat("Bad file size for file ", file_path)); + } + + if (file_size < Header::kHeaderSize) { return absl_ports::InternalError( absl_ports::StrCat("File header too short for ", file_path)); } auto header = std::make_unique<Header>(); - if (!filesystem.PRead(fd.get(), header.get(), sizeof(Header), + if (!filesystem.PRead(fd.get(), header.get(), Header::kHeaderSize, /*offset=*/0)) { return absl_ports::InternalError( absl_ports::StrCat("Failed to read header of ", file_path)); @@ -429,13 +633,15 @@ FileBackedVector<T>::InitializeExistingFile( absl_ports::StrCat("Invalid header crc for ", file_path)); } - if (header->element_size != sizeof(T)) { + if (header->element_size != kElementTypeSize) { return absl_ports::InternalError(IcingStringUtil::StringPrintf( - "Inconsistent element size, expected %zd, actual %d", sizeof(T), + "Inconsistent element size, expected %d, actual %d", kElementTypeSize, header->element_size)); } - int64_t min_file_size = header->num_elements * sizeof(T) + sizeof(Header); + int64_t min_file_size = + static_cast<int64_t>(header->num_elements) * kElementTypeSize + + Header::kHeaderSize; if (min_file_size > file_size) { return absl_ports::InternalError(IcingStringUtil::StringPrintf( "Inconsistent file size, expected %" PRId64 ", actual %" PRId64, @@ -446,23 +652,22 @@ FileBackedVector<T>::InitializeExistingFile( // access elements from the mmapped region auto mmapped_file = std::make_unique<MemoryMappedFile>(filesystem, file_path, mmap_strategy); - ICING_RETURN_IF_ERROR( - mmapped_file->Remap(sizeof(Header), file_size - sizeof(Header))); + ICING_RETURN_IF_ERROR(mmapped_file->Remap(Header::kHeaderSize, + file_size - Header::kHeaderSize)); // Check vector contents - Crc32 vector_checksum; - std::string_view vector_contents( - reinterpret_cast<const char*>(mmapped_file->region()), - header->num_elements * sizeof(T)); - vector_checksum.Append(vector_contents); + Crc32 vector_checksum( + std::string_view(reinterpret_cast<const char*>(mmapped_file->region()), + header->num_elements * kElementTypeSize)); if (vector_checksum.Get() != header->vector_checksum) { return absl_ports::FailedPreconditionError( absl_ports::StrCat("Invalid vector contents for ", file_path)); } - return std::unique_ptr<FileBackedVector<T>>(new FileBackedVector<T>( - filesystem, file_path, std::move(header), std::move(mmapped_file))); + return std::unique_ptr<FileBackedVector<T>>( + new FileBackedVector<T>(filesystem, file_path, std::move(header), + std::move(mmapped_file), max_file_size)); } template <typename T> @@ -479,12 +684,13 @@ template <typename T> FileBackedVector<T>::FileBackedVector( const Filesystem& filesystem, const std::string& file_path, std::unique_ptr<Header> header, - std::unique_ptr<MemoryMappedFile> mmapped_file) + std::unique_ptr<MemoryMappedFile> mmapped_file, int32_t max_file_size) : filesystem_(&filesystem), file_path_(file_path), header_(std::move(header)), mmapped_file_(std::move(mmapped_file)), - changes_end_(header_->num_elements) {} + changes_end_(header_->num_elements), + max_file_size_(max_file_size) {} template <typename T> FileBackedVector<T>::~FileBackedVector() { @@ -523,6 +729,40 @@ libtextclassifier3::StatusOr<const T*> FileBackedVector<T>::Get( } template <typename T> +libtextclassifier3::StatusOr<typename FileBackedVector<T>::MutableView> +FileBackedVector<T>::GetMutable(int32_t idx) { + if (idx < 0) { + return absl_ports::OutOfRangeError( + IcingStringUtil::StringPrintf("Index, %d, was less than 0", idx)); + } + + if (idx >= header_->num_elements) { + return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( + "Index, %d, was greater than vector size, %d", idx, + header_->num_elements)); + } + + return MutableView(this, &mutable_array()[idx]); +} + +template <typename T> +libtextclassifier3::StatusOr<typename FileBackedVector<T>::MutableArrayView> +FileBackedVector<T>::GetMutable(int32_t idx, int32_t len) { + if (idx < 0) { + return absl_ports::OutOfRangeError( + IcingStringUtil::StringPrintf("Index, %d, was less than 0", idx)); + } + + if (idx > header_->num_elements - len) { + return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( + "Index with len, %d %d, was greater than vector size, %d", idx, len, + header_->num_elements)); + } + + return MutableArrayView(this, &mutable_array()[idx], len); +} + +template <typename T> libtextclassifier3::Status FileBackedVector<T>::Set(int32_t idx, const T& value) { if (idx < 0) { @@ -530,6 +770,11 @@ libtextclassifier3::Status FileBackedVector<T>::Set(int32_t idx, IcingStringUtil::StringPrintf("Index, %d, was less than 0", idx)); } + if (idx > kMaxIndex) { + return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( + "Index, %d, was greater than max index allowed, %d", idx, kMaxIndex)); + } + ICING_RETURN_IF_ERROR(GrowIfNecessary(idx + 1)); if (idx + 1 > header_->num_elements) { @@ -541,36 +786,39 @@ libtextclassifier3::Status FileBackedVector<T>::Set(int32_t idx, return libtextclassifier3::Status::OK; } - // Cache original value to update crcs. - if (idx < changes_end_) { - // If we exceed kPartialCrcLimitDiv, clear changes_end_ to - // revert to full CRC. - if ((saved_original_buffer_.size() + sizeof(T)) * - FileBackedVector<T>::kPartialCrcLimitDiv > - changes_end_ * sizeof(T)) { - ICING_VLOG(2) << "FileBackedVector change tracking limit exceeded"; - changes_.clear(); - saved_original_buffer_.clear(); - changes_end_ = 0; - header_->vector_checksum = 0; - } else { - int32_t start_byte = idx * sizeof(T); - - changes_.push_back(idx); - saved_original_buffer_.append( - reinterpret_cast<char*>(const_cast<T*>(array())) + start_byte, - sizeof(T)); - } - } + SetDirty(idx); mutable_array()[idx] = value; return libtextclassifier3::Status::OK; } template <typename T> +libtextclassifier3::StatusOr<typename FileBackedVector<T>::MutableArrayView> +FileBackedVector<T>::Allocate(int32_t len) { + if (len <= 0) { + return absl_ports::OutOfRangeError("Invalid allocate length"); + } + + if (len > kMaxNumElements - header_->num_elements) { + return absl_ports::OutOfRangeError( + IcingStringUtil::StringPrintf("Cannot allocate %d elements", len)); + } + + // Although header_->num_elements + len doesn't exceed kMaxNumElements, the + // actual max # of elements are determined by max_file_size, kElementTypeSize, + // and kHeaderSize. Thus, it is still possible to fail to grow the file. + ICING_RETURN_IF_ERROR(GrowIfNecessary(header_->num_elements + len)); + + int32_t start_idx = header_->num_elements; + header_->num_elements += len; + + return MutableArrayView(this, &mutable_array()[start_idx], len); +} + +template <typename T> libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( int32_t num_elements) { - if (sizeof(T) == 0) { + if (kElementTypeSize == 0) { // Growing is a no-op return libtextclassifier3::Status::OK; } @@ -579,10 +827,12 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( return libtextclassifier3::Status::OK; } - if (num_elements > FileBackedVector<T>::kMaxNumElements) { + if (num_elements > + (max_file_size_ - Header::kHeaderSize) / kElementTypeSize) { return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( - "%d exceeds maximum number of elements allowed, %lld", num_elements, - static_cast<long long>(FileBackedVector<T>::kMaxNumElements))); + "%d elements total size exceed maximum bytes of elements allowed, " + "%d bytes", + num_elements, max_file_size_ - Header::kHeaderSize)); } int64_t current_file_size = filesystem_->GetFileSize(file_path_.c_str()); @@ -590,7 +840,8 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( return absl_ports::InternalError("Unable to retrieve file size."); } - int64_t least_file_size_needed = sizeof(Header) + num_elements * sizeof(T); + int32_t least_file_size_needed = + Header::kHeaderSize + num_elements * kElementTypeSize; // Won't overflow if (least_file_size_needed <= current_file_size) { // Our underlying file can hold the target num_elements cause we've grown // before @@ -598,9 +849,13 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( } // Otherwise, we need to grow. Grow to kGrowElements boundary. - least_file_size_needed = math_util::RoundUpTo( - least_file_size_needed, - int64_t{FileBackedVector<T>::kGrowElements * sizeof(T)}); + // Note that we need to use int64_t here, since int32_t might overflow after + // round up. + int64_t round_up_file_size_needed = math_util::RoundUpTo( + int64_t{least_file_size_needed}, + int64_t{FileBackedVector<T>::kGrowElements} * kElementTypeSize); + least_file_size_needed = + std::min(round_up_file_size_needed, int64_t{max_file_size_}); // We use PWrite here rather than Grow because Grow doesn't actually allocate // an underlying disk block. This can lead to problems with mmap because mmap @@ -609,20 +864,22 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( // these blocks, which will ensure that any failure to grow will surface here. int64_t page_size = getpagesize(); auto buf = std::make_unique<uint8_t[]>(page_size); - int64_t size_to_write = page_size - (current_file_size % page_size); + int64_t size_to_write = std::min(page_size - (current_file_size % page_size), + max_file_size_ - current_file_size); ScopedFd sfd(filesystem_->OpenForWrite(file_path_.c_str())); - while (current_file_size < least_file_size_needed) { + while (size_to_write > 0 && current_file_size < least_file_size_needed) { if (!filesystem_->PWrite(sfd.get(), current_file_size, buf.get(), size_to_write)) { return absl_ports::InternalError( absl_ports::StrCat("Couldn't grow file ", file_path_)); } current_file_size += size_to_write; - size_to_write = page_size - (current_file_size % page_size); + size_to_write = std::min(page_size - (current_file_size % page_size), + max_file_size_ - current_file_size); } ICING_RETURN_IF_ERROR(mmapped_file_->Remap( - sizeof(Header), least_file_size_needed - sizeof(Header))); + Header::kHeaderSize, least_file_size_needed - Header::kHeaderSize)); return libtextclassifier3::Status::OK; } @@ -653,6 +910,31 @@ libtextclassifier3::Status FileBackedVector<T>::TruncateTo( } template <typename T> +void FileBackedVector<T>::SetDirty(int32_t idx) { + // Cache original value to update crcs. + if (idx >= 0 && idx < changes_end_) { + // If we exceed kPartialCrcLimitDiv, clear changes_end_ to + // revert to full CRC. + if ((saved_original_buffer_.size() + kElementTypeSize) * + FileBackedVector<T>::kPartialCrcLimitDiv > + changes_end_ * kElementTypeSize) { + ICING_VLOG(2) << "FileBackedVector change tracking limit exceeded"; + changes_.clear(); + saved_original_buffer_.clear(); + changes_end_ = 0; + header_->vector_checksum = 0; + } else { + int32_t start_byte = idx * kElementTypeSize; + + changes_.push_back(idx); + saved_original_buffer_.append( + reinterpret_cast<char*>(const_cast<T*>(array())) + start_byte, + kElementTypeSize); + } + } +} + +template <typename T> libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { // First apply the modified area. Keep a bitmap of already updated // regions so we don't double-update. @@ -663,8 +945,7 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { int num_truncated = 0; int num_overlapped = 0; int num_duplicate = 0; - for (size_t i = 0; i < changes_.size(); i++) { - const int32_t change_offset = changes_[i]; + for (const int32_t change_offset : changes_) { if (change_offset > changes_end_) { return absl_ports::InternalError(IcingStringUtil::StringPrintf( "Failed to update crc, change offset %d, changes_end_ %d", @@ -678,9 +959,10 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { } // Turn change buffer into change^original. - const char* buffer_end = &saved_original_buffer_[cur_offset + sizeof(T)]; - const char* cur_array = - reinterpret_cast<const char*>(array()) + change_offset * sizeof(T); + const char* buffer_end = + &saved_original_buffer_[cur_offset + kElementTypeSize]; + const char* cur_array = reinterpret_cast<const char*>(array()) + + change_offset * kElementTypeSize; // Now xor in. SSE acceleration please? for (char* cur = &saved_original_buffer_[cur_offset]; cur < buffer_end; cur++, cur_array++) { @@ -692,9 +974,9 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { bool overlap = false; uint32_t cur_element = change_offset; for (char* cur = &saved_original_buffer_[cur_offset]; cur < buffer_end; - cur_element++, cur += sizeof(T)) { + cur_element++, cur += kElementTypeSize) { if (updated[cur_element]) { - memset(cur, 0, sizeof(T)); + memset(cur, 0, kElementTypeSize); overlap = true; } else { updated[cur_element] = true; @@ -705,10 +987,11 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { // Apply update to crc. if (new_update) { // Explicitly create the string_view with length - std::string_view xored_str(buffer_end - sizeof(T), sizeof(T)); + std::string_view xored_str(buffer_end - kElementTypeSize, + kElementTypeSize); if (!cur_crc - .UpdateWithXor(xored_str, changes_end_ * sizeof(T), - change_offset * sizeof(T)) + .UpdateWithXor(xored_str, changes_end_ * kElementTypeSize, + change_offset * kElementTypeSize) .ok()) { return absl_ports::InternalError(IcingStringUtil::StringPrintf( "Failed to update crc, change offset %d, change " @@ -722,7 +1005,7 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { } else { num_duplicate++; } - cur_offset += sizeof(T); + cur_offset += kElementTypeSize; } if (!changes_.empty()) { @@ -735,8 +1018,9 @@ libtextclassifier3::StatusOr<Crc32> FileBackedVector<T>::ComputeChecksum() { if (changes_end_ < header_->num_elements) { // Explicitly create the string_view with length std::string_view update_str( - reinterpret_cast<const char*>(array()) + changes_end_ * sizeof(T), - (header_->num_elements - changes_end_) * sizeof(T)); + reinterpret_cast<const char*>(array()) + + changes_end_ * kElementTypeSize, + (header_->num_elements - changes_end_) * kElementTypeSize); cur_crc.Append(update_str); ICING_VLOG(2) << IcingStringUtil::StringPrintf( "Array update tail crc offset %d -> %d", changes_end_, @@ -761,7 +1045,7 @@ libtextclassifier3::Status FileBackedVector<T>::PersistToDisk() { header_->header_checksum = header_->CalculateHeaderChecksum(); if (!filesystem_->PWrite(file_path_.c_str(), /*offset=*/0, header_.get(), - sizeof(Header))) { + Header::kHeaderSize)) { return absl_ports::InternalError("Failed to sync header"); } @@ -795,7 +1079,11 @@ libtextclassifier3::StatusOr<int64_t> FileBackedVector<T>::GetElementsFileSize() return absl_ports::InternalError( "Failed to get file size of elements in the file-backed vector"); } - return total_file_size - sizeof(Header); + if (total_file_size < Header::kHeaderSize) { + return absl_ports::InternalError( + "File size should not be smaller than header size"); + } + return total_file_size - Header::kHeaderSize; } } // namespace lib diff --git a/icing/file/file-backed-vector_test.cc b/icing/file/file-backed-vector_test.cc index 2f60c6b..60ed887 100644 --- a/icing/file/file-backed-vector_test.cc +++ b/icing/file/file-backed-vector_test.cc @@ -19,7 +19,9 @@ #include <algorithm> #include <cerrno> #include <cstdint> +#include <limits> #include <memory> +#include <string> #include <string_view> #include <vector> @@ -34,10 +36,14 @@ #include "icing/util/crc32.h" #include "icing/util/logging.h" +using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsTrue; +using ::testing::Lt; +using ::testing::Not; using ::testing::Pointee; using ::testing::Return; +using ::testing::SizeIs; namespace icing { namespace lib { @@ -60,20 +66,30 @@ class FileBackedVectorTest : public testing::Test { // Helper method to loop over some data and insert into the vector at some idx template <typename T> - void Insert(FileBackedVector<T>* vector, int32_t idx, std::string data) { - for (int i = 0; i < data.length(); ++i) { + void Insert(FileBackedVector<T>* vector, int32_t idx, + const std::vector<T>& data) { + for (int i = 0; i < data.size(); ++i) { ICING_ASSERT_OK(vector->Set(idx + i, data.at(i))); } } + void Insert(FileBackedVector<char>* vector, int32_t idx, std::string data) { + Insert(vector, idx, std::vector<char>(data.begin(), data.end())); + } + // Helper method to retrieve data from the beginning of the vector template <typename T> - std::string_view Get(FileBackedVector<T>* vector, int32_t expected_len) { + std::vector<T> Get(FileBackedVector<T>* vector, int32_t idx, + int32_t expected_len) { + return std::vector<T>(vector->array() + idx, + vector->array() + idx + expected_len); + } + + std::string_view Get(FileBackedVector<char>* vector, int32_t expected_len) { return Get(vector, 0, expected_len); } - template <typename T> - std::string_view Get(FileBackedVector<T>* vector, int32_t idx, + std::string_view Get(FileBackedVector<char>* vector, int32_t idx, int32_t expected_len) { return std::string_view(vector->array() + idx, expected_len); } @@ -103,6 +119,79 @@ TEST_F(FileBackedVectorTest, Create) { } } +TEST_F(FileBackedVectorTest, CreateWithInvalidStrategy) { + // Create a vector with unimplemented strategy + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_MANUAL_SYNC), + StatusIs(libtextclassifier3::StatusCode::UNIMPLEMENTED)); +} + +TEST_F(FileBackedVectorTest, CreateWithCustomMaxFileSize) { + int32_t header_size = FileBackedVector<char>::Header::kHeaderSize; + + // Create a vector with invalid max_file_size + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/-1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/header_size - 1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/header_size + sizeof(char) - 1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + { + // Create a vector with max_file_size that allows only 1 element. + ICING_ASSERT_OK_AND_ASSIGN( + auto vector, FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/header_size + sizeof(char) * 1)); + ICING_ASSERT_OK(vector->Set(0, 'a')); + } + + { + // We can create it again with larger max_file_size, as long as it is not + // greater than kMaxFileSize. + ICING_ASSERT_OK_AND_ASSIGN( + auto vector, FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/header_size + sizeof(char) * 2)); + EXPECT_THAT(vector->Get(0), IsOkAndHolds(Pointee(Eq('a')))); + ICING_ASSERT_OK(vector->Set(1, 'b')); + } + + // We cannot create it again with max_file_size < current_file_size, even if + // it is a valid value. + int64_t current_file_size = filesystem_.GetFileSize(file_path_.c_str()); + ASSERT_THAT(current_file_size, Eq(header_size + sizeof(char) * 2)); + ASSERT_THAT(current_file_size - 1, Not(Lt(header_size + sizeof(char)))); + EXPECT_THAT(FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/current_file_size - 1), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + + { + // We can create it again with max_file_size == current_file_size. + ICING_ASSERT_OK_AND_ASSIGN( + auto vector, FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, + /*max_file_size=*/current_file_size)); + EXPECT_THAT(vector->Get(0), IsOkAndHolds(Pointee(Eq('a')))); + EXPECT_THAT(vector->Get(1), IsOkAndHolds(Pointee(Eq('b')))); + } +} + TEST_F(FileBackedVectorTest, SimpleShared) { // Create a vector and add some data. ICING_ASSERT_OK_AND_ASSIGN( @@ -195,6 +284,373 @@ TEST_F(FileBackedVectorTest, Get) { StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); } +TEST_F(FileBackedVectorTest, MutableView) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::string(1000, 'a')); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2620640643U))); + + ICING_ASSERT_OK_AND_ASSIGN(FileBackedVector<char>::MutableView mutable_elt, + vector->GetMutable(3)); + + mutable_elt.Get() = 'b'; + EXPECT_THAT(vector->Get(3), IsOkAndHolds(Pointee(Eq('b')))); + + mutable_elt.Get() = 'c'; + EXPECT_THAT(vector->Get(3), IsOkAndHolds(Pointee(Eq('c')))); +} + +TEST_F(FileBackedVectorTest, MutableViewShouldSetDirty) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::string(1000, 'a')); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2620640643U))); + + std::string_view reconstructed_view = + std::string_view(vector->array(), vector->num_elements()); + + ICING_ASSERT_OK_AND_ASSIGN(FileBackedVector<char>::MutableView mutable_elt, + vector->GetMutable(3)); + + // Mutate the element via MutateView + // If non-const Get() is called, MutateView should set the element index dirty + // so that ComputeChecksum() can pick up the change and compute the checksum + // correctly. Validate by mapping another array on top. + mutable_elt.Get() = 'b'; + ASSERT_THAT(vector->Get(3), IsOkAndHolds(Pointee(Eq('b')))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc1, vector->ComputeChecksum()); + Crc32 full_crc1; + full_crc1.Append(reconstructed_view); + EXPECT_THAT(crc1, Eq(full_crc1)); + + // Mutate and test again. + mutable_elt.Get() = 'c'; + ASSERT_THAT(vector->Get(3), IsOkAndHolds(Pointee(Eq('c')))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc2, vector->ComputeChecksum()); + Crc32 full_crc2; + full_crc2.Append(reconstructed_view); + EXPECT_THAT(crc2, Eq(full_crc2)); +} + +TEST_F(FileBackedVectorTest, MutableArrayView) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + constexpr int kArrayViewOffset = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, /*len=*/3)); + EXPECT_THAT(mutable_arr, SizeIs(3)); + + mutable_arr[0] = 2; + mutable_arr[1] = 3; + mutable_arr[2] = 4; + + EXPECT_THAT(vector->Get(kArrayViewOffset + 0), IsOkAndHolds(Pointee(Eq(2)))); + EXPECT_THAT(mutable_arr.data()[0], Eq(2)); + + EXPECT_THAT(vector->Get(kArrayViewOffset + 1), IsOkAndHolds(Pointee(Eq(3)))); + EXPECT_THAT(mutable_arr.data()[1], Eq(3)); + + EXPECT_THAT(vector->Get(kArrayViewOffset + 2), IsOkAndHolds(Pointee(Eq(4)))); + EXPECT_THAT(mutable_arr.data()[2], Eq(4)); +} + +TEST_F(FileBackedVectorTest, MutableArrayViewSetArray) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + constexpr int kArrayViewOffset = 3; + constexpr int kArrayViewLen = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, kArrayViewLen)); + + std::vector<int> change1{2, 3, 4}; + mutable_arr.SetArray(/*idx=*/0, change1.data(), change1.size()); + EXPECT_THAT(Get(vector.get(), kArrayViewOffset, kArrayViewLen), + ElementsAre(2, 3, 4, 1, 1)); + + std::vector<int> change2{5, 6}; + mutable_arr.SetArray(/*idx=*/2, change2.data(), change2.size()); + EXPECT_THAT(Get(vector.get(), kArrayViewOffset, kArrayViewLen), + ElementsAre(2, 3, 5, 6, 1)); +} + +TEST_F(FileBackedVectorTest, MutableArrayViewSetArrayWithZeroLength) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + constexpr int kArrayViewOffset = 3; + constexpr int kArrayViewLen = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, kArrayViewLen)); + + // Zero arr_len should work and change nothing + std::vector<int> change{2, 3}; + mutable_arr.SetArray(/*idx=*/0, change.data(), /*arr_len=*/0); + EXPECT_THAT(Get(vector.get(), kArrayViewOffset, kArrayViewLen), + ElementsAre(1, 1, 1, 1, 1)); +} + +TEST_F(FileBackedVectorTest, MutableArrayViewIndexOperatorShouldSetDirty) { + // Create an array with some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + std::string_view reconstructed_view( + reinterpret_cast<const char*>(vector->array()), + vector->num_elements() * sizeof(int)); + + constexpr int kArrayViewOffset = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, /*len=*/3)); + + // Use operator[] to mutate elements + // If non-const operator[] is called, MutateView should set the element index + // dirty so that ComputeChecksum() can pick up the change and compute the + // checksum correctly. Validate by mapping another array on top. + mutable_arr[0] = 2; + ASSERT_THAT(vector->Get(kArrayViewOffset + 0), IsOkAndHolds(Pointee(Eq(2)))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc1, vector->ComputeChecksum()); + EXPECT_THAT(crc1, Eq(Crc32(reconstructed_view))); + + mutable_arr[1] = 3; + ASSERT_THAT(vector->Get(kArrayViewOffset + 1), IsOkAndHolds(Pointee(Eq(3)))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc2, vector->ComputeChecksum()); + EXPECT_THAT(crc2, Eq(Crc32(reconstructed_view))); + + mutable_arr[2] = 4; + ASSERT_THAT(vector->Get(kArrayViewOffset + 2), IsOkAndHolds(Pointee(Eq(4)))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc3, vector->ComputeChecksum()); + EXPECT_THAT(crc3, Eq(Crc32(reconstructed_view))); + + // Change the same position. It should set dirty again. + mutable_arr[0] = 5; + ASSERT_THAT(vector->Get(kArrayViewOffset + 0), IsOkAndHolds(Pointee(Eq(5)))); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc4, vector->ComputeChecksum()); + EXPECT_THAT(crc4, Eq(Crc32(reconstructed_view))); +} + +TEST_F(FileBackedVectorTest, MutableArrayViewSetArrayShouldSetDirty) { + // Create an array with some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::vector<int>(/*count=*/100, /*value=*/1)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(2494890115U))); + + std::string_view reconstructed_view( + reinterpret_cast<const char*>(vector->array()), + vector->num_elements() * sizeof(int)); + + constexpr int kArrayViewOffset = 3; + constexpr int kArrayViewLen = 5; + ICING_ASSERT_OK_AND_ASSIGN( + FileBackedVector<int>::MutableArrayView mutable_arr, + vector->GetMutable(kArrayViewOffset, kArrayViewLen)); + + std::vector<int> change{2, 3, 4}; + mutable_arr.SetArray(/*idx=*/0, change.data(), change.size()); + ASSERT_THAT(Get(vector.get(), kArrayViewOffset, kArrayViewLen), + ElementsAre(2, 3, 4, 1, 1)); + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc, vector->ComputeChecksum()); + EXPECT_THAT(crc, Eq(Crc32(reconstructed_view))); +} + +TEST_F(FileBackedVectorTest, Append) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + ICING_EXPECT_OK(vector->Append('a')); + EXPECT_THAT(vector->num_elements(), Eq(1)); + EXPECT_THAT(vector->Get(0), IsOkAndHolds(Pointee(Eq('a')))); + + ICING_EXPECT_OK(vector->Append('b')); + EXPECT_THAT(vector->num_elements(), Eq(2)); + EXPECT_THAT(vector->Get(1), IsOkAndHolds(Pointee(Eq('b')))); +} + +TEST_F(FileBackedVectorTest, AppendAfterSet) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + ICING_ASSERT_OK(vector->Set(9, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(10)); + ICING_EXPECT_OK(vector->Append('a')); + EXPECT_THAT(vector->num_elements(), Eq(11)); + EXPECT_THAT(vector->Get(10), IsOkAndHolds(Pointee(Eq('a')))); +} + +TEST_F(FileBackedVectorTest, AppendAfterTruncate) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::string(1000, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(1000)); + + ICING_ASSERT_OK(vector->TruncateTo(5)); + ICING_EXPECT_OK(vector->Append('a')); + EXPECT_THAT(vector->num_elements(), Eq(6)); + EXPECT_THAT(vector->Get(5), IsOkAndHolds(Pointee(Eq('a')))); +} + +TEST_F(FileBackedVectorTest, AppendShouldFailIfExceedingMaxFileSize) { + int32_t max_file_size = (1 << 10) - 1; + int32_t max_num_elements = + (max_file_size - FileBackedVector<char>::Header::kHeaderSize) / + sizeof(char); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size)); + ICING_ASSERT_OK(vector->Set(max_num_elements - 1, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(max_num_elements)); + + EXPECT_THAT(vector->Append('a'), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); +} + +TEST_F(FileBackedVectorTest, Allocate) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + ICING_ASSERT_OK_AND_ASSIGN( + typename FileBackedVector<char>::MutableArrayView mutable_arr, + vector->Allocate(3)); + EXPECT_THAT(vector->num_elements(), Eq(3)); + EXPECT_THAT(mutable_arr, SizeIs(3)); + std::string change = "abc"; + mutable_arr.SetArray(/*idx=*/0, /*arr=*/change.data(), /*arr_len=*/3); + EXPECT_THAT(Get(vector.get(), /*idx=*/0, /*expected_len=*/3), Eq(change)); +} + +TEST_F(FileBackedVectorTest, AllocateAfterSet) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + ICING_ASSERT_OK(vector->Set(9, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(10)); + ICING_ASSERT_OK_AND_ASSIGN( + typename FileBackedVector<char>::MutableArrayView mutable_arr, + vector->Allocate(3)); + EXPECT_THAT(vector->num_elements(), Eq(13)); + EXPECT_THAT(mutable_arr, SizeIs(3)); + std::string change = "abc"; + mutable_arr.SetArray(/*idx=*/0, /*arr=*/change.data(), /*arr_len=*/3); + EXPECT_THAT(Get(vector.get(), /*idx=*/10, /*expected_len=*/3), Eq(change)); +} + +TEST_F(FileBackedVectorTest, AllocateAfterTruncate) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), /*idx=*/0, std::string(1000, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(1000)); + + ICING_ASSERT_OK(vector->TruncateTo(5)); + ICING_ASSERT_OK_AND_ASSIGN( + typename FileBackedVector<char>::MutableArrayView mutable_arr, + vector->Allocate(3)); + EXPECT_THAT(vector->num_elements(), Eq(8)); + std::string change = "abc"; + mutable_arr.SetArray(/*idx=*/0, /*arr=*/change.data(), /*arr_len=*/3); + EXPECT_THAT(Get(vector.get(), /*idx=*/5, /*expected_len=*/3), Eq(change)); +} + +TEST_F(FileBackedVectorTest, AllocateInvalidLengthShouldFail) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ASSERT_THAT(vector->num_elements(), Eq(0)); + + EXPECT_THAT(vector->Allocate(-1), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->num_elements(), Eq(0)); + + EXPECT_THAT(vector->Allocate(0), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->num_elements(), Eq(0)); +} + +TEST_F(FileBackedVectorTest, AllocateShouldFailIfExceedingMaxFileSize) { + int32_t max_file_size = (1 << 10) - 1; + int32_t max_num_elements = + (max_file_size - FileBackedVector<char>::Header::kHeaderSize) / + sizeof(char); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size)); + ICING_ASSERT_OK(vector->Set(max_num_elements - 3, 'z')); + ASSERT_THAT(vector->num_elements(), Eq(max_num_elements - 2)); + + EXPECT_THAT(vector->Allocate(3), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->Allocate(2), IsOk()); +} + TEST_F(FileBackedVectorTest, IncrementalCrc_NonOverlappingChanges) { int num_elements = 1000; int incremental_size = 3; @@ -272,29 +728,58 @@ TEST_F(FileBackedVectorTest, IncrementalCrc_OverlappingChanges) { } } +TEST_F(FileBackedVectorTest, SetIntMaxShouldReturnOutOfRangeError) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int32_t>> vector, + FileBackedVector<int32_t>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(0))); + + // It is an edge case. Since Set() calls GrowIfNecessary(idx + 1), we have to + // make sure that when idx is INT32_MAX, Set() should handle it correctly. + EXPECT_THAT(vector->Set(std::numeric_limits<int32_t>::max(), 1), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); +} + TEST_F(FileBackedVectorTest, Grow) { - // This is the same value as FileBackedVector::kMaxNumElts - constexpr int32_t kMaxNumElts = 1U << 20; + int32_t max_file_size = (1 << 20) - 1; + int32_t header_size = FileBackedVector<int32_t>::Header::kHeaderSize; + int32_t element_type_size = static_cast<int32_t>(sizeof(int32_t)); + + // Max file size includes size of the header and elements, so max # of + // elements will be (max_file_size - header_size) / element_type_size. + // + // Also ensure that (max_file_size - header_size) is not a multiple of + // element_type_size, in order to test if the desired # of elements is + // computed by (math) floor instead of ceil. + ASSERT_THAT((max_file_size - header_size) % element_type_size, Not(Eq(0))); + int32_t max_num_elements = (max_file_size - header_size) / element_type_size; ASSERT_TRUE(filesystem_.Truncate(fd_, 0)); - // Create an array and add some data. + // Create a vector and add some data. ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<FileBackedVector<char>> vector, - FileBackedVector<char>::Create( + std::unique_ptr<FileBackedVector<int32_t>> vector, + FileBackedVector<int32_t>::Create( filesystem_, file_path_, - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size)); EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(0))); - EXPECT_THAT(vector->Set(kMaxNumElts + 11, 'a'), + // max_num_elements is the allowed max # of elements, so the valid index + // should be 0 to max_num_elements-1. + EXPECT_THAT(vector->Set(max_num_elements, 1), StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); - EXPECT_THAT(vector->Set(-1, 'a'), + EXPECT_THAT(vector->Set(-1, 1), StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->Set(max_num_elements - 1, 1), IsOk()); - uint32_t start = kMaxNumElts - 13; - Insert(vector.get(), start, "abcde"); + int32_t start = max_num_elements - 5; + std::vector<int32_t> data{1, 2, 3, 4, 5}; + Insert(vector.get(), start, data); // Crc works? - const Crc32 good_crc(1134899064U); + const Crc32 good_crc(650981917U); EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(good_crc)); // PersistToDisk does nothing bad, and ensures the content is still there @@ -306,12 +791,12 @@ TEST_F(FileBackedVectorTest, Grow) { vector.reset(); ICING_ASSERT_OK_AND_ASSIGN( - vector, FileBackedVector<char>::Create( - filesystem_, file_path_, - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + vector, + FileBackedVector<int32_t>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, max_file_size)); - std::string expected = "abcde"; - EXPECT_EQ(expected, Get(vector.get(), start, expected.length())); + EXPECT_THAT(Get(vector.get(), start, data.size()), Eq(data)); } TEST_F(FileBackedVectorTest, GrowsInChunks) { @@ -334,20 +819,20 @@ TEST_F(FileBackedVectorTest, GrowsInChunks) { // Once we add something though, we'll grow to be kGrowElements big. From this // point on, file size and disk usage should be the same because Growing will // explicitly allocate the number of blocks needed to accomodate the file. - Insert(vector.get(), 0, "a"); - int file_size = kGrowElements * sizeof(int); + Insert(vector.get(), 0, {1}); + int file_size = 1 * kGrowElements * sizeof(int); EXPECT_THAT(filesystem_.GetFileSize(fd_), Eq(file_size)); EXPECT_THAT(filesystem_.GetDiskUsage(fd_), Eq(file_size)); // Should still be the same size, don't need to grow underlying file - Insert(vector.get(), 1, "b"); + Insert(vector.get(), 1, {2}); EXPECT_THAT(filesystem_.GetFileSize(fd_), Eq(file_size)); EXPECT_THAT(filesystem_.GetDiskUsage(fd_), Eq(file_size)); // Now we grow by a kGrowElements chunk, so the underlying file is 2 // kGrowElements big - file_size *= 2; - Insert(vector.get(), 2, std::string(kGrowElements, 'c')); + file_size = 2 * kGrowElements * sizeof(int); + Insert(vector.get(), 2, std::vector<int>(kGrowElements, 3)); EXPECT_THAT(filesystem_.GetFileSize(fd_), Eq(file_size)); EXPECT_THAT(filesystem_.GetDiskUsage(fd_), Eq(file_size)); @@ -476,6 +961,48 @@ TEST_F(FileBackedVectorTest, TruncateAndReReadFile) { } } +TEST_F(FileBackedVectorTest, SetDirty) { + // 1. Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + Insert(vector.get(), 0, "abcd"); + + std::string_view reconstructed_view = + std::string_view(vector->array(), vector->num_elements()); + + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc1, vector->ComputeChecksum()); + Crc32 full_crc_before_overwrite; + full_crc_before_overwrite.Append(reconstructed_view); + EXPECT_THAT(crc1, Eq(full_crc_before_overwrite)); + + // 2. Manually overwrite the values of the first two elements. + std::string corrupted_content = "ef"; + ASSERT_THAT( + filesystem_.PWrite(fd_, /*offset=*/sizeof(FileBackedVector<char>::Header), + corrupted_content.c_str(), corrupted_content.length()), + IsTrue()); + ASSERT_THAT(Get(vector.get(), 0, 4), Eq("efcd")); + Crc32 full_crc_after_overwrite; + full_crc_after_overwrite.Append(reconstructed_view); + ASSERT_THAT(full_crc_before_overwrite, Not(Eq(full_crc_after_overwrite))); + + // 3. Without calling SetDirty(), the checksum will be recomputed incorrectly. + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc2, vector->ComputeChecksum()); + EXPECT_THAT(crc2, Not(Eq(full_crc_after_overwrite))); + + // 4. Call SetDirty() + vector->SetDirty(0); + vector->SetDirty(1); + + // 5. The checksum should be computed correctly after calling SetDirty() with + // correct index. + ICING_ASSERT_OK_AND_ASSIGN(Crc32 crc3, vector->ComputeChecksum()); + EXPECT_THAT(crc3, Eq(full_crc_after_overwrite)); +} + TEST_F(FileBackedVectorTest, InitFileTooSmallForHeaderFails) { { // 1. Create a vector with a few elements. diff --git a/icing/file/filesystem.cc b/icing/file/filesystem.cc index 82b8d98..10b77db 100644 --- a/icing/file/filesystem.cc +++ b/icing/file/filesystem.cc @@ -63,18 +63,16 @@ void LogOpenFileDescriptors() { constexpr int kMaxFileDescriptorsToStat = 4096; struct rlimit rlim = {0, 0}; if (getrlimit(RLIMIT_NOFILE, &rlim) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "getrlimit() failed (errno=%d)", errno); + ICING_LOG(ERROR) << "getrlimit() failed (errno=" << errno << ")"; return; } int fd_lim = rlim.rlim_cur; if (fd_lim > kMaxFileDescriptorsToStat) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Maximum number of file descriptors (%d) too large.", fd_lim); + ICING_LOG(ERROR) << "Maximum number of file descriptors (" << fd_lim + << ") too large."; fd_lim = kMaxFileDescriptorsToStat; } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Listing up to %d file descriptors.", fd_lim); + ICING_LOG(ERROR) << "Listing up to " << fd_lim << " file descriptors."; // Verify that /proc/self/fd is a directory. If not, procfs is not mounted or // inaccessible for some other reason. In that case, there's no point trying @@ -96,15 +94,12 @@ void LogOpenFileDescriptors() { if (len >= 0) { // Zero-terminate the buffer, because readlink() won't. target[len < target_size ? len : target_size - 1] = '\0'; - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("fd %d -> \"%s\"", fd, - target); + ICING_LOG(ERROR) << "fd " << fd << " -> \"" << target << "\""; } else if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("fd %d -> ? (errno=%d)", - fd, errno); + ICING_LOG(ERROR) << "fd " << fd << " -> ? (errno=" << errno << ")"; } } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "File descriptor list complete."); + ICING_LOG(ERROR) << "File descriptor list complete."; } // Logs an error formatted as: desc1 + file_name + desc2 + strerror(errnum). @@ -113,8 +108,7 @@ void LogOpenFileDescriptors() { // file descriptors (see LogOpenFileDescriptors() above). void LogOpenError(const char* desc1, const char* file_name, const char* desc2, int errnum) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "%s%s%s%s", desc1, file_name, desc2, strerror(errnum)); + ICING_LOG(ERROR) << desc1 << file_name << desc2 << strerror(errnum); if (errnum == EMFILE) { LogOpenFileDescriptors(); } @@ -155,8 +149,7 @@ bool ListDirectoryInternal(const char* dir_name, } } if (closedir(dir) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Error closing %s: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Error closing " << dir_name << " " << strerror(errno); } return true; } @@ -179,11 +172,10 @@ void ScopedFd::reset(int fd) { const int64_t Filesystem::kBadFileSize; bool Filesystem::DeleteFile(const char* file_name) const { - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Deleting file %s", file_name); + ICING_VLOG(1) << "Deleting file " << file_name; int ret = unlink(file_name); if (ret != 0 && errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Deleting file %s failed: %s", file_name, strerror(errno)); + ICING_LOG(ERROR) << "Deleting file " << file_name << " failed: " << strerror(errno); return false; } return true; @@ -192,8 +184,7 @@ bool Filesystem::DeleteFile(const char* file_name) const { bool Filesystem::DeleteDirectory(const char* dir_name) const { int ret = rmdir(dir_name); if (ret != 0 && errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Deleting directory %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Deleting directory " << dir_name << " failed: " << strerror(errno); return false; } return true; @@ -206,8 +197,7 @@ bool Filesystem::DeleteDirectoryRecursively(const char* dir_name) const { if (errno == ENOENT) { return true; // If directory didn't exist, this was successful. } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Stat %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Stat " << dir_name << " failed: " << strerror(errno); return false; } vector<std::string> entries; @@ -220,8 +210,7 @@ bool Filesystem::DeleteDirectoryRecursively(const char* dir_name) const { ++i) { std::string filename = std::string(dir_name) + '/' + *i; if (stat(filename.c_str(), &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Stat %s failed: %s", filename.c_str(), strerror(errno)); + ICING_LOG(ERROR) << "Stat " << filename << " failed: " << strerror(errno); success = false; } else if (S_ISDIR(st.st_mode)) { success = DeleteDirectoryRecursively(filename.c_str()) && success; @@ -244,8 +233,7 @@ bool Filesystem::FileExists(const char* file_name) const { exists = S_ISREG(st.st_mode) != 0; } else { if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", file_name, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file " << file_name << ": " << strerror(errno); } exists = false; } @@ -259,8 +247,7 @@ bool Filesystem::DirectoryExists(const char* dir_name) const { exists = S_ISDIR(st.st_mode) != 0; } else { if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat directory %s: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat directory " << dir_name << ": " << strerror(errno); } exists = false; } @@ -316,8 +303,7 @@ bool Filesystem::GetMatchingFiles(const char* glob, int basename_idx = GetBasenameIndex(glob); if (basename_idx == 0) { // We need a directory. - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Expected directory, no matching files for: %s", glob); + ICING_VLOG(1) << "Expected directory, no matching files for: " << glob; return true; } const char* basename_glob = glob + basename_idx; @@ -372,8 +358,7 @@ int Filesystem::OpenForRead(const char* file_name) const { int64_t Filesystem::GetFileSize(int fd) const { struct stat st; if (fstat(fd, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file: " << strerror(errno); return kBadFileSize; } return st.st_size; @@ -383,11 +368,9 @@ int64_t Filesystem::GetFileSize(const char* filename) const { struct stat st; if (stat(filename, &st) < 0) { if (errno == ENOENT) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", filename, strerror(errno)); + ICING_VLOG(1) << "Unable to stat file " << filename << ": " << strerror(errno); } else { - ICING_LOG(WARNING) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", filename, strerror(errno)); + ICING_LOG(WARNING) << "Unable to stat file " << filename << ": " << strerror(errno); } return kBadFileSize; } @@ -396,8 +379,7 @@ int64_t Filesystem::GetFileSize(const char* filename) const { bool Filesystem::Truncate(int fd, int64_t new_size) const { if (ftruncate(fd, new_size) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to truncate file: %s", strerror(errno)); + ICING_LOG(ERROR) << "Unable to truncate file: " << strerror(errno); return false; } lseek(fd, new_size, SEEK_SET); @@ -416,8 +398,7 @@ bool Filesystem::Truncate(const char* filename, int64_t new_size) const { bool Filesystem::Grow(int fd, int64_t new_size) const { if (ftruncate(fd, new_size) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to grow file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to grow file: " << strerror(errno); return false; } @@ -442,8 +423,7 @@ bool Filesystem::Write(int fd, const void* data, size_t data_size) const { size_t chunk_size = std::min<size_t>(write_len, 64u * 1024); ssize_t wrote = write(fd, data, chunk_size); if (wrote < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad write: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad write: " << strerror(errno); return false; } data = static_cast<const uint8_t*>(data) + wrote; @@ -521,8 +501,7 @@ bool Filesystem::CopyDirectory(const char* src_dir, const char* dst_dir, } } if (closedir(dir) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Error closing %s: %s", - src_dir, strerror(errno)); + ICING_LOG(ERROR) << "Error closing " << src_dir << ": " << strerror(errno); } return true; } @@ -535,8 +514,7 @@ bool Filesystem::PWrite(int fd, off_t offset, const void* data, size_t chunk_size = std::min<size_t>(write_len, 64u * 1024); ssize_t wrote = pwrite(fd, data, chunk_size, offset); if (wrote < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad write: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad write: " << strerror(errno); return false; } data = static_cast<const uint8_t*>(data) + wrote; @@ -561,8 +539,7 @@ bool Filesystem::PWrite(const char* filename, off_t offset, const void* data, bool Filesystem::Read(int fd, void* buf, size_t buf_size) const { ssize_t read_status = read(fd, buf, buf_size); if (read_status < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad read: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad read: " << strerror(errno); return false; } return true; @@ -582,8 +559,7 @@ bool Filesystem::Read(const char* filename, void* buf, size_t buf_size) const { bool Filesystem::PRead(int fd, void* buf, size_t buf_size, off_t offset) const { ssize_t read_status = pread(fd, buf, buf_size, offset); if (read_status < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad read: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad read: " << strerror(errno); return false; } return true; @@ -609,8 +585,7 @@ bool Filesystem::DataSync(int fd) const { #endif if (result < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to sync data: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to sync data: " << strerror(errno); return false; } return true; @@ -618,9 +593,7 @@ bool Filesystem::DataSync(int fd) const { bool Filesystem::RenameFile(const char* old_name, const char* new_name) const { if (rename(old_name, new_name) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to rename file %s to %s: %s", old_name, new_name, - strerror(errno)); + ICING_LOG(ERROR) << "Unable to rename file " << old_name << " to " << new_name << ": " << strerror(errno); return false; } return true; @@ -658,8 +631,7 @@ bool Filesystem::CreateDirectory(const char* dir_name) const { if (mkdir(dir_name, S_IRUSR | S_IWUSR | S_IXUSR) == 0) { success = true; } else { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Creating directory %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Creating directory " << dir_name << " failed: " << strerror(errno); } } return success; @@ -679,8 +651,7 @@ bool Filesystem::CreateDirectoryRecursively(const char* dir_name) const { int64_t Filesystem::GetDiskUsage(int fd) const { struct stat st; if (fstat(fd, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file: " << strerror(errno); return kBadFileSize; } return st.st_blocks * kStatBlockSize; @@ -689,8 +660,7 @@ int64_t Filesystem::GetDiskUsage(int fd) const { int64_t Filesystem::GetFileDiskUsage(const char* path) const { struct stat st; if (stat(path, &st) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat %s: %s", - path, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat " << path << ": " << strerror(errno); return kBadFileSize; } return st.st_blocks * kStatBlockSize; @@ -699,8 +669,7 @@ int64_t Filesystem::GetFileDiskUsage(const char* path) const { int64_t Filesystem::GetDiskUsage(const char* path) const { struct stat st; if (stat(path, &st) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat %s: %s", - path, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat " << path << ": " << strerror(errno); return kBadFileSize; } int64_t result = st.st_blocks * kStatBlockSize; diff --git a/icing/file/persistent-hash-map.cc b/icing/file/persistent-hash-map.cc new file mode 100644 index 0000000..d20285a --- /dev/null +++ b/icing/file/persistent-hash-map.cc @@ -0,0 +1,534 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/file/persistent-hash-map.h" + +#include <cstring> +#include <memory> +#include <string> +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/absl_ports/canonical_errors.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/util/crc32.h" +#include "icing/util/status-macros.h" + +namespace icing { +namespace lib { + +namespace { + +// Helper function to check if there is no termination character '\0' in the +// key. +libtextclassifier3::Status ValidateKey(std::string_view key) { + if (key.find('\0') != std::string_view::npos) { // NOLINT + return absl_ports::InvalidArgumentError( + "Key cannot contain termination character '\\0'"); + } + return libtextclassifier3::Status::OK; +} + +// Helper function to convert the key to bucket index by hash. +// +// Returns: +// int32_t: A valid bucket index with range [0, num_buckets - 1]. +// INTERNAL_ERROR if num_buckets == 0 +libtextclassifier3::StatusOr<int32_t> HashKeyToBucketIndex( + std::string_view key, int32_t num_buckets) { + if (num_buckets == 0) { + return absl_ports::InternalError("Should not have empty bucket"); + } + return static_cast<int32_t>(std::hash<std::string_view>()(key) % num_buckets); +} + +// Helper function to PWrite crcs and info to metadata_file_path. Note that +// metadata_file_path will be the normal or temporary (for branching use when +// rehashing) metadata file path. +libtextclassifier3::Status WriteMetadata(const Filesystem& filesystem, + const char* metadata_file_path, + const PersistentHashMap::Crcs* crcs, + const PersistentHashMap::Info* info) { + ScopedFd sfd(filesystem.OpenForWrite(metadata_file_path)); + if (!sfd.is_valid()) { + return absl_ports::InternalError("Failed to create metadata file"); + } + + // Write crcs and info. File layout: <Crcs><Info> + if (!filesystem.PWrite(sfd.get(), PersistentHashMap::Crcs::kFileOffset, crcs, + sizeof(PersistentHashMap::Crcs))) { + return absl_ports::InternalError("Failed to write crcs into metadata file"); + } + // Note that PWrite won't change the file offset, so we need to specify + // the correct offset when writing Info. + if (!filesystem.PWrite(sfd.get(), PersistentHashMap::Info::kFileOffset, info, + sizeof(PersistentHashMap::Info))) { + return absl_ports::InternalError("Failed to write info into metadata file"); + } + + return libtextclassifier3::Status::OK; +} + +// Helper function to update checksums from info and storages to a Crcs +// instance. Note that storages will be the normal instances used by +// PersistentHashMap, or the temporary instances (for branching use when +// rehashing). +libtextclassifier3::Status UpdateChecksums( + PersistentHashMap::Crcs* crcs, PersistentHashMap::Info* info, + FileBackedVector<PersistentHashMap::Bucket>* bucket_storage, + FileBackedVector<PersistentHashMap::Entry>* entry_storage, + FileBackedVector<char>* kv_storage) { + // Compute crcs + ICING_ASSIGN_OR_RETURN(Crc32 bucket_storage_crc, + bucket_storage->ComputeChecksum()); + ICING_ASSIGN_OR_RETURN(Crc32 entry_storage_crc, + entry_storage->ComputeChecksum()); + ICING_ASSIGN_OR_RETURN(Crc32 kv_storage_crc, kv_storage->ComputeChecksum()); + + crcs->component_crcs.info_crc = info->ComputeChecksum().Get(); + crcs->component_crcs.bucket_storage_crc = bucket_storage_crc.Get(); + crcs->component_crcs.entry_storage_crc = entry_storage_crc.Get(); + crcs->component_crcs.kv_storage_crc = kv_storage_crc.Get(); + crcs->all_crc = crcs->component_crcs.ComputeChecksum().Get(); + + return libtextclassifier3::Status::OK; +} + +// Helper function to validate checksums. +libtextclassifier3::Status ValidateChecksums( + const PersistentHashMap::Crcs* crcs, const PersistentHashMap::Info* info, + FileBackedVector<PersistentHashMap::Bucket>* bucket_storage, + FileBackedVector<PersistentHashMap::Entry>* entry_storage, + FileBackedVector<char>* kv_storage) { + if (crcs->all_crc != crcs->component_crcs.ComputeChecksum().Get()) { + return absl_ports::FailedPreconditionError( + "Invalid all crc for PersistentHashMap"); + } + + if (crcs->component_crcs.info_crc != info->ComputeChecksum().Get()) { + return absl_ports::FailedPreconditionError( + "Invalid info crc for PersistentHashMap"); + } + + ICING_ASSIGN_OR_RETURN(Crc32 bucket_storage_crc, + bucket_storage->ComputeChecksum()); + if (crcs->component_crcs.bucket_storage_crc != bucket_storage_crc.Get()) { + return absl_ports::FailedPreconditionError( + "Mismatch crc with PersistentHashMap bucket storage"); + } + + ICING_ASSIGN_OR_RETURN(Crc32 entry_storage_crc, + entry_storage->ComputeChecksum()); + if (crcs->component_crcs.entry_storage_crc != entry_storage_crc.Get()) { + return absl_ports::FailedPreconditionError( + "Mismatch crc with PersistentHashMap entry storage"); + } + + ICING_ASSIGN_OR_RETURN(Crc32 kv_storage_crc, kv_storage->ComputeChecksum()); + if (crcs->component_crcs.kv_storage_crc != kv_storage_crc.Get()) { + return absl_ports::FailedPreconditionError( + "Mismatch crc with PersistentHashMap key value storage"); + } + + return libtextclassifier3::Status::OK; +} + +// Since metadata/bucket/entry storages should be branched when rehashing, we +// have to store them together under the same sub directory +// ("<base_dir>/<sub_dir>"). On the other hand, key-value storage won't be +// branched and it will be stored under <base_dir>. +// +// The following 4 methods are helper functions to get the correct path of +// metadata/bucket/entry/key-value storages, according to the given base +// directory and sub directory. +std::string GetMetadataFilePath(std::string_view base_dir, + std::string_view sub_dir) { + return absl_ports::StrCat(base_dir, "/", sub_dir, "/", + PersistentHashMap::kFilePrefix, ".m"); +} + +std::string GetBucketStorageFilePath(std::string_view base_dir, + std::string_view sub_dir) { + return absl_ports::StrCat(base_dir, "/", sub_dir, "/", + PersistentHashMap::kFilePrefix, ".b"); +} + +std::string GetEntryStorageFilePath(std::string_view base_dir, + std::string_view sub_dir) { + return absl_ports::StrCat(base_dir, "/", sub_dir, "/", + PersistentHashMap::kFilePrefix, ".e"); +} + +std::string GetKeyValueStorageFilePath(std::string_view base_dir) { + return absl_ports::StrCat(base_dir, "/", PersistentHashMap::kFilePrefix, + ".k"); +} + +} // namespace + +/* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> +PersistentHashMap::Create(const Filesystem& filesystem, + std::string_view base_dir, int32_t value_type_size, + int32_t max_load_factor_percent) { + if (!filesystem.FileExists( + GetMetadataFilePath(base_dir, kSubDirectory).c_str()) || + !filesystem.FileExists( + GetBucketStorageFilePath(base_dir, kSubDirectory).c_str()) || + !filesystem.FileExists( + GetEntryStorageFilePath(base_dir, kSubDirectory).c_str()) || + !filesystem.FileExists(GetKeyValueStorageFilePath(base_dir).c_str())) { + // TODO: erase all files if missing any. + return InitializeNewFiles(filesystem, base_dir, value_type_size, + max_load_factor_percent); + } + return InitializeExistingFiles(filesystem, base_dir, value_type_size, + max_load_factor_percent); +} + +PersistentHashMap::~PersistentHashMap() { + if (!PersistToDisk().ok()) { + ICING_LOG(WARNING) + << "Failed to persist hash map to disk while destructing " << base_dir_; + } +} + +libtextclassifier3::Status PersistentHashMap::Put(std::string_view key, + const void* value) { + ICING_RETURN_IF_ERROR(ValidateKey(key)); + ICING_ASSIGN_OR_RETURN( + int32_t bucket_idx, + HashKeyToBucketIndex(key, bucket_storage_->num_elements())); + + ICING_ASSIGN_OR_RETURN(int32_t target_entry_idx, + FindEntryIndexByKey(bucket_idx, key)); + if (target_entry_idx == Entry::kInvalidIndex) { + // If not found, then insert new key value pair. + return Insert(bucket_idx, key, value); + } + + // Otherwise, overwrite the value. + ICING_ASSIGN_OR_RETURN(const Entry* entry, + entry_storage_->Get(target_entry_idx)); + + int32_t kv_len = key.length() + 1 + info()->value_type_size; + int32_t value_offset = key.length() + 1; + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<char>::MutableArrayView mutable_kv_arr, + kv_storage_->GetMutable(entry->key_value_index(), kv_len)); + // It is the same key and value_size is fixed, so we can directly overwrite + // serialized value. + mutable_kv_arr.SetArray(value_offset, reinterpret_cast<const char*>(value), + info()->value_type_size); + + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status PersistentHashMap::GetOrPut(std::string_view key, + void* next_value) { + ICING_RETURN_IF_ERROR(ValidateKey(key)); + ICING_ASSIGN_OR_RETURN( + int32_t bucket_idx, + HashKeyToBucketIndex(key, bucket_storage_->num_elements())); + + ICING_ASSIGN_OR_RETURN(int32_t target_entry_idx, + FindEntryIndexByKey(bucket_idx, key)); + if (target_entry_idx == Entry::kInvalidIndex) { + // If not found, then insert new key value pair. + return Insert(bucket_idx, key, next_value); + } + + // Otherwise, copy the hash map value into next_value. + return CopyEntryValue(target_entry_idx, next_value); +} + +libtextclassifier3::Status PersistentHashMap::Get(std::string_view key, + void* value) const { + ICING_RETURN_IF_ERROR(ValidateKey(key)); + ICING_ASSIGN_OR_RETURN( + int32_t bucket_idx, + HashKeyToBucketIndex(key, bucket_storage_->num_elements())); + + ICING_ASSIGN_OR_RETURN(int32_t target_entry_idx, + FindEntryIndexByKey(bucket_idx, key)); + if (target_entry_idx == Entry::kInvalidIndex) { + return absl_ports::NotFoundError( + absl_ports::StrCat("Key not found in PersistentHashMap ", base_dir_)); + } + + return CopyEntryValue(target_entry_idx, value); +} + +libtextclassifier3::Status PersistentHashMap::PersistToDisk() { + ICING_RETURN_IF_ERROR(bucket_storage_->PersistToDisk()); + ICING_RETURN_IF_ERROR(entry_storage_->PersistToDisk()); + ICING_RETURN_IF_ERROR(kv_storage_->PersistToDisk()); + + ICING_RETURN_IF_ERROR(UpdateChecksums(crcs(), info(), bucket_storage_.get(), + entry_storage_.get(), + kv_storage_.get())); + // Changes should have been applied to the underlying file when using + // MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC, but call msync() as an + // extra safety step to ensure they are written out. + ICING_RETURN_IF_ERROR(metadata_mmapped_file_->PersistToDisk()); + + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::StatusOr<int64_t> PersistentHashMap::GetDiskUsage() const { + ICING_ASSIGN_OR_RETURN(int64_t bucket_storage_disk_usage, + bucket_storage_->GetDiskUsage()); + ICING_ASSIGN_OR_RETURN(int64_t entry_storage_disk_usage, + entry_storage_->GetDiskUsage()); + ICING_ASSIGN_OR_RETURN(int64_t kv_storage_disk_usage, + kv_storage_->GetDiskUsage()); + + int64_t total = bucket_storage_disk_usage + entry_storage_disk_usage + + kv_storage_disk_usage; + Filesystem::IncrementByOrSetInvalid( + filesystem_->GetDiskUsage( + GetMetadataFilePath(base_dir_, kSubDirectory).c_str()), + &total); + + if (total < 0 || total == Filesystem::kBadFileSize) { + return absl_ports::InternalError( + "Failed to get disk usage of PersistentHashMap"); + } + return total; +} + +libtextclassifier3::StatusOr<int64_t> PersistentHashMap::GetElementsSize() + const { + ICING_ASSIGN_OR_RETURN(int64_t bucket_storage_elements_size, + bucket_storage_->GetElementsFileSize()); + ICING_ASSIGN_OR_RETURN(int64_t entry_storage_elements_size, + entry_storage_->GetElementsFileSize()); + ICING_ASSIGN_OR_RETURN(int64_t kv_storage_elements_size, + kv_storage_->GetElementsFileSize()); + return bucket_storage_elements_size + entry_storage_elements_size + + kv_storage_elements_size; +} + +libtextclassifier3::StatusOr<Crc32> PersistentHashMap::ComputeChecksum() { + Crcs* crcs_ptr = crcs(); + ICING_RETURN_IF_ERROR(UpdateChecksums(crcs_ptr, info(), bucket_storage_.get(), + entry_storage_.get(), + kv_storage_.get())); + return Crc32(crcs_ptr->all_crc); +} + +/* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> +PersistentHashMap::InitializeNewFiles(const Filesystem& filesystem, + std::string_view base_dir, + int32_t value_type_size, + int32_t max_load_factor_percent) { + // Create directory. + const std::string dir_path = absl_ports::StrCat(base_dir, "/", kSubDirectory); + if (!filesystem.CreateDirectoryRecursively(dir_path.c_str())) { + return absl_ports::InternalError( + absl_ports::StrCat("Failed to create directory: ", dir_path)); + } + + // Initialize 3 storages + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<FileBackedVector<Bucket>> bucket_storage, + FileBackedVector<Bucket>::Create( + filesystem, GetBucketStorageFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<FileBackedVector<Entry>> entry_storage, + FileBackedVector<Entry>::Create( + filesystem, GetEntryStorageFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<FileBackedVector<char>> kv_storage, + FileBackedVector<char>::Create( + filesystem, GetKeyValueStorageFilePath(base_dir), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + // Initialize one bucket. + ICING_RETURN_IF_ERROR(bucket_storage->Append(Bucket())); + ICING_RETURN_IF_ERROR(bucket_storage->PersistToDisk()); + + // Create and initialize new info + Info new_info; + new_info.version = kVersion; + new_info.value_type_size = value_type_size; + new_info.max_load_factor_percent = max_load_factor_percent; + new_info.num_deleted_entries = 0; + new_info.num_deleted_key_value_bytes = 0; + + // Compute checksums + Crcs new_crcs; + ICING_RETURN_IF_ERROR(UpdateChecksums(&new_crcs, &new_info, + bucket_storage.get(), + entry_storage.get(), kv_storage.get())); + + const std::string metadata_file_path = + GetMetadataFilePath(base_dir, kSubDirectory); + // Write new metadata file + ICING_RETURN_IF_ERROR(WriteMetadata(filesystem, metadata_file_path.c_str(), + &new_crcs, &new_info)); + + // Mmap the content of the crcs and info. + auto metadata_mmapped_file = std::make_unique<MemoryMappedFile>( + filesystem, metadata_file_path, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC); + ICING_RETURN_IF_ERROR(metadata_mmapped_file->Remap( + /*file_offset=*/0, /*mmap_size=*/sizeof(Crcs) + sizeof(Info))); + + return std::unique_ptr<PersistentHashMap>(new PersistentHashMap( + filesystem, base_dir, std::move(metadata_mmapped_file), + std::move(bucket_storage), std::move(entry_storage), + std::move(kv_storage))); +} + +/* static */ libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> +PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem, + std::string_view base_dir, + int32_t value_type_size, + int32_t max_load_factor_percent) { + // Mmap the content of the crcs and info. + auto metadata_mmapped_file = std::make_unique<MemoryMappedFile>( + filesystem, GetMetadataFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC); + ICING_RETURN_IF_ERROR(metadata_mmapped_file->Remap( + /*file_offset=*/0, /*mmap_size=*/sizeof(Crcs) + sizeof(Info))); + + // Initialize 3 storages + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<FileBackedVector<Bucket>> bucket_storage, + FileBackedVector<Bucket>::Create( + filesystem, GetBucketStorageFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<FileBackedVector<Entry>> entry_storage, + FileBackedVector<Entry>::Create( + filesystem, GetEntryStorageFilePath(base_dir, kSubDirectory), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<FileBackedVector<char>> kv_storage, + FileBackedVector<char>::Create( + filesystem, GetKeyValueStorageFilePath(base_dir), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + Crcs* crcs_ptr = reinterpret_cast<Crcs*>( + metadata_mmapped_file->mutable_region() + Crcs::kFileOffset); + Info* info_ptr = reinterpret_cast<Info*>( + metadata_mmapped_file->mutable_region() + Info::kFileOffset); + + // Value type size should be consistent. + if (value_type_size != info_ptr->value_type_size) { + return absl_ports::FailedPreconditionError("Incorrect value type size"); + } + + // Validate checksums of info and 3 storages. + ICING_RETURN_IF_ERROR( + ValidateChecksums(crcs_ptr, info_ptr, bucket_storage.get(), + entry_storage.get(), kv_storage.get())); + + // Allow max_load_factor_percent_ change. + if (max_load_factor_percent != info_ptr->max_load_factor_percent) { + ICING_VLOG(2) << "Changing max_load_factor_percent from " << info_ptr->max_load_factor_percent << " to " << max_load_factor_percent; + + info_ptr->max_load_factor_percent = max_load_factor_percent; + crcs_ptr->component_crcs.info_crc = info_ptr->ComputeChecksum().Get(); + crcs_ptr->all_crc = crcs_ptr->component_crcs.ComputeChecksum().Get(); + ICING_RETURN_IF_ERROR(metadata_mmapped_file->PersistToDisk()); + // TODO(b/193919210): rehash if needed + } + + return std::unique_ptr<PersistentHashMap>(new PersistentHashMap( + filesystem, base_dir, std::move(metadata_mmapped_file), + std::move(bucket_storage), std::move(entry_storage), + std::move(kv_storage))); +} + +libtextclassifier3::StatusOr<int32_t> PersistentHashMap::FindEntryIndexByKey( + int32_t bucket_idx, std::string_view key) const { + // Iterate all entries in the bucket, compare with key, and return the entry + // index if exists. + ICING_ASSIGN_OR_RETURN(const Bucket* bucket, + bucket_storage_->Get(bucket_idx)); + int32_t curr_entry_idx = bucket->head_entry_index(); + while (curr_entry_idx != Entry::kInvalidIndex) { + ICING_ASSIGN_OR_RETURN(const Entry* entry, + entry_storage_->Get(curr_entry_idx)); + if (entry->key_value_index() == kInvalidKVIndex) { + ICING_LOG(ERROR) << "Got an invalid key value index in the persistent " + "hash map bucket. This shouldn't happen"; + return absl_ports::InternalError("Unexpected invalid key value index"); + } + ICING_ASSIGN_OR_RETURN(const char* kv_arr, + kv_storage_->Get(entry->key_value_index())); + if (key.compare(kv_arr) == 0) { + return curr_entry_idx; + } + + curr_entry_idx = entry->next_entry_index(); + } + + return curr_entry_idx; +} + +libtextclassifier3::Status PersistentHashMap::CopyEntryValue( + int32_t entry_idx, void* value) const { + ICING_ASSIGN_OR_RETURN(const Entry* entry, entry_storage_->Get(entry_idx)); + + ICING_ASSIGN_OR_RETURN(const char* kv_arr, + kv_storage_->Get(entry->key_value_index())); + int32_t value_offset = strlen(kv_arr) + 1; + memcpy(value, kv_arr + value_offset, info()->value_type_size); + + return libtextclassifier3::Status::OK; +} + +libtextclassifier3::Status PersistentHashMap::Insert(int32_t bucket_idx, + std::string_view key, + const void* value) { + // If size() + 1 exceeds Entry::kMaxNumEntries, then return error. + if (size() > Entry::kMaxNumEntries - 1) { + return absl_ports::ResourceExhaustedError("Cannot insert new entry"); + } + + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<Bucket>::MutableView mutable_bucket, + bucket_storage_->GetMutable(bucket_idx)); + + // Append new key value. + int32_t new_kv_idx = kv_storage_->num_elements(); + int32_t kv_len = key.size() + 1 + info()->value_type_size; + int32_t value_offset = key.size() + 1; + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<char>::MutableArrayView mutable_new_kv_arr, + kv_storage_->Allocate(kv_len)); + mutable_new_kv_arr.SetArray(/*idx=*/0, key.data(), key.size()); + mutable_new_kv_arr.SetArray(/*idx=*/key.size(), "\0", 1); + mutable_new_kv_arr.SetArray(/*idx=*/value_offset, + reinterpret_cast<const char*>(value), + info()->value_type_size); + + // Append new entry. + int32_t new_entry_idx = entry_storage_->num_elements(); + ICING_RETURN_IF_ERROR(entry_storage_->Append( + Entry(new_kv_idx, mutable_bucket.Get().head_entry_index()))); + mutable_bucket.Get().set_head_entry_index(new_entry_idx); + + // TODO: rehash if needed + + return libtextclassifier3::Status::OK; +} + +} // namespace lib +} // namespace icing diff --git a/icing/file/persistent-hash-map.h b/icing/file/persistent-hash-map.h new file mode 100644 index 0000000..24a47ea --- /dev/null +++ b/icing/file/persistent-hash-map.h @@ -0,0 +1,383 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_FILE_PERSISTENT_HASH_MAP_H_ +#define ICING_FILE_PERSISTENT_HASH_MAP_H_ + +#include <cstdint> +#include <memory> +#include <string> +#include <string_view> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/util/crc32.h" + +namespace icing { +namespace lib { + +// Low level persistent hash map. +// It supports variant length serialized key + fixed length serialized value. +// Key and value can be any type, but callers should serialize key/value by +// themselves and pass raw bytes into the hash map, and the serialized key +// should not contain termination character '\0'. +class PersistentHashMap { + public: + // Crcs and Info will be written into the metadata file. + // File layout: <Crcs><Info> + // Crcs + struct Crcs { + static constexpr int32_t kFileOffset = 0; + + struct ComponentCrcs { + uint32_t info_crc; + uint32_t bucket_storage_crc; + uint32_t entry_storage_crc; + uint32_t kv_storage_crc; + + bool operator==(const ComponentCrcs& other) const { + return info_crc == other.info_crc && + bucket_storage_crc == other.bucket_storage_crc && + entry_storage_crc == other.entry_storage_crc && + kv_storage_crc == other.kv_storage_crc; + } + + Crc32 ComputeChecksum() const { + return Crc32(std::string_view(reinterpret_cast<const char*>(this), + sizeof(ComponentCrcs))); + } + } __attribute__((packed)); + + bool operator==(const Crcs& other) const { + return all_crc == other.all_crc && component_crcs == other.component_crcs; + } + + uint32_t all_crc; + ComponentCrcs component_crcs; + } __attribute__((packed)); + static_assert(sizeof(Crcs) == 20, ""); + + // Info + struct Info { + static constexpr int32_t kFileOffset = static_cast<int32_t>(sizeof(Crcs)); + + int32_t version; + int32_t value_type_size; + int32_t max_load_factor_percent; + int32_t num_deleted_entries; + int32_t num_deleted_key_value_bytes; + + Crc32 ComputeChecksum() const { + return Crc32( + std::string_view(reinterpret_cast<const char*>(this), sizeof(Info))); + } + } __attribute__((packed)); + static_assert(sizeof(Info) == 20, ""); + + // Bucket + class Bucket { + public: + // Absolute max # of buckets allowed. Since max file size on Android is + // 2^31-1, we can at most have ~2^29 buckets. To make it power of 2, round + // it down to 2^28. Also since we're using FileBackedVector to store + // buckets, add some static_asserts to ensure numbers here are compatible + // with FileBackedVector. + static constexpr int32_t kMaxNumBuckets = 1 << 28; + + explicit Bucket(int32_t head_entry_index = Entry::kInvalidIndex) + : head_entry_index_(head_entry_index) {} + + // For FileBackedVector + bool operator==(const Bucket& other) const { + return head_entry_index_ == other.head_entry_index_; + } + + int32_t head_entry_index() const { return head_entry_index_; } + void set_head_entry_index(int32_t head_entry_index) { + head_entry_index_ = head_entry_index; + } + + private: + int32_t head_entry_index_; + } __attribute__((packed)); + static_assert(sizeof(Bucket) == 4, ""); + static_assert(sizeof(Bucket) == FileBackedVector<Bucket>::kElementTypeSize, + "Bucket type size is inconsistent with FileBackedVector " + "element type size"); + static_assert(Bucket::kMaxNumBuckets <= + (FileBackedVector<Bucket>::kMaxFileSize - + FileBackedVector<Bucket>::Header::kHeaderSize) / + FileBackedVector<Bucket>::kElementTypeSize, + "Max # of buckets cannot fit into FileBackedVector"); + + // Entry + class Entry { + public: + // Absolute max # of entries allowed. Since max file size on Android is + // 2^31-1, we can at most have ~2^28 entries. To make it power of 2, round + // it down to 2^27. Also since we're using FileBackedVector to store + // entries, add some static_asserts to ensure numbers here are compatible + // with FileBackedVector. + // + // Still the actual max # of entries are determined by key-value storage, + // since length of the key varies and affects # of actual key-value pairs + // that can be stored. + static constexpr int32_t kMaxNumEntries = 1 << 27; + static constexpr int32_t kMaxIndex = kMaxNumEntries - 1; + static constexpr int32_t kInvalidIndex = -1; + + explicit Entry(int32_t key_value_index, int32_t next_entry_index) + : key_value_index_(key_value_index), + next_entry_index_(next_entry_index) {} + + bool operator==(const Entry& other) const { + return key_value_index_ == other.key_value_index_ && + next_entry_index_ == other.next_entry_index_; + } + + int32_t key_value_index() const { return key_value_index_; } + void set_key_value_index(int32_t key_value_index) { + key_value_index_ = key_value_index; + } + + int32_t next_entry_index() const { return next_entry_index_; } + void set_next_entry_index(int32_t next_entry_index) { + next_entry_index_ = next_entry_index; + } + + private: + int32_t key_value_index_; + int32_t next_entry_index_; + } __attribute__((packed)); + static_assert(sizeof(Entry) == 8, ""); + static_assert(sizeof(Entry) == FileBackedVector<Entry>::kElementTypeSize, + "Entry type size is inconsistent with FileBackedVector " + "element type size"); + static_assert(Entry::kMaxNumEntries <= + (FileBackedVector<Entry>::kMaxFileSize - + FileBackedVector<Entry>::Header::kHeaderSize) / + FileBackedVector<Entry>::kElementTypeSize, + "Max # of entries cannot fit into FileBackedVector"); + + // Key-value serialized type + static constexpr int32_t kMaxKVTotalByteSize = + (FileBackedVector<char>::kMaxFileSize - + FileBackedVector<char>::Header::kHeaderSize) / + FileBackedVector<char>::kElementTypeSize; + static constexpr int32_t kMaxKVIndex = kMaxKVTotalByteSize - 1; + static constexpr int32_t kInvalidKVIndex = -1; + static_assert(sizeof(char) == FileBackedVector<char>::kElementTypeSize, + "Char type size is inconsistent with FileBackedVector element " + "type size"); + + static constexpr int32_t kVersion = 1; + static constexpr int32_t kDefaultMaxLoadFactorPercent = 75; + + static constexpr std::string_view kFilePrefix = "persistent_hash_map"; + // Only metadata, bucket, entry files are stored under this sub-directory, for + // rehashing branching use. + static constexpr std::string_view kSubDirectory = "dynamic"; + + // Creates a new PersistentHashMap to read/write/delete key value pairs. + // + // filesystem: Object to make system level calls + // base_dir: Specifies the directory for all persistent hash map related + // sub-directory and files to be stored. If base_dir doesn't exist, + // then PersistentHashMap will automatically create it. If files + // exist, then it will initialize the hash map from existing files. + // value_type_size: (fixed) size of the serialized value type for hash map. + // max_load_factor_percent: percentage of the max loading for the hash map. + // load_factor_percent = 100 * num_keys / num_buckets + // If load_factor_percent exceeds + // max_load_factor_percent, then rehash will be + // invoked (and # of buckets will be doubled). + // Note that load_factor_percent exceeding 100 is + // considered valid. + // + // Returns: + // FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored + // checksum. + // INTERNAL_ERROR on I/O errors. + // Any FileBackedVector errors. + static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + Create(const Filesystem& filesystem, std::string_view base_dir, + int32_t value_type_size, + int32_t max_load_factor_percent = kDefaultMaxLoadFactorPercent); + + ~PersistentHashMap(); + + // Update a key value pair. If key does not exist, then insert (key, value) + // into the storage. Otherwise overwrite the value into the storage. + // + // REQUIRES: the buffer pointed to by value must be of value_size() + // + // Returns: + // OK on success + // RESOURCE_EXHAUSTED_ERROR if # of entries reach kMaxNumEntries + // INVALID_ARGUMENT_ERROR if the key is invalid (i.e. contains '\0') + // INTERNAL_ERROR on I/O error or any data inconsistency + // Any FileBackedVector errors + libtextclassifier3::Status Put(std::string_view key, const void* value); + + // If key does not exist, then insert (key, next_value) into the storage. + // Otherwise, copy the hash map value into next_value. + // + // REQUIRES: the buffer pointed to by next_value must be of value_size() + // + // Returns: + // OK on success + // INVALID_ARGUMENT_ERROR if the key is invalid (i.e. contains '\0') + // INTERNAL_ERROR on I/O error or any data inconsistency + // Any FileBackedVector errors + libtextclassifier3::Status GetOrPut(std::string_view key, void* next_value); + + // Get the value by key from the storage. If key exists, then copy the hash + // map value into into value buffer. Otherwise, return NOT_FOUND_ERROR. + // + // REQUIRES: the buffer pointed to by value must be of value_size() + // + // Returns: + // OK on success + // NOT_FOUND_ERROR if the key doesn't exist + // INVALID_ARGUMENT_ERROR if the key is invalid (i.e. contains '\0') + // INTERNAL_ERROR on I/O error or any data inconsistency + // Any FileBackedVector errors + libtextclassifier3::Status Get(std::string_view key, void* value) const; + + // Flushes content to underlying files. + // + // Returns: + // OK on success + // INTERNAL_ERROR on I/O error + libtextclassifier3::Status PersistToDisk(); + + // Calculates and returns the disk usage (metadata + 3 storages total file + // size) in bytes. + // + // Returns: + // Disk usage on success + // INTERNAL_ERROR on I/O error + libtextclassifier3::StatusOr<int64_t> GetDiskUsage() const; + + // Returns the total file size of the all the elements held in the persistent + // hash map. File size is in bytes. This excludes the size of any internal + // metadata, i.e. crcs/info of persistent hash map, file backed vector's + // header. + // + // Returns: + // File size on success + // INTERNAL_ERROR on I/O error + libtextclassifier3::StatusOr<int64_t> GetElementsSize() const; + + // Updates all checksums of the persistent hash map components and returns + // all_crc. + // + // Returns: + // Crc of all components (all_crc) on success + // INTERNAL_ERROR if any data inconsistency + libtextclassifier3::StatusOr<Crc32> ComputeChecksum(); + + int32_t size() const { + return entry_storage_->num_elements() - info()->num_deleted_entries; + } + + bool empty() const { return size() == 0; } + + private: + explicit PersistentHashMap( + const Filesystem& filesystem, std::string_view base_dir, + std::unique_ptr<MemoryMappedFile> metadata_mmapped_file, + std::unique_ptr<FileBackedVector<Bucket>> bucket_storage, + std::unique_ptr<FileBackedVector<Entry>> entry_storage, + std::unique_ptr<FileBackedVector<char>> kv_storage) + : filesystem_(&filesystem), + base_dir_(base_dir), + metadata_mmapped_file_(std::move(metadata_mmapped_file)), + bucket_storage_(std::move(bucket_storage)), + entry_storage_(std::move(entry_storage)), + kv_storage_(std::move(kv_storage)) {} + + static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + InitializeNewFiles(const Filesystem& filesystem, std::string_view base_dir, + int32_t value_type_size, int32_t max_load_factor_percent); + + static libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + InitializeExistingFiles(const Filesystem& filesystem, + std::string_view base_dir, int32_t value_type_size, + int32_t max_load_factor_percent); + + // Find the index of the key entry from a bucket (specified by bucket index). + // The caller should specify the desired bucket index. + // + // Returns: + // int32_t: on success, the index of the entry, or Entry::kInvalidIndex if + // not found + // INTERNAL_ERROR if any content inconsistency + // Any FileBackedVector errors + libtextclassifier3::StatusOr<int32_t> FindEntryIndexByKey( + int32_t bucket_idx, std::string_view key) const; + + // Copy the hash map value of the entry into value buffer. + // + // REQUIRES: entry_idx should be valid. + // REQUIRES: the buffer pointed to by value must be of value_size() + // + // Returns: + // OK on success + // Any FileBackedVector errors + libtextclassifier3::Status CopyEntryValue(int32_t entry_idx, + void* value) const; + + // Insert a new key value pair into a bucket (specified by the bucket index). + // The caller should specify the desired bucket index and make sure that the + // key is not present in the hash map before calling. + // + // Returns: + // OK on success + // Any FileBackedVector errors + libtextclassifier3::Status Insert(int32_t bucket_idx, std::string_view key, + const void* value); + + Crcs* crcs() { + return reinterpret_cast<Crcs*>(metadata_mmapped_file_->mutable_region() + + Crcs::kFileOffset); + } + + Info* info() { + return reinterpret_cast<Info*>(metadata_mmapped_file_->mutable_region() + + Info::kFileOffset); + } + + const Info* info() const { + return reinterpret_cast<const Info*>(metadata_mmapped_file_->region() + + Info::kFileOffset); + } + + const Filesystem* filesystem_; + std::string base_dir_; + + std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_; + + // Storages + std::unique_ptr<FileBackedVector<Bucket>> bucket_storage_; + std::unique_ptr<FileBackedVector<Entry>> entry_storage_; + std::unique_ptr<FileBackedVector<char>> kv_storage_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_FILE_PERSISTENT_HASH_MAP_H_ diff --git a/icing/file/persistent-hash-map_test.cc b/icing/file/persistent-hash-map_test.cc new file mode 100644 index 0000000..fb15175 --- /dev/null +++ b/icing/file/persistent-hash-map_test.cc @@ -0,0 +1,662 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/file/persistent-hash-map.h" + +#include <cstring> +#include <vector> + +#include "icing/text_classifier/lib3/utils/base/status.h" +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/tmp-directory.h" +#include "icing/util/crc32.h" + +namespace icing { +namespace lib { + +namespace { + +static constexpr int32_t kCorruptedValueOffset = 3; + +using ::testing::Eq; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::Pointee; +using ::testing::SizeIs; + +using Bucket = PersistentHashMap::Bucket; +using Crcs = PersistentHashMap::Crcs; +using Entry = PersistentHashMap::Entry; +using Info = PersistentHashMap::Info; + +class PersistentHashMapTest : public ::testing::Test { + protected: + void SetUp() override { + base_dir_ = GetTestTempDir() + "/persistent_hash_map_test"; + } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(base_dir_.c_str()); + } + + std::vector<char> Serialize(int val) { + std::vector<char> ret(sizeof(val)); + memcpy(ret.data(), &val, sizeof(val)); + return ret; + } + + libtextclassifier3::StatusOr<int> GetValueByKey( + PersistentHashMap* persistent_hash_map, std::string_view key) { + int val; + ICING_RETURN_IF_ERROR(persistent_hash_map->Get(key, &val)); + return val; + } + + Filesystem filesystem_; + std::string base_dir_; +}; + +TEST_F(PersistentHashMapTest, InvalidBaseDir) { + EXPECT_THAT(PersistentHashMap::Create(filesystem_, "/dev/null", + /*value_type_size=*/sizeof(int)), + StatusIs(libtextclassifier3::StatusCode::INTERNAL)); +} + +TEST_F(PersistentHashMapTest, InitializeNewFiles) { + { + ASSERT_FALSE(filesystem_.DirectoryExists(base_dir_.c_str())); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + EXPECT_THAT(persistent_hash_map, Pointee(IsEmpty())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + // Metadata file should be initialized correctly for both info and crcs + // sections. + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + // Check info section + Info info; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &info, sizeof(Info), + Info::kFileOffset)); + EXPECT_THAT(info.version, Eq(PersistentHashMap::kVersion)); + EXPECT_THAT(info.value_type_size, Eq(sizeof(int))); + EXPECT_THAT(info.max_load_factor_percent, + Eq(PersistentHashMap::kDefaultMaxLoadFactorPercent)); + EXPECT_THAT(info.num_deleted_entries, Eq(0)); + EXPECT_THAT(info.num_deleted_key_value_bytes, Eq(0)); + + // Check crcs section + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + // # of elements in bucket_storage should be 1, so it should have non-zero + // crc value. + EXPECT_THAT(crcs.component_crcs.bucket_storage_crc, Not(Eq(0))); + // Other empty file backed vectors should have 0 crc value. + EXPECT_THAT(crcs.component_crcs.entry_storage_crc, Eq(0)); + EXPECT_THAT(crcs.component_crcs.kv_storage_crc, Eq(0)); + EXPECT_THAT(crcs.component_crcs.info_crc, + Eq(Crc32(std::string_view(reinterpret_cast<const char*>(&info), + sizeof(Info))) + .Get())); + EXPECT_THAT(crcs.all_crc, + Eq(Crc32(std::string_view( + reinterpret_cast<const char*>(&crcs.component_crcs), + sizeof(Crcs::ComponentCrcs))) + .Get())); +} + +TEST_F(PersistentHashMapTest, + TestInitializationFailsWithoutPersistToDiskOrDestruction) { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing, to + // avoid PersistToDisk being called implicitly by rehashing. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + + // Put some key value pairs. + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); + // TODO(b/193919210): call Delete() to change PersistentHashMap header + + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + + // Without calling PersistToDisk, checksums will not be recomputed or synced + // to disk, so initializing another instance on the same files should fail. + EXPECT_THAT(PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(PersistentHashMapTest, TestInitializationSucceedsWithPersistToDisk) { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing, to + // avoid PersistToDisk being called implicitly by rehashing. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map1, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + + // Put some key value pairs. + ICING_ASSERT_OK(persistent_hash_map1->Put("a", Serialize(1).data())); + ICING_ASSERT_OK(persistent_hash_map1->Put("b", Serialize(2).data())); + // TODO(b/193919210): call Delete() to change PersistentHashMap header + + ASSERT_THAT(persistent_hash_map1, Pointee(SizeIs(2))); + ASSERT_THAT(GetValueByKey(persistent_hash_map1.get(), "a"), IsOkAndHolds(1)); + ASSERT_THAT(GetValueByKey(persistent_hash_map1.get(), "b"), IsOkAndHolds(2)); + + // After calling PersistToDisk, all checksums should be recomputed and synced + // correctly to disk, so initializing another instance on the same files + // should succeed, and we should be able to get the same contents. + ICING_EXPECT_OK(persistent_hash_map1->PersistToDisk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map2, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + EXPECT_THAT(persistent_hash_map2, Pointee(SizeIs(2))); + EXPECT_THAT(GetValueByKey(persistent_hash_map2.get(), "a"), IsOkAndHolds(1)); + EXPECT_THAT(GetValueByKey(persistent_hash_map2.get(), "b"), IsOkAndHolds(2)); +} + +TEST_F(PersistentHashMapTest, TestInitializationSucceedsAfterDestruction) { + { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing, to + // avoid PersistToDisk being called implicitly by rehashing. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); + // TODO(b/193919210): call Delete() to change PersistentHashMap header + + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + } + + { + // The previous instance went out of scope and was destructed. Although we + // didn't call PersistToDisk explicitly, the destructor should invoke it and + // thus initializing another instance on the same files should succeed, and + // we should be able to get the same contents. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesWithDifferentValueTypeSizeShouldFail) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + { + // Attempt to create the persistent hash map with different value type size. + // This should fail. + ASSERT_THAT(sizeof(char), Not(Eq(sizeof(int)))); + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(char)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(persistent_hash_map_or.status().error_message(), + HasSubstr("Incorrect value type size")); + } +} + +TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongAllCrc) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + + // Manually corrupt all_crc + crcs.all_crc += kCorruptedValueOffset; + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Crcs::kFileOffset, &crcs, + sizeof(Crcs))); + metadata_sfd.reset(); + + { + // Attempt to create the persistent hash map with metadata containing + // corrupted all_crc. This should fail. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(persistent_hash_map_or.status().error_message(), + HasSubstr("Invalid all crc for PersistentHashMap")); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesWithCorruptedInfoShouldFail) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Info info; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &info, sizeof(Info), + Info::kFileOffset)); + + // Modify info, but don't update the checksum. This would be similar to + // corruption of info. + info.num_deleted_entries += kCorruptedValueOffset; + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Info::kFileOffset, &info, + sizeof(Info))); + { + // Attempt to create the persistent hash map with info that doesn't match + // its checksum and confirm that it fails. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(persistent_hash_map_or.status().error_message(), + HasSubstr("Invalid info crc for PersistentHashMap")); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesWithWrongBucketStorageCrc) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + + // Manually corrupt bucket_storage_crc + crcs.component_crcs.bucket_storage_crc += kCorruptedValueOffset; + crcs.all_crc = Crc32(std::string_view( + reinterpret_cast<const char*>(&crcs.component_crcs), + sizeof(Crcs::ComponentCrcs))) + .Get(); + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Crcs::kFileOffset, &crcs, + sizeof(Crcs))); + { + // Attempt to create the persistent hash map with metadata containing + // corrupted bucket_storage_crc. This should fail. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + persistent_hash_map_or.status().error_message(), + HasSubstr("Mismatch crc with PersistentHashMap bucket storage")); + } +} + +TEST_F(PersistentHashMapTest, InitializeExistingFilesWithWrongEntryStorageCrc) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + + // Manually corrupt entry_storage_crc + crcs.component_crcs.entry_storage_crc += kCorruptedValueOffset; + crcs.all_crc = Crc32(std::string_view( + reinterpret_cast<const char*>(&crcs.component_crcs), + sizeof(Crcs::ComponentCrcs))) + .Get(); + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Crcs::kFileOffset, &crcs, + sizeof(Crcs))); + { + // Attempt to create the persistent hash map with metadata containing + // corrupted entry_storage_crc. This should fail. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT(persistent_hash_map_or.status().error_message(), + HasSubstr("Mismatch crc with PersistentHashMap entry storage")); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesWithWrongKeyValueStorageCrc) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Crcs crcs; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &crcs, sizeof(Crcs), + Crcs::kFileOffset)); + + // Manually corrupt kv_storage_crc + crcs.component_crcs.kv_storage_crc += kCorruptedValueOffset; + crcs.all_crc = Crc32(std::string_view( + reinterpret_cast<const char*>(&crcs.component_crcs), + sizeof(Crcs::ComponentCrcs))) + .Get(); + ASSERT_TRUE(filesystem_.PWrite(metadata_sfd.get(), Crcs::kFileOffset, &crcs, + sizeof(Crcs))); + { + // Attempt to create the persistent hash map with metadata containing + // corrupted kv_storage_crc. This should fail. + libtextclassifier3::StatusOr<std::unique_ptr<PersistentHashMap>> + persistent_hash_map_or = PersistentHashMap::Create( + filesystem_, base_dir_, /*value_type_size=*/sizeof(int)); + EXPECT_THAT(persistent_hash_map_or, + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); + EXPECT_THAT( + persistent_hash_map_or.status().error_message(), + HasSubstr("Mismatch crc with PersistentHashMap key value storage")); + } +} + +TEST_F(PersistentHashMapTest, + InitializeExistingFilesAllowDifferentMaxLoadFactorPercent) { + { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); + ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); + + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + int32_t new_max_load_factor_percent = 100; + { + ASSERT_THAT(new_max_load_factor_percent, + Not(Eq(PersistentHashMap::kDefaultMaxLoadFactorPercent))); + // Attempt to create the persistent hash map with different max load factor + // percent. This should succeed and metadata should be modified correctly. + // Also verify all entries should remain unchanged. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + new_max_load_factor_percent)); + + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "b"), IsOkAndHolds(2)); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } + + const std::string metadata_file_path = + absl_ports::StrCat(base_dir_, "/", PersistentHashMap::kSubDirectory, "/", + PersistentHashMap::kFilePrefix, ".m"); + ScopedFd metadata_sfd(filesystem_.OpenForWrite(metadata_file_path.c_str())); + ASSERT_TRUE(metadata_sfd.is_valid()); + + Info info; + ASSERT_TRUE(filesystem_.PRead(metadata_sfd.get(), &info, sizeof(Info), + Info::kFileOffset)); + EXPECT_THAT(info.max_load_factor_percent, Eq(new_max_load_factor_percent)); + + // Also should update crcs correctly. We test it by creating instance again + // and make sure it won't get corrupted crcs/info errors. + { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + new_max_load_factor_percent)); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); + } +} + +TEST_F(PersistentHashMapTest, PutAndGet) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + EXPECT_THAT(persistent_hash_map, Pointee(IsEmpty())); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-youtube.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + + ICING_EXPECT_OK( + persistent_hash_map->Put("default-google.com", Serialize(100).data())); + ICING_EXPECT_OK( + persistent_hash_map->Put("default-youtube.com", Serialize(50).data())); + + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(100)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-youtube.com"), + IsOkAndHolds(50)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "key-not-exist"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + + ICING_ASSERT_OK(persistent_hash_map->PersistToDisk()); +} + +TEST_F(PersistentHashMapTest, PutShouldOverwriteValueIfKeyExists) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com", Serialize(100).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(100)); + + ICING_EXPECT_OK( + persistent_hash_map->Put("default-google.com", Serialize(200).data())); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(200)); + + ICING_EXPECT_OK( + persistent_hash_map->Put("default-google.com", Serialize(300).data())); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(300)); +} + +TEST_F(PersistentHashMapTest, GetOrPutShouldPutIfKeyDoesNotExist) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + + int val = 1; + EXPECT_THAT(persistent_hash_map->GetOrPut("default-google.com", &val), + IsOk()); + EXPECT_THAT(val, Eq(1)); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(1)); +} + +TEST_F(PersistentHashMapTest, GetOrPutShouldGetIfKeyExists) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ASSERT_THAT( + persistent_hash_map->Put("default-google.com", Serialize(1).data()), + IsOk()); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(1)); + + int val = 2; + EXPECT_THAT(persistent_hash_map->GetOrPut("default-google.com", &val), + IsOk()); + EXPECT_THAT(val, Eq(1)); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(1)); +} + +TEST_F(PersistentHashMapTest, ShouldFailIfKeyContainsTerminationCharacter) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + const char invalid_key[] = "a\0bc"; + std::string_view invalid_key_view(invalid_key, 4); + + int val = 1; + EXPECT_THAT(persistent_hash_map->Put(invalid_key_view, &val), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(persistent_hash_map->GetOrPut(invalid_key_view, &val), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(persistent_hash_map->Get(invalid_key_view, &val), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index e390f0f..4089ec9 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -49,14 +49,16 @@ #include "icing/proto/status.pb.h" #include "icing/query/query-processor.h" #include "icing/query/suggestion-processor.h" +#include "icing/result/page-result.h" #include "icing/result/projection-tree.h" #include "icing/result/projector.h" -#include "icing/result/result-retriever.h" +#include "icing/result/result-retriever-v2.h" #include "icing/schema/schema-store.h" #include "icing/schema/schema-util.h" #include "icing/schema/section.h" -#include "icing/scoring/ranker.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" #include "icing/scoring/scored-document-hit.h" +#include "icing/scoring/scored-document-hits-ranker.h" #include "icing/scoring/scoring-processor.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" @@ -112,6 +114,11 @@ libtextclassifier3::Status ValidateResultSpec( return absl_ports::InvalidArgumentError( "ResultSpecProto.num_per_page cannot be negative."); } + if (result_spec.num_total_bytes_per_page_threshold() <= 0) { + return absl_ports::InvalidArgumentError( + "ResultSpecProto.num_total_bytes_per_page_threshold cannot be " + "non-positive."); + } std::unordered_set<std::string> unique_namespaces; for (const ResultSpecProto::ResultGrouping& result_grouping : result_spec.result_groupings()) { @@ -263,9 +270,9 @@ void TransformStatus(const libtextclassifier3::Status& internal_status, case libtextclassifier3::StatusCode::UNAUTHENTICATED: // Other internal status codes aren't supported externally yet. If it // should be supported, add another switch-case above. - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Internal status code %d not supported in the external API", - internal_status.error_code()); + ICING_LOG(ERROR) << "Internal status code " + << internal_status.error_code() + << " not supported in the external API"; code = StatusProto::UNKNOWN; break; } @@ -295,6 +302,17 @@ libtextclassifier3::Status RetrieveAndAddDocumentInfo( return libtextclassifier3::Status::OK; } +bool ShouldRebuildIndex(const OptimizeStatsProto& optimize_stats) { + int num_invalid_documents = optimize_stats.num_deleted_documents() + + optimize_stats.num_expired_documents(); + // Rebuilding the index could be faster than optimizing the index if we have + // removed most of the documents. + // Based on benchmarks, 85%~95% seems to be a good threshold for most cases. + // TODO(b/238236206): Try using the number of remaining hits in this + // condition, and allow clients to configure the threshold. + return num_invalid_documents >= optimize_stats.num_original_documents() * 0.9; +} + } // namespace IcingSearchEngine::IcingSearchEngine(const IcingSearchEngineOptions& options, @@ -634,18 +652,18 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( StatusProto* result_status = result_proto.mutable_status(); absl_ports::unique_lock l(&mutex_); - std::unique_ptr<Timer> timer = clock_->GetNewTimer(); + ScopedTimer timer(clock_->GetNewTimer(), [&result_proto](int64_t t) { + result_proto.set_latency_ms(t); + }); if (!initialized_) { result_status->set_code(StatusProto::FAILED_PRECONDITION); result_status->set_message("IcingSearchEngine has not been initialized!"); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } auto lost_previous_schema_or = LostPreviousSchema(); if (!lost_previous_schema_or.ok()) { TransformStatus(lost_previous_schema_or.status(), result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } bool lost_previous_schema = lost_previous_schema_or.ValueOrDie(); @@ -663,7 +681,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( std::move(new_schema), ignore_errors_and_delete_documents); if (!set_schema_result_or.ok()) { TransformStatus(set_schema_result_or.status(), result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } SchemaStore::SetSchemaResult set_schema_result = @@ -706,7 +723,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( status = document_store_->UpdateSchemaStore(schema_store_.get()); if (!status.ok()) { TransformStatus(status, result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } } else if (!set_schema_result.old_schema_type_ids_changed.empty() || @@ -716,7 +732,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( set_schema_result); if (!status.ok()) { TransformStatus(status, result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } } @@ -726,7 +741,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( status = index_->Reset(); if (!status.ok()) { TransformStatus(status, result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } @@ -737,7 +751,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( if (!restore_result.status.ok() && !absl_ports::IsDataLoss(restore_result.status)) { TransformStatus(status, result_status); - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } } @@ -748,7 +761,6 @@ SetSchemaResultProto IcingSearchEngine::SetSchema( result_status->set_message("Schema is incompatible."); } - result_proto.set_latency_ms(timer->GetElapsedMilliseconds()); return result_proto; } @@ -804,12 +816,13 @@ PutResultProto IcingSearchEngine::Put(const DocumentProto& document) { PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { ICING_VLOG(1) << "Writing document to document store"; - std::unique_ptr<Timer> put_timer = clock_->GetNewTimer(); - PutResultProto result_proto; StatusProto* result_status = result_proto.mutable_status(); PutDocumentStatsProto* put_document_stats = result_proto.mutable_put_document_stats(); + ScopedTimer put_timer(clock_->GetNewTimer(), [put_document_stats](int64_t t) { + put_document_stats->set_latency_ms(t); + }); // Lock must be acquired before validation because the DocumentStore uses // the schema file to validate, and the schema could be changed in @@ -818,7 +831,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { if (!initialized_) { result_status->set_code(StatusProto::FAILED_PRECONDITION); result_status->set_message("IcingSearchEngine has not been initialized!"); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } @@ -826,7 +838,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { schema_store_.get(), language_segmenter_.get(), std::move(document)); if (!tokenized_document_or.ok()) { TransformStatus(tokenized_document_or.status(), result_status); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } TokenizedDocument tokenized_document( @@ -837,7 +848,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { tokenized_document.num_tokens(), put_document_stats); if (!document_id_or.ok()) { TransformStatus(document_id_or.status(), result_status); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } DocumentId document_id = document_id_or.ValueOrDie(); @@ -846,7 +856,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get()); if (!index_processor_or.ok()) { TransformStatus(index_processor_or.status(), result_status); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } std::unique_ptr<IndexProcessor> index_processor = @@ -867,7 +876,6 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { } TransformStatus(status, result_status); - put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); return result_proto; } @@ -1081,7 +1089,9 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( delete_stats->set_num_schema_types_filtered( search_spec.schema_type_filters_size()); - std::unique_ptr<Timer> delete_timer = clock_->GetNewTimer(); + ScopedTimer delete_timer(clock_->GetNewTimer(), [delete_stats](int64_t t) { + delete_stats->set_latency_ms(t); + }); libtextclassifier3::Status status = ValidateSearchSpec(search_spec, performance_configuration_); if (!status.ok()) { @@ -1096,6 +1106,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( document_store_.get(), schema_store_.get()); if (!query_processor_or.ok()) { TransformStatus(query_processor_or.status(), result_status); + delete_stats->set_parse_query_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } std::unique_ptr<QueryProcessor> query_processor = @@ -1104,6 +1116,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( auto query_results_or = query_processor->ParseSearch(search_spec); if (!query_results_or.ok()) { TransformStatus(query_results_or.status(), result_status); + delete_stats->set_parse_query_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } QueryProcessor::QueryResults query_results = @@ -1131,6 +1145,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( query_results.root_iterator->doc_hit_info().document_id()); if (!status.ok()) { TransformStatus(status, result_status); + delete_stats->set_document_removal_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } } @@ -1138,6 +1154,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( query_results.root_iterator->doc_hit_info().document_id()); if (!status.ok()) { TransformStatus(status, result_status); + delete_stats->set_document_removal_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } } @@ -1156,7 +1174,6 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( result_proto.mutable_status()->set_message( "No documents matched the query to delete by!"); } - delete_stats->set_latency_ms(delete_timer->GetElapsedMilliseconds()); delete_stats->set_num_documents_deleted(num_deleted); return result_proto; } @@ -1199,11 +1216,10 @@ OptimizeResultProto IcingSearchEngine::Optimize() { return result_proto; } - std::unique_ptr<Timer> optimize_timer = clock_->GetNewTimer(); OptimizeStatsProto* optimize_stats = result_proto.mutable_optimize_stats(); - int64_t before_size = filesystem_->GetDiskUsage(options_.base_dir().c_str()); - optimize_stats->set_storage_size_before( - Filesystem::SanitizeFileSize(before_size)); + ScopedTimer optimize_timer( + clock_->GetNewTimer(), + [optimize_stats](int64_t t) { optimize_stats->set_latency_ms(t); }); // Flushes data to disk before doing optimization auto status = InternalPersistToDisk(PersistType::FULL); @@ -1212,52 +1228,85 @@ OptimizeResultProto IcingSearchEngine::Optimize() { return result_proto; } + int64_t before_size = filesystem_->GetDiskUsage(options_.base_dir().c_str()); + optimize_stats->set_storage_size_before( + Filesystem::SanitizeFileSize(before_size)); + // TODO(b/143646633): figure out if we need to optimize index and doc store // at the same time. std::unique_ptr<Timer> optimize_doc_store_timer = clock_->GetNewTimer(); - libtextclassifier3::Status optimization_status = - OptimizeDocumentStore(optimize_stats); + libtextclassifier3::StatusOr<std::vector<DocumentId>> + document_id_old_to_new_or = OptimizeDocumentStore(optimize_stats); optimize_stats->set_document_store_optimize_latency_ms( optimize_doc_store_timer->GetElapsedMilliseconds()); - if (!optimization_status.ok() && - !absl_ports::IsDataLoss(optimization_status)) { + if (!document_id_old_to_new_or.ok() && + !absl_ports::IsDataLoss(document_id_old_to_new_or.status())) { // The status now is either ABORTED_ERROR or INTERNAL_ERROR. // If ABORTED_ERROR, Icing should still be working. // If INTERNAL_ERROR, we're having IO errors or other errors that we can't // recover from. - TransformStatus(optimization_status, result_status); + TransformStatus(document_id_old_to_new_or.status(), result_status); return result_proto; } // The status is either OK or DATA_LOSS. The optimized document store is // guaranteed to work, so we update index according to the new document store. std::unique_ptr<Timer> optimize_index_timer = clock_->GetNewTimer(); - libtextclassifier3::Status index_reset_status = index_->Reset(); - if (!index_reset_status.ok()) { - status = absl_ports::Annotate( - absl_ports::InternalError("Failed to reset index after optimization."), - index_reset_status.error_message()); - TransformStatus(status, result_status); - return result_proto; + bool should_rebuild_index = + !document_id_old_to_new_or.ok() || ShouldRebuildIndex(*optimize_stats); + if (!should_rebuild_index) { + optimize_stats->set_index_restoration_mode( + OptimizeStatsProto::INDEX_TRANSLATION); + libtextclassifier3::Status index_optimize_status = + index_->Optimize(document_id_old_to_new_or.ValueOrDie()); + if (!index_optimize_status.ok()) { + ICING_LOG(WARNING) << "Failed to optimize index. Error: " + << index_optimize_status.error_message(); + should_rebuild_index = true; + } } + // If we received a DATA_LOSS error from OptimizeDocumentStore, we have a + // valid document store, but it might be the old one or the new one. So throw + // out the index and rebuild from scratch. + // Likewise, if Index::Optimize failed, then attempt to recover the index by + // rebuilding from scratch. + // If ShouldRebuildIndex() returns true, we will also rebuild the index for + // better performance. + if (should_rebuild_index) { + optimize_stats->set_index_restoration_mode( + OptimizeStatsProto::FULL_INDEX_REBUILD); + ICING_LOG(WARNING) << "Resetting the entire index!"; + libtextclassifier3::Status index_reset_status = index_->Reset(); + if (!index_reset_status.ok()) { + status = absl_ports::Annotate( + absl_ports::InternalError("Failed to reset index."), + index_reset_status.error_message()); + TransformStatus(status, result_status); + optimize_stats->set_index_restoration_latency_ms( + optimize_index_timer->GetElapsedMilliseconds()); + return result_proto; + } - IndexRestorationResult index_restoration_status = RestoreIndexIfNeeded(); - optimize_stats->set_index_restoration_latency_ms( - optimize_index_timer->GetElapsedMilliseconds()); - // DATA_LOSS means that we have successfully re-added content to the index. - // Some indexed content was lost, but otherwise the index is in a valid state - // and can be queried. - if (!index_restoration_status.status.ok() && - !absl_ports::IsDataLoss(index_restoration_status.status)) { - status = absl_ports::Annotate( - absl_ports::InternalError( - "Failed to reindex documents after optimization."), - index_restoration_status.status.error_message()); + IndexRestorationResult index_restoration_status = RestoreIndexIfNeeded(); + // DATA_LOSS means that we have successfully re-added content to the index. + // Some indexed content was lost, but otherwise the index is in a valid + // state and can be queried. + if (!index_restoration_status.status.ok() && + !absl_ports::IsDataLoss(index_restoration_status.status)) { + status = absl_ports::Annotate( + absl_ports::InternalError( + "Failed to reindex documents after optimization."), + index_restoration_status.status.error_message()); - TransformStatus(status, result_status); - return result_proto; + TransformStatus(status, result_status); + optimize_stats->set_index_restoration_latency_ms( + optimize_index_timer->GetElapsedMilliseconds()); + return result_proto; + } } + optimize_stats->set_index_restoration_latency_ms( + optimize_index_timer->GetElapsedMilliseconds()); // Read the optimize status to get the time that we last ran. std::string optimize_status_filename = @@ -1279,12 +1328,18 @@ OptimizeResultProto IcingSearchEngine::Optimize() { optimize_status->set_last_successful_optimize_run_time_ms(current_time); optimize_status_file.Write(std::move(optimize_status)); + // Flushes data to disk after doing optimization + status = InternalPersistToDisk(PersistType::FULL); + if (!status.ok()) { + TransformStatus(status, result_status); + return result_proto; + } + int64_t after_size = filesystem_->GetDiskUsage(options_.base_dir().c_str()); optimize_stats->set_storage_size_after( Filesystem::SanitizeFileSize(after_size)); - optimize_stats->set_latency_ms(optimize_timer->GetElapsedMilliseconds()); - TransformStatus(optimization_status, result_status); + TransformStatus(document_id_old_to_new_or.status(), result_status); return result_proto; } @@ -1442,7 +1497,9 @@ SearchResultProto IcingSearchEngine::Search( QueryStatsProto* query_stats = result_proto.mutable_query_stats(); query_stats->set_query_length(search_spec.query().length()); - std::unique_ptr<Timer> overall_timer = clock_->GetNewTimer(); + ScopedTimer overall_timer(clock_->GetNewTimer(), [query_stats](int64_t t) { + query_stats->set_latency_ms(t); + }); libtextclassifier3::Status status = ValidateResultSpec(result_spec); if (!status.ok()) { @@ -1470,6 +1527,8 @@ SearchResultProto IcingSearchEngine::Search( document_store_.get(), schema_store_.get()); if (!query_processor_or.ok()) { TransformStatus(query_processor_or.status(), result_status); + query_stats->set_parse_query_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } std::unique_ptr<QueryProcessor> query_processor = @@ -1478,6 +1537,8 @@ SearchResultProto IcingSearchEngine::Search( auto query_results_or = query_processor->ParseSearch(search_spec); if (!query_results_or.ok()) { TransformStatus(query_results_or.status(), result_status); + query_stats->set_parse_query_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } QueryProcessor::QueryResults query_results = @@ -1498,6 +1559,8 @@ SearchResultProto IcingSearchEngine::Search( scoring_spec, document_store_.get(), schema_store_.get()); if (!scoring_processor_or.ok()) { TransformStatus(scoring_processor_or.status(), result_status); + query_stats->set_scoring_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } std::unique_ptr<ScoringProcessor> scoring_processor = @@ -1517,62 +1580,62 @@ SearchResultProto IcingSearchEngine::Search( } component_timer = clock_->GetNewTimer(); - // Ranks and paginates results - libtextclassifier3::StatusOr<PageResultState> page_result_state_or = - result_state_manager_->RankAndPaginate(ResultState( - std::move(result_document_hits), std::move(query_results.query_terms), - search_spec, scoring_spec, result_spec, *document_store_)); - if (!page_result_state_or.ok()) { - TransformStatus(page_result_state_or.status(), result_status); - return result_proto; - } - PageResultState page_result_state = - std::move(page_result_state_or).ValueOrDie(); + // Ranks results + std::unique_ptr<ScoredDocumentHitsRanker> ranker = + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(result_document_hits), + /*is_descending=*/scoring_spec.order_by() == + ScoringSpecProto::Order::DESC); query_stats->set_ranking_latency_ms( component_timer->GetElapsedMilliseconds()); component_timer = clock_->GetNewTimer(); - // Retrieves the document protos and snippets if requested + // RanksAndPaginates and retrieves the document protos and snippets if + // requested auto result_retriever_or = - ResultRetriever::Create(document_store_.get(), schema_store_.get(), - language_segmenter_.get(), normalizer_.get()); + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get()); if (!result_retriever_or.ok()) { - result_state_manager_->InvalidateResultState( - page_result_state.next_page_token); TransformStatus(result_retriever_or.status(), result_status); + query_stats->set_document_retrieval_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } - std::unique_ptr<ResultRetriever> result_retriever = + std::unique_ptr<ResultRetrieverV2> result_retriever = std::move(result_retriever_or).ValueOrDie(); - libtextclassifier3::StatusOr<std::vector<SearchResultProto::ResultProto>> - results_or = result_retriever->RetrieveResults(page_result_state); - if (!results_or.ok()) { - result_state_manager_->InvalidateResultState( - page_result_state.next_page_token); - TransformStatus(results_or.status(), result_status); + libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> + page_result_info_or = result_state_manager_->CacheAndRetrieveFirstPage( + std::move(ranker), std::move(query_results.query_terms), search_spec, + scoring_spec, result_spec, *document_store_, *result_retriever); + if (!page_result_info_or.ok()) { + TransformStatus(page_result_info_or.status(), result_status); + query_stats->set_document_retrieval_latency_ms( + component_timer->GetElapsedMilliseconds()); return result_proto; } - std::vector<SearchResultProto::ResultProto> results = - std::move(results_or).ValueOrDie(); + std::pair<uint64_t, PageResult> page_result_info = + std::move(page_result_info_or).ValueOrDie(); // Assembles the final search result proto - result_proto.mutable_results()->Reserve(results.size()); - for (SearchResultProto::ResultProto& result : results) { + result_proto.mutable_results()->Reserve( + page_result_info.second.results.size()); + for (SearchResultProto::ResultProto& result : + page_result_info.second.results) { result_proto.mutable_results()->Add(std::move(result)); } + result_status->set_code(StatusProto::OK); - if (page_result_state.next_page_token != kInvalidNextPageToken) { - result_proto.set_next_page_token(page_result_state.next_page_token); + if (page_result_info.first != kInvalidNextPageToken) { + result_proto.set_next_page_token(page_result_info.first); } + query_stats->set_document_retrieval_latency_ms( component_timer->GetElapsedMilliseconds()); - query_stats->set_latency_ms(overall_timer->GetElapsedMilliseconds()); query_stats->set_num_results_returned_current_page( result_proto.results_size()); query_stats->set_num_results_with_snippets( - std::min(result_proto.results_size(), - result_spec.snippet_spec().num_to_snippet())); + page_result_info.second.num_results_with_snippets); return result_proto; } @@ -1593,53 +1656,46 @@ SearchResultProto IcingSearchEngine::GetNextPage(uint64_t next_page_token) { query_stats->set_is_first_page(false); std::unique_ptr<Timer> overall_timer = clock_->GetNewTimer(); - libtextclassifier3::StatusOr<PageResultState> page_result_state_or = - result_state_manager_->GetNextPage(next_page_token); - - if (!page_result_state_or.ok()) { - if (absl_ports::IsNotFound(page_result_state_or.status())) { - // NOT_FOUND means an empty result. - result_status->set_code(StatusProto::OK); - } else { - // Real error, pass up. - TransformStatus(page_result_state_or.status(), result_status); - } - return result_proto; - } - - PageResultState page_result_state = - std::move(page_result_state_or).ValueOrDie(); - query_stats->set_requested_page_size(page_result_state.requested_page_size); - - // Retrieves the document protos. auto result_retriever_or = - ResultRetriever::Create(document_store_.get(), schema_store_.get(), - language_segmenter_.get(), normalizer_.get()); + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get()); if (!result_retriever_or.ok()) { TransformStatus(result_retriever_or.status(), result_status); return result_proto; } - std::unique_ptr<ResultRetriever> result_retriever = + std::unique_ptr<ResultRetrieverV2> result_retriever = std::move(result_retriever_or).ValueOrDie(); - libtextclassifier3::StatusOr<std::vector<SearchResultProto::ResultProto>> - results_or = result_retriever->RetrieveResults(page_result_state); - if (!results_or.ok()) { - TransformStatus(results_or.status(), result_status); + libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> + page_result_info_or = result_state_manager_->GetNextPage( + next_page_token, *result_retriever); + if (!page_result_info_or.ok()) { + if (absl_ports::IsNotFound(page_result_info_or.status())) { + // NOT_FOUND means an empty result. + result_status->set_code(StatusProto::OK); + } else { + // Real error, pass up. + TransformStatus(page_result_info_or.status(), result_status); + } return result_proto; } - std::vector<SearchResultProto::ResultProto> results = - std::move(results_or).ValueOrDie(); + + std::pair<uint64_t, PageResult> page_result_info = + std::move(page_result_info_or).ValueOrDie(); + query_stats->set_requested_page_size( + page_result_info.second.requested_page_size); // Assembles the final search result proto - result_proto.mutable_results()->Reserve(results.size()); - for (SearchResultProto::ResultProto& result : results) { + result_proto.mutable_results()->Reserve( + page_result_info.second.results.size()); + for (SearchResultProto::ResultProto& result : + page_result_info.second.results) { result_proto.mutable_results()->Add(std::move(result)); } result_status->set_code(StatusProto::OK); - if (page_result_state.next_page_token != kInvalidNextPageToken) { - result_proto.set_next_page_token(page_result_state.next_page_token); + if (page_result_info.first != kInvalidNextPageToken) { + result_proto.set_next_page_token(page_result_info.first); } // The only thing that we're doing is document retrieval. So document @@ -1650,12 +1706,8 @@ SearchResultProto IcingSearchEngine::GetNextPage(uint64_t next_page_token) { query_stats->set_latency_ms(overall_timer->GetElapsedMilliseconds()); query_stats->set_num_results_returned_current_page( result_proto.results_size()); - int num_left_to_snippet = - std::max(page_result_state.snippet_context.snippet_spec.num_to_snippet() - - page_result_state.num_previously_returned, - 0); query_stats->set_num_results_with_snippets( - std::min(result_proto.results_size(), num_left_to_snippet)); + page_result_info.second.num_results_with_snippets); return result_proto; } @@ -1668,8 +1720,8 @@ void IcingSearchEngine::InvalidateNextPageToken(uint64_t next_page_token) { result_state_manager_->InvalidateResultState(next_page_token); } -libtextclassifier3::Status IcingSearchEngine::OptimizeDocumentStore( - OptimizeStatsProto* optimize_stats) { +libtextclassifier3::StatusOr<std::vector<DocumentId>> +IcingSearchEngine::OptimizeDocumentStore(OptimizeStatsProto* optimize_stats) { // Gets the current directory path and an empty tmp directory path for // document store optimization. const std::string current_document_dir = @@ -1685,15 +1737,16 @@ libtextclassifier3::Status IcingSearchEngine::OptimizeDocumentStore( } // Copies valid document data to tmp directory - auto optimize_status = document_store_->OptimizeInto( - temporary_document_dir, language_segmenter_.get(), optimize_stats); + libtextclassifier3::StatusOr<std::vector<DocumentId>> + document_id_old_to_new_or = document_store_->OptimizeInto( + temporary_document_dir, language_segmenter_.get(), optimize_stats); // Handles error if any - if (!optimize_status.ok()) { + if (!document_id_old_to_new_or.ok()) { filesystem_->DeleteDirectoryRecursively(temporary_document_dir.c_str()); return absl_ports::Annotate( absl_ports::AbortedError("Failed to optimize document store"), - optimize_status.error_message()); + document_id_old_to_new_or.status().error_message()); } // result_state_manager_ depends on document_store_. So we need to reset it at @@ -1768,7 +1821,7 @@ libtextclassifier3::Status IcingSearchEngine::OptimizeDocumentStore( ICING_LOG(ERROR) << "Document store has been optimized, but it failed to " "delete temporary file directory"; } - return libtextclassifier3::Status::OK; + return document_id_old_to_new_or; } IcingSearchEngine::IndexRestorationResult diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h index 6a06fb9..2eda803 100644 --- a/icing/icing-search-engine.h +++ b/icing/icing-search-engine.h @@ -20,13 +20,13 @@ #include <string> #include <string_view> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/mutex.h" #include "icing/absl_ports/thread_annotations.h" #include "icing/file/filesystem.h" #include "icing/index/index.h" +#include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/performance-configuration.h" #include "icing/proto/document.pb.h" @@ -582,14 +582,16 @@ class IcingSearchEngine { // would need call Initialize() to reinitialize everything into a valid state. // // Returns: - // OK on success + // On success, a vector that maps from old document id to new document id. A + // value of kInvalidDocumentId indicates that the old document id has been + // deleted. // ABORTED_ERROR if any error happens before the actual optimization, the // original document store should be still available // DATA_LOSS_ERROR on errors that could potentially cause data loss, // document store is still available // INTERNAL_ERROR on any IO errors or other errors that we can't recover // from - libtextclassifier3::Status OptimizeDocumentStore( + libtextclassifier3::StatusOr<std::vector<DocumentId>> OptimizeDocumentStore( OptimizeStatsProto* optimize_stats) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc index f922b98..2ac456e 100644 --- a/icing/icing-search-engine_test.cc +++ b/icing/icing-search-engine_test.cc @@ -20,13 +20,13 @@ #include <string> #include <utility> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" #include "icing/file/mock-filesystem.h" +#include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-mock-filesystem.h" #include "icing/portable/endian.h" #include "icing/portable/equals-proto.h" @@ -2274,7 +2274,12 @@ TEST_F(IcingSearchEngineTest, SearchReturnsScoresCreationTimestamp) { } TEST_F(IcingSearchEngineTest, SearchReturnsOneResult) { - IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + auto fake_clock = std::make_unique<FakeClock>(); + fake_clock->SetTimerElapsedMilliseconds(1000); + TestIcingSearchEngine icing(GetDefaultIcingOptions(), + std::make_unique<Filesystem>(), + std::make_unique<IcingFilesystem>(), + std::move(fake_clock), GetTestJniCache()); ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); @@ -2299,6 +2304,15 @@ TEST_F(IcingSearchEngineTest, SearchReturnsOneResult) { SearchResultProto search_result_proto = icing.Search(search_spec, GetDefaultScoringSpec(), result_spec); EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + + EXPECT_THAT(search_result_proto.query_stats().latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().parse_query_latency_ms(), + Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().scoring_latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().ranking_latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().document_retrieval_latency_ms(), + Eq(1000)); + // The token is a random number so we don't verify it. expected_search_result_proto.set_next_page_token( search_result_proto.next_page_token()); @@ -2347,6 +2361,30 @@ TEST_F(IcingSearchEngineTest, SearchNegativeResultLimitReturnsInvalidArgument) { expected_search_result_proto)); } +TEST_F(IcingSearchEngineTest, + SearchNonPositivePageTotalBytesLimitReturnsInvalidArgument) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::PREFIX); + search_spec.set_query(""); + + ResultSpecProto result_spec; + result_spec.set_num_total_bytes_per_page_threshold(-1); + + SearchResultProto actual_results1 = + icing.Search(search_spec, GetDefaultScoringSpec(), result_spec); + EXPECT_THAT(actual_results1.status(), + ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); + + result_spec.set_num_total_bytes_per_page_threshold(0); + SearchResultProto actual_results2 = + icing.Search(search_spec, GetDefaultScoringSpec(), result_spec); + EXPECT_THAT(actual_results2.status(), + ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); +} + TEST_F(IcingSearchEngineTest, SearchWithPersistenceReturnsValidResults) { IcingSearchEngineOptions icing_options = GetDefaultIcingOptions(); @@ -2403,7 +2441,12 @@ TEST_F(IcingSearchEngineTest, SearchWithPersistenceReturnsValidResults) { } TEST_F(IcingSearchEngineTest, SearchShouldReturnEmpty) { - IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + auto fake_clock = std::make_unique<FakeClock>(); + fake_clock->SetTimerElapsedMilliseconds(1000); + TestIcingSearchEngine icing(GetDefaultIcingOptions(), + std::make_unique<Filesystem>(), + std::make_unique<IcingFilesystem>(), + std::move(fake_clock), GetTestJniCache()); ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); @@ -2418,6 +2461,15 @@ TEST_F(IcingSearchEngineTest, SearchShouldReturnEmpty) { SearchResultProto search_result_proto = icing.Search(search_spec, GetDefaultScoringSpec(), ResultSpecProto::default_instance()); + EXPECT_THAT(search_result_proto.status(), ProtoIsOk()); + + EXPECT_THAT(search_result_proto.query_stats().latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().parse_query_latency_ms(), + Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().scoring_latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().ranking_latency_ms(), Eq(0)); + EXPECT_THAT(search_result_proto.query_stats().document_retrieval_latency_ms(), + Eq(0)); EXPECT_THAT(search_result_proto, EqualsSearchResultIgnoreStatsAndScores( expected_search_result_proto)); @@ -2894,10 +2946,11 @@ TEST_F(IcingSearchEngineTest, GetAndPutShouldWorkAfterOptimization) { DocumentProto document1 = CreateMessageDocument("namespace", "uri1"); DocumentProto document2 = CreateMessageDocument("namespace", "uri2"); DocumentProto document3 = CreateMessageDocument("namespace", "uri3"); + DocumentProto document4 = CreateMessageDocument("namespace", "uri4"); + DocumentProto document5 = CreateMessageDocument("namespace", "uri5"); GetResultProto expected_get_result_proto; expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); - *expected_get_result_proto.mutable_document() = document1; { IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); @@ -2905,27 +2958,49 @@ TEST_F(IcingSearchEngineTest, GetAndPutShouldWorkAfterOptimization) { ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document3).status(), ProtoIsOk()); + ASSERT_THAT(icing.Delete("namespace", "uri2").status(), ProtoIsOk()); ASSERT_THAT(icing.Optimize().status(), ProtoIsOk()); // Validates that Get() and Put() are good right after Optimize() + *expected_get_result_proto.mutable_document() = document1; EXPECT_THAT( icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); - EXPECT_THAT(icing.Put(document2).status(), ProtoIsOk()); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()) + .status() + .code(), + Eq(StatusProto::NOT_FOUND)); + *expected_get_result_proto.mutable_document() = document3; + EXPECT_THAT( + icing.Get("namespace", "uri3", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); + EXPECT_THAT(icing.Put(document4).status(), ProtoIsOk()); } // Destroys IcingSearchEngine to make sure nothing is cached. IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); EXPECT_THAT(icing.Initialize().status(), ProtoIsOk()); + *expected_get_result_proto.mutable_document() = document1; EXPECT_THAT( icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); - - *expected_get_result_proto.mutable_document() = document2; EXPECT_THAT( - icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()), + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()) + .status() + .code(), + Eq(StatusProto::NOT_FOUND)); + *expected_get_result_proto.mutable_document() = document3; + EXPECT_THAT( + icing.Get("namespace", "uri3", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); + *expected_get_result_proto.mutable_document() = document4; + EXPECT_THAT( + icing.Get("namespace", "uri4", GetResultSpecProto::default_instance()), EqualsProto(expected_get_result_proto)); - EXPECT_THAT(icing.Put(document3).status(), ProtoIsOk()); + EXPECT_THAT(icing.Put(document5).status(), ProtoIsOk()); } TEST_F(IcingSearchEngineTest, DeleteShouldWorkAfterOptimization) { @@ -3821,8 +3896,11 @@ TEST_F(IcingSearchEngineTest, ProtoIsOk()); // Optimize() fails due to filesystem error - EXPECT_THAT(icing.Optimize().status(), - ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + OptimizeResultProto result = icing.Optimize(); + EXPECT_THAT(result.status(), ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + // Should rebuild the index for data loss. + EXPECT_THAT(result.optimize_stats().index_restoration_mode(), + Eq(OptimizeStatsProto::FULL_INDEX_REBUILD)); // Document is not found because original file directory is missing GetResultProto expected_get_result_proto; @@ -3895,8 +3973,11 @@ TEST_F(IcingSearchEngineTest, OptimizationShouldRecoverIfDataFilesAreMissing) { ProtoIsOk()); // Optimize() fails due to filesystem error - EXPECT_THAT(icing.Optimize().status(), - ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + OptimizeResultProto result = icing.Optimize(); + EXPECT_THAT(result.status(), ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + // Should rebuild the index for data loss. + EXPECT_THAT(result.optimize_stats().index_restoration_mode(), + Eq(OptimizeStatsProto::FULL_INDEX_REBUILD)); // Document is not found because original files are missing GetResultProto expected_get_result_proto; @@ -7867,6 +7948,7 @@ TEST_F(IcingSearchEngineTest, OptimizeStatsProtoTest) { expected.set_num_original_documents(3); expected.set_num_deleted_documents(1); expected.set_num_expired_documents(1); + expected.set_index_restoration_mode(OptimizeStatsProto::INDEX_TRANSLATION); // Run Optimize OptimizeResultProto result = icing->Optimize(); @@ -7899,6 +7981,7 @@ TEST_F(IcingSearchEngineTest, OptimizeStatsProtoTest) { expected.set_num_deleted_documents(0); expected.set_num_expired_documents(0); expected.set_time_since_last_optimize_ms(10000); + expected.set_index_restoration_mode(OptimizeStatsProto::INDEX_TRANSLATION); // Run Optimize result = icing->Optimize(); @@ -7907,6 +7990,29 @@ TEST_F(IcingSearchEngineTest, OptimizeStatsProtoTest) { result.mutable_optimize_stats()->clear_storage_size_before(); result.mutable_optimize_stats()->clear_storage_size_after(); EXPECT_THAT(result.optimize_stats(), EqualsProto(expected)); + + // Delete the last document. + ASSERT_THAT(icing->Delete(document3.namespace_(), document3.uri()).status(), + ProtoIsOk()); + + expected = OptimizeStatsProto(); + expected.set_latency_ms(5); + expected.set_document_store_optimize_latency_ms(5); + expected.set_index_restoration_latency_ms(5); + expected.set_num_original_documents(1); + expected.set_num_deleted_documents(1); + expected.set_num_expired_documents(0); + expected.set_time_since_last_optimize_ms(0); + // Should rebuild the index since all documents are removed. + expected.set_index_restoration_mode(OptimizeStatsProto::FULL_INDEX_REBUILD); + + // Run Optimize + result = icing->Optimize(); + EXPECT_THAT(result.optimize_stats().storage_size_before(), + Ge(result.optimize_stats().storage_size_after())); + result.mutable_optimize_stats()->clear_storage_size_before(); + result.mutable_optimize_stats()->clear_storage_size_after(); + EXPECT_THAT(result.optimize_stats(), EqualsProto(expected)); } TEST_F(IcingSearchEngineTest, StorageInfoTest) { diff --git a/icing/index/hit/hit.cc b/icing/index/hit/hit.cc index 887e6e4..ce1c366 100644 --- a/icing/index/hit/hit.cc +++ b/icing/index/hit/hit.cc @@ -97,6 +97,11 @@ bool Hit::is_in_prefix_section() const { return bit_util::BitfieldGet(value(), kInPrefixSection, 1); } +Hit Hit::TranslateHit(Hit old_hit, DocumentId new_document_id) { + return Hit(old_hit.section_id(), new_document_id, old_hit.term_frequency(), + old_hit.is_in_prefix_section(), old_hit.is_prefix_hit()); +} + bool Hit::EqualsDocumentIdAndSectionId::operator()(const Hit& hit1, const Hit& hit2) const { return (hit1.value() >> kNumFlags) == (hit2.value() >> kNumFlags); diff --git a/icing/index/hit/hit.h b/icing/index/hit/hit.h index ee1f64b..f8cbd78 100644 --- a/icing/index/hit/hit.h +++ b/icing/index/hit/hit.h @@ -77,6 +77,9 @@ class Hit { bool is_prefix_hit() const; bool is_in_prefix_section() const; + // Creates a new hit based on old_hit but with new_document_id set. + static Hit TranslateHit(Hit old_hit, DocumentId new_document_id); + bool operator<(const Hit& h2) const { return value() < h2.value(); } bool operator==(const Hit& h2) const { return value() == h2.value(); } diff --git a/icing/index/index.cc b/icing/index/index.cc index 02ba699..1d863cc 100644 --- a/icing/index/index.cc +++ b/icing/index/index.cc @@ -264,6 +264,14 @@ IndexStorageInfoProto Index::GetStorageInfo() const { return main_index_->GetStorageInfo(std::move(storage_info)); } +libtextclassifier3::Status Index::Optimize( + const std::vector<DocumentId>& document_id_old_to_new) { + if (main_index_->last_added_document_id() != kInvalidDocumentId) { + ICING_RETURN_IF_ERROR(main_index_->Optimize(document_id_old_to_new)); + } + return lite_index_->Optimize(document_id_old_to_new, term_id_codec_.get()); +} + libtextclassifier3::Status Index::Editor::BufferTerm(const char* term) { // Step 1: See if this term is already in the lexicon uint32_t tvi; diff --git a/icing/index/index.h b/icing/index/index.h index f101a91..748acb0 100644 --- a/icing/index/index.h +++ b/icing/index/index.h @@ -263,6 +263,15 @@ class Index { return lite_index_->Reset(); } + // Reduces internal file sizes by reclaiming space of deleted documents. + // + // Returns: + // OK on success + // INTERNAL_ERROR on IO error, this indicates that the index may be in an + // invalid state and should be cleared. + libtextclassifier3::Status Optimize( + const std::vector<DocumentId>& document_id_old_to_new); + private: Index(const Options& options, std::unique_ptr<TermIdCodec> term_id_codec, std::unique_ptr<LiteIndex> lite_index, diff --git a/icing/index/index_test.cc b/icing/index/index_test.cc index 2eb3b59..7323603 100644 --- a/icing/index/index_test.cc +++ b/icing/index/index_test.cc @@ -14,6 +14,7 @@ #include "icing/index/index.h" +#include <algorithm> #include <cstdint> #include <limits> #include <memory> @@ -48,6 +49,7 @@ namespace lib { namespace { +using ::testing::ContainerEq; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Ge; @@ -79,6 +81,23 @@ class IndexTest : public Test { icing_filesystem_.DeleteDirectoryRecursively(index_dir_.c_str()); } + std::vector<DocHitInfo> GetHits( + std::unique_ptr<DocHitInfoIterator> iterator) { + std::vector<DocHitInfo> infos; + while (iterator->Advance().ok()) { + infos.push_back(iterator->doc_hit_info()); + } + return infos; + } + + libtextclassifier3::StatusOr<std::vector<DocHitInfo>> GetHits( + std::string term, TermMatchType::Code match_type) { + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<DocHitInfoIterator> itr, + index_->GetIterator(term, kSectionIdMaskAll, match_type)); + return GetHits(std::move(itr)); + } + Filesystem filesystem_; IcingFilesystem icing_filesystem_; std::string index_dir_; @@ -97,14 +116,6 @@ constexpr DocumentId kDocumentId8 = 8; constexpr SectionId kSectionId2 = 2; constexpr SectionId kSectionId3 = 3; -std::vector<DocHitInfo> GetHits(std::unique_ptr<DocHitInfoIterator> iterator) { - std::vector<DocHitInfo> infos; - while (iterator->Advance().ok()) { - infos.push_back(iterator->doc_hit_info()); - } - return infos; -} - MATCHER_P2(EqualsDocHitInfo, document_id, sections, "") { const DocHitInfo& actual = arg; SectionIdMask section_mask = kSectionIdMaskNone; @@ -249,6 +260,66 @@ TEST_F(IndexTest, SingleHitSingleTermIndexAfterMerge) { kDocumentId0, std::vector<SectionId>{kSectionId2}))); } +TEST_F(IndexTest, SingleHitSingleTermIndexAfterOptimize) { + Index::Editor edit = index_->Edit( + kDocumentId2, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + index_->set_last_added_document_id(kDocumentId2); + + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2})); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId2, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId2); + + // Mapping to a different docid will translate the hit + ICING_ASSERT_OK(index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId, kDocumentId1})); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId1, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId1); + + // Mapping to kInvalidDocumentId will remove the hit. + ICING_ASSERT_OK( + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId})); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); +} + +TEST_F(IndexTest, SingleHitSingleTermIndexAfterMergeAndOptimize) { + Index::Editor edit = index_->Edit( + kDocumentId2, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + index_->set_last_added_document_id(kDocumentId2); + + ICING_ASSERT_OK(index_->Merge()); + + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2})); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId2, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId2); + + // Mapping to a different docid will translate the hit + ICING_ASSERT_OK(index_->Optimize( + /*document_id_old_to_new=*/{0, kInvalidDocumentId, kDocumentId1})); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId1, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId1); + + // Mapping to kInvalidDocumentId will remove the hit. + ICING_ASSERT_OK( + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId})); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); +} + TEST_F(IndexTest, SingleHitMultiTermIndex) { Index::Editor edit = index_->Edit( kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); @@ -281,6 +352,112 @@ TEST_F(IndexTest, SingleHitMultiTermIndexAfterMerge) { kDocumentId0, std::vector<SectionId>{kSectionId2}))); } +TEST_F(IndexTest, MultiHitMultiTermIndexAfterOptimize) { + Index::Editor edit = index_->Edit( + kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + + edit = index_->Edit(kDocumentId1, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("bar"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + + edit = index_->Edit(kDocumentId2, kSectionId3, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + index_->set_last_added_document_id(kDocumentId2); + + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2})); + EXPECT_THAT( + GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(kDocumentId2, std::vector<SectionId>{kSectionId3}), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId2})))); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId1, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId2); + + // Delete document id 1, and document id 2 is translated to 1. + ICING_ASSERT_OK( + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId, 1})); + EXPECT_THAT( + GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId3}), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId2})))); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId1); + + // Delete all the rest documents. + ICING_ASSERT_OK(index_->Optimize( + /*document_id_old_to_new=*/{kInvalidDocumentId, kInvalidDocumentId})); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); +} + +TEST_F(IndexTest, MultiHitMultiTermIndexAfterMergeAndOptimize) { + Index::Editor edit = index_->Edit( + kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + + edit = index_->Edit(kDocumentId1, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("bar"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + + edit = index_->Edit(kDocumentId2, kSectionId3, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("foo"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + index_->set_last_added_document_id(kDocumentId2); + + ICING_ASSERT_OK(index_->Merge()); + + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2})); + EXPECT_THAT( + GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(kDocumentId2, std::vector<SectionId>{kSectionId3}), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId2})))); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre(EqualsDocHitInfo( + kDocumentId1, std::vector<SectionId>{kSectionId2})))); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId2); + + // Delete document id 1, and document id 2 is translated to 1. + ICING_ASSERT_OK( + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId, 1})); + EXPECT_THAT( + GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(ElementsAre( + EqualsDocHitInfo(kDocumentId1, std::vector<SectionId>{kSectionId3}), + EqualsDocHitInfo(kDocumentId0, + std::vector<SectionId>{kSectionId2})))); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId1); + + // Delete all the rest documents. + ICING_ASSERT_OK(index_->Optimize( + /*document_id_old_to_new=*/{kInvalidDocumentId, kInvalidDocumentId})); + EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), + IsOkAndHolds(IsEmpty())); + EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); +} + TEST_F(IndexTest, NoHitMultiTermIndex) { Index::Editor edit = index_->Edit( kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, /*namespace_id=*/0); @@ -807,6 +984,114 @@ TEST_F(IndexTest, FullIndexMerge) { EXPECT_THAT(last_itr->doc_hit_info().document_id(), Eq(document_id + 1)); } +TEST_F(IndexTest, OptimizeShouldWorkForEmptyIndex) { + // Optimize an empty index should succeed, but have no effects. + ICING_ASSERT_OK(index_->Optimize(std::vector<DocumentId>())); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<DocHitInfoIterator> itr, + index_->GetIterator("", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); + EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); + + ICING_ASSERT_OK_AND_ASSIGN( + itr, index_->GetIterator("", kSectionIdMaskAll, TermMatchType::PREFIX)); + EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); +} + +TEST_F(IndexTest, IndexOptimize) { + std::string prefix = "prefix"; + std::default_random_engine random; + std::vector<std::string> query_terms; + // Add 1024 hits to main index, and 1024 hits to lite index. + for (int i = 0; i < 2048; ++i) { + if (i == 1024) { + ICING_ASSERT_OK(index_->Merge()); + } + // Generate a unique term for document i. + query_terms.push_back(prefix + RandomString("abcdefg", 5, &random) + + std::to_string(i)); + TermMatchType::Code term_match_type = TermMatchType::PREFIX; + SectionId section_id = i % 5; + if (section_id == 2) { + // Make section 2 an exact section. + term_match_type = TermMatchType::EXACT_ONLY; + } + Index::Editor edit = index_->Edit(/*document_id=*/i, section_id, + term_match_type, /*namespace_id=*/0); + ICING_ASSERT_OK(edit.BufferTerm(query_terms.at(i).c_str())); + ICING_ASSERT_OK(edit.IndexAllBufferedTerms()); + index_->set_last_added_document_id(i); + } + + // Delete one document for every three documents. + DocumentId document_id = 0; + DocumentId new_last_added_document_id = kInvalidDocumentId; + std::vector<DocumentId> document_id_old_to_new; + for (int i = 0; i < 2048; ++i) { + if (i % 3 == 0) { + document_id_old_to_new.push_back(kInvalidDocumentId); + } else { + new_last_added_document_id = document_id++; + document_id_old_to_new.push_back(new_last_added_document_id); + } + } + + std::vector<DocHitInfo> exp_prefix_hits; + for (int i = 0; i < 2048; ++i) { + if (document_id_old_to_new[i] == kInvalidDocumentId) { + continue; + } + if (i % 5 == 2) { + // Section 2 is an exact section, so we should not see any hits in + // prefix search. + continue; + } + exp_prefix_hits.push_back(DocHitInfo(document_id_old_to_new[i])); + exp_prefix_hits.back().UpdateSection(/*section_id=*/i % 5, + /*hit_term_frequency=*/1); + } + std::reverse(exp_prefix_hits.begin(), exp_prefix_hits.end()); + + // Check that optimize is correct + ICING_ASSERT_OK(index_->Optimize(document_id_old_to_new)); + EXPECT_EQ(index_->last_added_document_id(), new_last_added_document_id); + // Check prefix search. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<DocHitInfo> hits, + GetHits(prefix, TermMatchType::PREFIX)); + EXPECT_THAT(hits, ContainerEq(exp_prefix_hits)); + // Check exact search. + for (int i = 0; i < 2048; ++i) { + ICING_ASSERT_OK_AND_ASSIGN( + hits, GetHits(query_terms[i], TermMatchType::EXACT_ONLY)); + if (document_id_old_to_new[i] == kInvalidDocumentId) { + EXPECT_THAT(hits, IsEmpty()); + } else { + EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfo( + document_id_old_to_new[i], + std::vector<SectionId>{(SectionId)(i % 5)}))); + } + } + + // Check that optimize does not block merge. + ICING_ASSERT_OK(index_->Merge()); + EXPECT_EQ(index_->last_added_document_id(), new_last_added_document_id); + // Check prefix search. + ICING_ASSERT_OK_AND_ASSIGN(hits, GetHits(prefix, TermMatchType::PREFIX)); + EXPECT_THAT(hits, ContainerEq(exp_prefix_hits)); + // Check exact search. + for (int i = 0; i < 2048; ++i) { + ICING_ASSERT_OK_AND_ASSIGN( + hits, GetHits(query_terms[i], TermMatchType::EXACT_ONLY)); + if (document_id_old_to_new[i] == kInvalidDocumentId) { + EXPECT_THAT(hits, IsEmpty()); + } else { + EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfo( + document_id_old_to_new[i], + std::vector<SectionId>{(SectionId)(i % 5)}))); + } + } +} + TEST_F(IndexTest, IndexCreateIOFailure) { // Create the index with mock filesystem. By default, Mock will return false, // so the first attempted file operation will fail. diff --git a/icing/legacy/index/icing-lite-index-header.h b/icing/index/lite/lite-index-header.h index ac2d3c0..dd6a0a8 100644 --- a/icing/legacy/index/icing-lite-index-header.h +++ b/icing/index/lite/lite-index-header.h @@ -16,15 +16,15 @@ #define ICING_LEGACY_INDEX_ICING_LITE_INDEX_HEADER_H_ #include "icing/legacy/core/icing-string-util.h" -#include "icing/legacy/index/icing-common-types.h" +#include "icing/store/document-id.h" namespace icing { namespace lib { // A wrapper around the actual mmapped header data. -class IcingLiteIndex_Header { +class LiteIndex_Header { public: - virtual ~IcingLiteIndex_Header() = default; + virtual ~LiteIndex_Header() = default; // Returns true if the magic of the header matches the hard-coded magic // value associated with this header format. @@ -47,7 +47,7 @@ class IcingLiteIndex_Header { virtual void Reset() = 0; }; -class IcingLiteIndex_HeaderImpl : public IcingLiteIndex_Header { +class LiteIndex_HeaderImpl : public LiteIndex_Header { public: struct HeaderData { static const uint32_t kMagic = 0x6dfba6a0; @@ -66,7 +66,7 @@ class IcingLiteIndex_HeaderImpl : public IcingLiteIndex_Header { uint32_t searchable_end; }; - explicit IcingLiteIndex_HeaderImpl(HeaderData *hdr) : hdr_(hdr) {} + explicit LiteIndex_HeaderImpl(HeaderData *hdr) : hdr_(hdr) {} bool check_magic() const override { return hdr_->magic == HeaderData::kMagic; @@ -97,7 +97,7 @@ class IcingLiteIndex_HeaderImpl : public IcingLiteIndex_Header { void Reset() override { hdr_->lite_index_crc = 0; hdr_->magic = HeaderData::kMagic; - hdr_->last_added_docid = kIcingInvalidDocId; + hdr_->last_added_docid = kInvalidDocumentId; hdr_->cur_size = 0; hdr_->searchable_end = 0; } @@ -105,7 +105,7 @@ class IcingLiteIndex_HeaderImpl : public IcingLiteIndex_Header { private: HeaderData *hdr_; }; -static_assert(24 == sizeof(IcingLiteIndex_HeaderImpl::HeaderData), +static_assert(24 == sizeof(LiteIndex_HeaderImpl::HeaderData), "sizeof(HeaderData) != 24"); } // namespace lib diff --git a/icing/legacy/index/icing-lite-index-options.cc b/icing/index/lite/lite-index-options.cc index 4bf0d38..29075f8 100644 --- a/icing/legacy/index/icing-lite-index-options.cc +++ b/icing/index/lite/lite-index-options.cc @@ -12,13 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/legacy/index/icing-lite-index-options.h" +#include "icing/index/lite/lite-index-options.h" + +#include "icing/index/lite/term-id-hit-pair.h" namespace icing { namespace lib { namespace { +constexpr int kIcingMaxVariantsPerToken = 10; // Maximum number of variants + +constexpr size_t kIcingMaxSearchableDocumentSize = (1u << 16) - 1; // 64K +// Max num tokens per document. 64KB is our original maximum (searchable) +// document size. We clip if document exceeds this. +constexpr uint32_t kIcingMaxNumTokensPerDoc = + kIcingMaxSearchableDocumentSize / 5; +constexpr uint32_t kIcingMaxNumHitsPerDocument = + kIcingMaxNumTokensPerDoc * kIcingMaxVariantsPerToken; + uint32_t CalculateHitBufferSize(uint32_t hit_buffer_want_merge_bytes) { constexpr uint32_t kHitBufferSlopMult = 2; @@ -27,7 +39,7 @@ uint32_t CalculateHitBufferSize(uint32_t hit_buffer_want_merge_bytes) { // TODO(b/111690435) Move LiteIndex::Element to a separate file so that this // can use sizeof(LiteIndex::Element) uint32_t hit_capacity_elts_with_slop = - hit_buffer_want_merge_bytes / sizeof(uint64_t); + hit_buffer_want_merge_bytes / sizeof(TermIdHitPair); // Add some slop for index variants on top of max num tokens. hit_capacity_elts_with_slop += kIcingMaxNumHitsPerDocument; hit_capacity_elts_with_slop *= kHitBufferSlopMult; @@ -51,8 +63,8 @@ IcingDynamicTrie::Options CalculateTrieOptions(uint32_t hit_buffer_size) { } // namespace -IcingLiteIndexOptions::IcingLiteIndexOptions( - const std::string& filename_base, uint32_t hit_buffer_want_merge_bytes) +LiteIndexOptions::LiteIndexOptions(const std::string& filename_base, + uint32_t hit_buffer_want_merge_bytes) : filename_base(filename_base), hit_buffer_want_merge_bytes(hit_buffer_want_merge_bytes) { hit_buffer_size = CalculateHitBufferSize(hit_buffer_want_merge_bytes); diff --git a/icing/legacy/index/icing-lite-index-options.h b/icing/index/lite/lite-index-options.h index 2922621..ae58802 100644 --- a/icing/legacy/index/icing-lite-index-options.h +++ b/icing/index/lite/lite-index-options.h @@ -15,20 +15,19 @@ #ifndef ICING_LEGACY_INDEX_ICING_LITE_INDEX_OPTIONS_H_ #define ICING_LEGACY_INDEX_ICING_LITE_INDEX_OPTIONS_H_ -#include "icing/legacy/index/icing-common-types.h" #include "icing/legacy/index/icing-dynamic-trie.h" namespace icing { namespace lib { -struct IcingLiteIndexOptions { - IcingLiteIndexOptions() = default; - // Creates IcingLiteIndexOptions based off of the specified parameters. All +struct LiteIndexOptions { + LiteIndexOptions() = default; + // Creates LiteIndexOptions based off of the specified parameters. All // other fields are calculated based on the value of // hit_buffer_want_merge_bytes and the logic in CalculateHitBufferSize and // CalculateTrieOptions. - IcingLiteIndexOptions(const std::string& filename_base, - uint32_t hit_buffer_want_merge_bytes); + LiteIndexOptions(const std::string& filename_base, + uint32_t hit_buffer_want_merge_bytes); IcingDynamicTrie::Options lexicon_options; IcingDynamicTrie::Options display_mappings_options; diff --git a/icing/index/lite/lite-index.cc b/icing/index/lite/lite-index.cc index fc40225..9622ff4 100644 --- a/icing/index/lite/lite-index.cc +++ b/icing/index/lite/lite-index.cc @@ -23,6 +23,7 @@ #include <memory> #include <string> #include <string_view> +#include <unordered_set> #include <utility> #include <vector> @@ -33,13 +34,13 @@ #include "icing/file/filesystem.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/hit/hit.h" +#include "icing/index/lite/lite-index-header.h" #include "icing/index/term-property-id.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/legacy/core/icing-timer.h" #include "icing/legacy/index/icing-array-storage.h" #include "icing/legacy/index/icing-dynamic-trie.h" #include "icing/legacy/index/icing-filesystem.h" -#include "icing/legacy/index/icing-lite-index-header.h" #include "icing/legacy/index/icing-mmapper.h" #include "icing/proto/term.pb.h" #include "icing/schema/section.h" @@ -60,7 +61,7 @@ std::string MakeHitBufferFilename(const std::string& filename_base) { return filename_base + "hb"; } -size_t header_size() { return sizeof(IcingLiteIndex_HeaderImpl::HeaderData); } +size_t header_size() { return sizeof(LiteIndex_HeaderImpl::HeaderData); } } // namespace @@ -156,8 +157,8 @@ libtextclassifier3::Status LiteIndex::Initialize() { // Set up header. header_mmap_.Remap(hit_buffer_fd_.get(), 0, header_size()); - header_ = std::make_unique<IcingLiteIndex_HeaderImpl>( - reinterpret_cast<IcingLiteIndex_HeaderImpl::HeaderData*>( + header_ = std::make_unique<LiteIndex_HeaderImpl>( + reinterpret_cast<LiteIndex_HeaderImpl::HeaderData*>( header_mmap_.address())); header_->Reset(); @@ -171,8 +172,8 @@ libtextclassifier3::Status LiteIndex::Initialize() { UpdateChecksum(); } else { header_mmap_.Remap(hit_buffer_fd_.get(), 0, header_size()); - header_ = std::make_unique<IcingLiteIndex_HeaderImpl>( - reinterpret_cast<IcingLiteIndex_HeaderImpl::HeaderData*>( + header_ = std::make_unique<LiteIndex_HeaderImpl>( + reinterpret_cast<LiteIndex_HeaderImpl::HeaderData*>( header_mmap_.address())); if (!hit_buffer_.Init(hit_buffer_fd_.get(), header_padded_size, true, @@ -197,8 +198,7 @@ libtextclassifier3::Status LiteIndex::Initialize() { } } - ICING_VLOG(2) << IcingStringUtil::StringPrintf("Lite index init ok in %.3fms", - timer.Elapsed() * 1000); + ICING_VLOG(2) << "Lite index init ok in " << timer.Elapsed() * 1000 << "ms"; return status; error: @@ -230,8 +230,7 @@ Crc32 LiteIndex::ComputeChecksum() { Crc32 all_crc(header_->CalculateHeaderCrc()); all_crc.Append(std::string_view(reinterpret_cast<const char*>(dependent_crcs), sizeof(dependent_crcs))); - ICING_VLOG(2) << IcingStringUtil::StringPrintf( - "Lite index crc computed in %.3fms", timer.Elapsed() * 1000); + ICING_VLOG(2) << "Lite index crc computed in " << timer.Elapsed() * 1000 << "ms"; return all_crc; } @@ -246,8 +245,7 @@ libtextclassifier3::Status LiteIndex::Reset() { header_->Reset(); UpdateChecksum(); - ICING_VLOG(2) << IcingStringUtil::StringPrintf("Lite index clear in %.3fms", - timer.Elapsed() * 1000); + ICING_VLOG(2) << "Lite index clear in " << timer.Elapsed() * 1000 << "ms"; return libtextclassifier3::Status::OK; } @@ -439,34 +437,38 @@ IndexStorageInfoProto LiteIndex::GetStorageInfo( return storage_info; } -uint32_t LiteIndex::Seek(uint32_t term_id) { +void LiteIndex::SortHits() { // Make searchable by sorting by hit buffer. uint32_t sort_len = header_->cur_size() - header_->searchable_end(); - if (sort_len > 0) { - IcingTimer timer; - - auto* array_start = - hit_buffer_.GetMutableMem<TermIdHitPair::Value>(0, header_->cur_size()); - TermIdHitPair::Value* sort_start = array_start + header_->searchable_end(); - std::sort(sort_start, array_start + header_->cur_size()); - - // Now merge with previous region. Since the previous region is already - // sorted and deduplicated, optimize the merge by skipping everything less - // than the new region's smallest value. - if (header_->searchable_end() > 0) { - std::inplace_merge(array_start, array_start + header_->searchable_end(), - array_start + header_->cur_size()); - } - ICING_VLOG(2) << IcingStringUtil::StringPrintf( - "Lite index sort and merge %u into %u in %.3fms", sort_len, - header_->searchable_end(), timer.Elapsed() * 1000); - - // Now the entire array is sorted. - header_->set_searchable_end(header_->cur_size()); + if (sort_len <= 0) { + return; + } + IcingTimer timer; - // Update crc in-line. - UpdateChecksum(); + auto* array_start = + hit_buffer_.GetMutableMem<TermIdHitPair::Value>(0, header_->cur_size()); + TermIdHitPair::Value* sort_start = array_start + header_->searchable_end(); + std::sort(sort_start, array_start + header_->cur_size()); + + // Now merge with previous region. Since the previous region is already + // sorted and deduplicated, optimize the merge by skipping everything less + // than the new region's smallest value. + if (header_->searchable_end() > 0) { + std::inplace_merge(array_start, array_start + header_->searchable_end(), + array_start + header_->cur_size()); } + ICING_VLOG(2) << "Lite index sort and merge " << sort_len << " into " + << header_->searchable_end() << " in " << timer.Elapsed() * 1000 << "ms"; + + // Now the entire array is sorted. + header_->set_searchable_end(header_->cur_size()); + + // Update crc in-line. + UpdateChecksum(); +} + +uint32_t LiteIndex::Seek(uint32_t term_id) { + SortHits(); // Binary search for our term_id. Make sure we get the first // element. Using kBeginSortValue ensures this for the hit value. @@ -480,5 +482,86 @@ uint32_t LiteIndex::Seek(uint32_t term_id) { return ptr - array; } +libtextclassifier3::Status LiteIndex::Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + const TermIdCodec* term_id_codec) { + if (header_->cur_size() == 0) { + return libtextclassifier3::Status::OK; + } + // Sort the hits so that hits with the same term id will be grouped together, + // which helps later to determine which terms will be unused after compaction. + SortHits(); + uint32_t new_size = 0; + // The largest document id after translating hits. + DocumentId largest_document_id = kInvalidDocumentId; + uint32_t curr_term_id = 0; + uint32_t curr_tvi = 0; + std::unordered_set<uint32_t> tvi_to_delete; + for (uint32_t idx = 0; idx < header_->cur_size(); ++idx) { + TermIdHitPair term_id_hit_pair( + hit_buffer_.array_cast<TermIdHitPair>()[idx]); + if (idx == 0 || term_id_hit_pair.term_id() != curr_term_id) { + curr_term_id = term_id_hit_pair.term_id(); + ICING_ASSIGN_OR_RETURN(TermIdCodec::DecodedTermInfo term_info, + term_id_codec->DecodeTermInfo(curr_term_id)); + curr_tvi = term_info.tvi; + // Mark the property of the current term as not having hits in prefix + // section. The property will be set below if there are any valid hits + // from a prefix section. + lexicon_.ClearProperty(curr_tvi, GetHasHitsInPrefixSectionPropertyId()); + // Add curr_tvi to tvi_to_delete. It will be removed from tvi_to_delete + // below if there are any valid hits pointing to that termid. + tvi_to_delete.insert(curr_tvi); + } + DocumentId new_document_id = + document_id_old_to_new[term_id_hit_pair.hit().document_id()]; + if (new_document_id == kInvalidDocumentId) { + continue; + } + if (largest_document_id == kInvalidDocumentId || + new_document_id > largest_document_id) { + largest_document_id = new_document_id; + } + if (term_id_hit_pair.hit().is_in_prefix_section()) { + lexicon_.SetProperty(curr_tvi, GetHasHitsInPrefixSectionPropertyId()); + } + tvi_to_delete.erase(curr_tvi); + TermIdHitPair new_term_id_hit_pair( + term_id_hit_pair.term_id(), + Hit::TranslateHit(term_id_hit_pair.hit(), new_document_id)); + // Rewriting the hit_buffer in place. + // new_size is weakly less than idx so we are okay to overwrite the entry at + // new_size, and valp should never be nullptr since it is within the already + // allocated region of hit_buffer_. + TermIdHitPair::Value* valp = + hit_buffer_.GetMutableMem<TermIdHitPair::Value>(new_size++, 1); + *valp = new_term_id_hit_pair.value(); + } + header_->set_cur_size(new_size); + header_->set_searchable_end(new_size); + header_->set_last_added_docid(largest_document_id); + + // Delete unused terms. + std::unordered_set<std::string> terms_to_delete; + for (IcingDynamicTrie::Iterator term_iter(lexicon_, /*prefix=*/""); + term_iter.IsValid(); term_iter.Advance()) { + if (tvi_to_delete.find(term_iter.GetValueIndex()) != tvi_to_delete.end()) { + terms_to_delete.insert(term_iter.GetKey()); + } + } + for (const std::string& term : terms_to_delete) { + // Mark "term" as deleted. This won't actually free space in the lexicon. It + // will simply make it impossible to Find "term" in subsequent calls (which + // saves an unnecessary search through the hit buffer). This is acceptable + // because the free space will eventually be reclaimed the next time that + // the lite index is merged with the main index. + if (!lexicon_.Delete(term)) { + return absl_ports::InternalError( + "Could not delete invalid terms in lite lexicon during compaction."); + } + } + return libtextclassifier3::Status::OK; +} + } // namespace lib } // namespace icing diff --git a/icing/index/lite/lite-index.h b/icing/index/lite/lite-index.h index 42d69f8..64b5881 100644 --- a/icing/index/lite/lite-index.h +++ b/icing/index/lite/lite-index.h @@ -30,12 +30,13 @@ #include "icing/file/filesystem.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/hit/hit.h" +#include "icing/index/lite/lite-index-header.h" +#include "icing/index/lite/lite-index-options.h" #include "icing/index/lite/term-id-hit-pair.h" +#include "icing/index/term-id-codec.h" #include "icing/legacy/index/icing-array-storage.h" #include "icing/legacy/index/icing-dynamic-trie.h" #include "icing/legacy/index/icing-filesystem.h" -#include "icing/legacy/index/icing-lite-index-header.h" -#include "icing/legacy/index/icing-lite-index-options.h" #include "icing/legacy/index/icing-mmapper.h" #include "icing/proto/debug.pb.h" #include "icing/proto/storage.pb.h" @@ -53,7 +54,7 @@ namespace lib { class LiteIndex { public: // An entry in the hit buffer. - using Options = IcingLiteIndexOptions; + using Options = LiteIndexOptions; // Updates checksum of subcomponents. ~LiteIndex(); @@ -260,6 +261,16 @@ class LiteIndex { IndexStorageInfoProto GetStorageInfo( IndexStorageInfoProto storage_info) const; + // Reduces internal file sizes by reclaiming space of deleted documents. + // + // Returns: + // OK on success + // INTERNAL_ERROR on IO error, this indicates that the index may be in an + // invalid state and should be cleared. + libtextclassifier3::Status Optimize( + const std::vector<DocumentId>& document_id_old_to_new, + const TermIdCodec* term_id_codec); + private: static IcingDynamicTrie::RuntimeOptions MakeTrieRuntimeOptions(); @@ -279,6 +290,9 @@ class LiteIndex { // Sets the computed checksum in the header void UpdateChecksum(); + // Sort hits stored in the index. + void SortHits(); + // Returns the position of the first element with term_id, or the size of the // hit buffer if term_id is not present. uint32_t Seek(uint32_t term_id); @@ -301,7 +315,7 @@ class LiteIndex { IcingMMapper header_mmap_; // Wrapper around the mmapped header that contains stats on the lite index. - std::unique_ptr<IcingLiteIndex_Header> header_; + std::unique_ptr<LiteIndex_Header> header_; // Options used to initialize the LiteIndex. const Options options_; diff --git a/icing/index/main/flash-index-storage.cc b/icing/index/main/flash-index-storage.cc index dabff28..33dacf9 100644 --- a/icing/index/main/flash-index-storage.cc +++ b/icing/index/main/flash-index-storage.cc @@ -133,9 +133,7 @@ bool FlashIndexStorage::CreateHeader() { posting_list_bytes /= 2) { uint32_t aligned_posting_list_bytes = (posting_list_bytes / sizeof(Hit) * sizeof(Hit)); - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Block size %u: %u", header_block_->header()->num_index_block_infos, - aligned_posting_list_bytes); + ICING_VLOG(1) << "Block size " << header_block_->header()->num_index_block_infos << ": " << aligned_posting_list_bytes; // Initialize free list to empty. HeaderBlock::Header::IndexBlockInfo* block_info = @@ -169,23 +167,18 @@ bool FlashIndexStorage::OpenHeader(int64_t file_size) { return false; } if (file_size % read_header.header()->block_size != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Index size %" PRIu64 " not a multiple of block size %u", file_size, - read_header.header()->block_size); + ICING_LOG(ERROR) << "Index size " << file_size << " not a multiple of block size " << read_header.header()->block_size; return false; } if (file_size < static_cast<int64_t>(read_header.header()->block_size)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Index size %" PRIu64 " shorter than block size %u", file_size, - read_header.header()->block_size); + ICING_LOG(ERROR) << "Index size " << file_size << " shorter than block size " << read_header.header()->block_size; return false; } if (read_header.header()->block_size % getpagesize() != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Block size %u is not a multiple of page size %d", - read_header.header()->block_size, getpagesize()); + ICING_LOG(ERROR) << "Block size " << read_header.header()->block_size + << " is not a multiple of page size " << getpagesize(); return false; } num_blocks_ = file_size / read_header.header()->block_size; @@ -215,11 +208,10 @@ bool FlashIndexStorage::OpenHeader(int64_t file_size) { int posting_list_bytes = header_block_->header()->index_block_infos[i].posting_list_bytes; if (posting_list_bytes % sizeof(Hit) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Posting list size misaligned, index %u, size %u, hit %zu, " - "file_size %" PRIu64, - i, header_block_->header()->index_block_infos[i].posting_list_bytes, - sizeof(Hit), file_size); + ICING_LOG(ERROR) << "Posting list size misaligned, index " << i + << ", size " + << header_block_->header()->index_block_infos[i].posting_list_bytes + << ", hit " << sizeof(Hit) << ", file_size " << file_size; return false; } } @@ -229,8 +221,7 @@ bool FlashIndexStorage::OpenHeader(int64_t file_size) { bool FlashIndexStorage::PersistToDisk() { // First, write header. if (!header_block_->Write(block_fd_.get())) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Write index header failed: %s", strerror(errno)); + ICING_LOG(ERROR) << "Write index header failed: " << strerror(errno); return false; } @@ -456,8 +447,7 @@ void FlashIndexStorage::FreePostingList(PostingListHolder holder) { int FlashIndexStorage::GrowIndex() { if (num_blocks_ >= kMaxBlockIndex) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Reached max block index %u", - kMaxBlockIndex); + ICING_VLOG(1) << "Reached max block index " << kMaxBlockIndex; return kInvalidBlockIndex; } @@ -465,8 +455,7 @@ int FlashIndexStorage::GrowIndex() { if (!filesystem_->Grow( block_fd_.get(), static_cast<uint64_t>(num_blocks_ + 1) * block_size())) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Error growing index file: %s", strerror(errno)); + ICING_VLOG(1) << "Error growing index file: " << strerror(errno); return kInvalidBlockIndex; } diff --git a/icing/index/main/main-index.cc b/icing/index/main/main-index.cc index 158c287..9f591c0 100644 --- a/icing/index/main/main-index.cc +++ b/icing/index/main/main-index.cc @@ -17,9 +17,11 @@ #include <cstring> #include <memory> #include <string> +#include <unordered_set> #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" +#include "icing/file/destructible-directory.h" #include "icing/index/main/index-block.h" #include "icing/index/term-id-codec.h" #include "icing/index/term-property-id.h" @@ -84,35 +86,40 @@ FindTermResult FindShortestValidTermWithPrefixHits( } // namespace +MainIndex::MainIndex(const std::string& index_directory, + const Filesystem* filesystem, + const IcingFilesystem* icing_filesystem) + : base_dir_(index_directory), + filesystem_(filesystem), + icing_filesystem_(icing_filesystem) {} + libtextclassifier3::StatusOr<std::unique_ptr<MainIndex>> MainIndex::Create( const std::string& index_directory, const Filesystem* filesystem, const IcingFilesystem* icing_filesystem) { ICING_RETURN_ERROR_IF_NULL(filesystem); ICING_RETURN_ERROR_IF_NULL(icing_filesystem); - auto main_index = std::make_unique<MainIndex>(); - ICING_RETURN_IF_ERROR( - main_index->Init(index_directory, filesystem, icing_filesystem)); + std::unique_ptr<MainIndex> main_index( + new MainIndex(index_directory, filesystem, icing_filesystem)); + ICING_RETURN_IF_ERROR(main_index->Init()); return main_index; } // TODO(b/139087650) : Migrate off of IcingFilesystem. -libtextclassifier3::Status MainIndex::Init( - const std::string& index_directory, const Filesystem* filesystem, - const IcingFilesystem* icing_filesystem) { - if (!filesystem->CreateDirectoryRecursively(index_directory.c_str())) { +libtextclassifier3::Status MainIndex::Init() { + if (!filesystem_->CreateDirectoryRecursively(base_dir_.c_str())) { return absl_ports::InternalError("Unable to create main index directory."); } - std::string flash_index_file = index_directory + "/main_index"; + std::string flash_index_file = base_dir_ + "/main_index"; ICING_ASSIGN_OR_RETURN( FlashIndexStorage flash_index, - FlashIndexStorage::Create(flash_index_file, filesystem)); + FlashIndexStorage::Create(flash_index_file, filesystem_)); flash_index_storage_ = std::make_unique<FlashIndexStorage>(std::move(flash_index)); - std::string lexicon_file = index_directory + "/main-lexicon"; + std::string lexicon_file = base_dir_ + "/main-lexicon"; IcingDynamicTrie::RuntimeOptions runtime_options; main_lexicon_ = std::make_unique<IcingDynamicTrie>( - lexicon_file, runtime_options, icing_filesystem); + lexicon_file, runtime_options, icing_filesystem_); IcingDynamicTrie::Options lexicon_options; if (!main_lexicon_->CreateIfNotExist(lexicon_options) || !main_lexicon_->Init()) { @@ -490,8 +497,7 @@ libtextclassifier3::Status MainIndex::AddHits( } // Now copy remaining backfills. - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Remaining backfills %zu", - backfill_map.size()); + ICING_VLOG(1) << "Remaining backfills " << backfill_map.size(); for (auto other_tvi_main_tvi_pair : backfill_map) { PostingListIdentifier backfill_posting_list_id = PostingListIdentifier::kInvalid; @@ -524,9 +530,7 @@ libtextclassifier3::Status MainIndex::AddHitsForTerm( std::unique_ptr<PostingListAccessor> pl_accessor; if (posting_list_id.is_valid()) { if (posting_list_id.block_index() >= flash_index_storage_->num_blocks()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Index dropped hits. Invalid block index %u >= %u", - posting_list_id.block_index(), flash_index_storage_->num_blocks()); + ICING_LOG(ERROR) << "Index dropped hits. Invalid block index " << posting_list_id.block_index() << " >= " << flash_index_storage_->num_blocks(); // TODO(b/159918304) : Consider revising the checksumming strategy in the // main index. Providing some mechanism to check for corruption - either // during initialization or some later time would allow us to avoid @@ -633,5 +637,142 @@ std::string MainIndex::GetDebugInfo(DebugInfoVerbosity::Code verbosity) const { return res; } +libtextclassifier3::Status MainIndex::Optimize( + const std::vector<DocumentId>& document_id_old_to_new) { + std::string temporary_index_dir_path = base_dir_ + "_temp"; + if (!filesystem_->DeleteDirectoryRecursively( + temporary_index_dir_path.c_str())) { + ICING_LOG(ERROR) << "Recursively deleting " << temporary_index_dir_path; + return absl_ports::InternalError( + "Unable to delete temp directory to prepare to build new index."); + } + + DestructibleDirectory temporary_index_dir( + filesystem_, std::move(temporary_index_dir_path)); + if (!temporary_index_dir.is_valid()) { + return absl_ports::InternalError( + "Unable to create temp directory to build new index."); + } + + ICING_ASSIGN_OR_RETURN(std::unique_ptr<MainIndex> new_index, + MainIndex::Create(temporary_index_dir.dir(), + filesystem_, icing_filesystem_)); + ICING_RETURN_IF_ERROR(TransferIndex(document_id_old_to_new, new_index.get())); + ICING_RETURN_IF_ERROR(new_index->PersistToDisk()); + new_index = nullptr; + flash_index_storage_ = nullptr; + main_lexicon_ = nullptr; + + if (!filesystem_->SwapFiles(temporary_index_dir.dir().c_str(), + base_dir_.c_str())) { + return absl_ports::InternalError( + "Unable to apply new index due to failed swap!"); + } + + // Reinitialize the index so that flash_index_storage_ and main_lexicon_ are + // properly updated. + return Init(); +} + +libtextclassifier3::StatusOr<DocumentId> MainIndex::TransferAndAddHits( + const std::vector<DocumentId>& document_id_old_to_new, const char* term, + PostingListAccessor& old_pl_accessor, MainIndex* new_index) { + std::vector<Hit> new_hits; + bool has_no_exact_hits = true; + bool has_hits_in_prefix_section = false; + // The largest document id after translating hits. + DocumentId largest_document_id = kInvalidDocumentId; + ICING_ASSIGN_OR_RETURN(std::vector<Hit> tmp, + old_pl_accessor.GetNextHitsBatch()); + while (!tmp.empty()) { + for (const Hit& hit : tmp) { + DocumentId new_document_id = document_id_old_to_new[hit.document_id()]; + // Transfer the document id of the hit, if the document is not deleted + // or outdated. + if (new_document_id != kInvalidDocumentId) { + if (hit.is_in_prefix_section()) { + has_hits_in_prefix_section = true; + } + if (!hit.is_prefix_hit()) { + has_no_exact_hits = false; + } + if (largest_document_id == kInvalidDocumentId || + new_document_id > largest_document_id) { + largest_document_id = new_document_id; + } + new_hits.push_back(Hit::TranslateHit(hit, new_document_id)); + } + } + ICING_ASSIGN_OR_RETURN(tmp, old_pl_accessor.GetNextHitsBatch()); + } + // A term without exact hits indicates that it is a purely backfill term. If + // the term is not branching in the new trie, it means backfilling is no + // longer necessary, so that we can skip. + if (new_hits.empty() || + (has_no_exact_hits && !new_index->main_lexicon_->IsBranchingTerm(term))) { + return largest_document_id; + } + + ICING_ASSIGN_OR_RETURN( + PostingListAccessor hit_accum, + PostingListAccessor::Create(new_index->flash_index_storage_.get())); + for (auto itr = new_hits.rbegin(); itr != new_hits.rend(); ++itr) { + ICING_RETURN_IF_ERROR(hit_accum.PrependHit(*itr)); + } + PostingListAccessor::FinalizeResult result = + PostingListAccessor::Finalize(std::move(hit_accum)); + uint32_t tvi; + if (!result.id.is_valid() || + !new_index->main_lexicon_->Insert(term, &result.id, &tvi, + /*replace=*/false)) { + return absl_ports::InternalError( + absl_ports::StrCat("Could not transfer main index for term: ", term)); + } + if (has_no_exact_hits && !new_index->main_lexicon_->SetProperty( + tvi, GetHasNoExactHitsPropertyId())) { + return absl_ports::InternalError("Setting prefix prop failed"); + } + if (has_hits_in_prefix_section && + !new_index->main_lexicon_->SetProperty( + tvi, GetHasHitsInPrefixSectionPropertyId())) { + return absl_ports::InternalError("Setting prefix prop failed"); + } + return largest_document_id; +} + +libtextclassifier3::Status MainIndex::TransferIndex( + const std::vector<DocumentId>& document_id_old_to_new, + MainIndex* new_index) { + DocumentId largest_document_id = kInvalidDocumentId; + for (IcingDynamicTrie::Iterator term_itr(*main_lexicon_, /*prefix=*/"", + /*reverse=*/true); + term_itr.IsValid(); term_itr.Advance()) { + PostingListIdentifier posting_list_id = PostingListIdentifier::kInvalid; + memcpy(&posting_list_id, term_itr.GetValue(), sizeof(posting_list_id)); + if (posting_list_id == PostingListIdentifier::kInvalid) { + // Why? + ICING_LOG(ERROR) + << "Got invalid posting_list_id from previous main index"; + continue; + } + ICING_ASSIGN_OR_RETURN(PostingListAccessor pl_accessor, + PostingListAccessor::CreateFromExisting( + flash_index_storage_.get(), posting_list_id)); + ICING_ASSIGN_OR_RETURN( + DocumentId curr_largest_document_id, + TransferAndAddHits(document_id_old_to_new, term_itr.GetKey(), + pl_accessor, new_index)); + if (curr_largest_document_id == kInvalidDocumentId) { + continue; + } + if (largest_document_id == kInvalidDocumentId || + curr_largest_document_id > largest_document_id) { + largest_document_id = curr_largest_document_id; + } + } + new_index->flash_index_storage_->set_last_indexed_docid(largest_document_id); + return libtextclassifier3::Status::OK; +} + } // namespace lib } // namespace icing diff --git a/icing/index/main/main-index.h b/icing/index/main/main-index.h index d6f7d5f..15030b0 100644 --- a/icing/index/main/main-index.h +++ b/icing/index/main/main-index.h @@ -188,10 +188,20 @@ class MainIndex { // postings lists. std::string GetDebugInfo(DebugInfoVerbosity::Code verbosity) const; + // Reduces internal file sizes by reclaiming space of deleted documents. + // + // Returns: + // OK on success + // INTERNAL_ERROR on IO error, this indicates that the index may be in an + // invalid state and should be cleared. + libtextclassifier3::Status Optimize( + const std::vector<DocumentId>& document_id_old_to_new); + private: - libtextclassifier3::Status Init(const std::string& index_directory, - const Filesystem* filesystem, - const IcingFilesystem* icing_filesystem); + MainIndex(const std::string& index_directory, const Filesystem* filesystem, + const IcingFilesystem* icing_filesystem); + + libtextclassifier3::Status Init(); // Helpers for merging the lexicon // Add all 'backfill' branch points. Backfill branch points are prefix @@ -287,6 +297,27 @@ class MainIndex { PostingListIdentifier backfill_posting_list_id, PostingListAccessor* hit_accum); + // Transfer hits from old_pl_accessor to new_index for term. + // + // Returns: + // largest document id added to the translated posting list, on success + // INTERNAL_ERROR on IO error + static libtextclassifier3::StatusOr<DocumentId> TransferAndAddHits( + const std::vector<DocumentId>& document_id_old_to_new, const char* term, + PostingListAccessor& old_pl_accessor, MainIndex* new_index); + + // Transfer hits from the current main index to new_index. + // + // Returns: + // OK on success + // INTERNAL_ERROR on IO error + libtextclassifier3::Status TransferIndex( + const std::vector<DocumentId>& document_id_old_to_new, + MainIndex* new_index); + + std::string base_dir_; + const Filesystem* filesystem_; + const IcingFilesystem* icing_filesystem_; std::unique_ptr<FlashIndexStorage> flash_index_storage_; std::unique_ptr<IcingDynamicTrie> main_lexicon_; }; diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc index 17bb059..c9e7127 100644 --- a/icing/jni/icing-search-engine-jni.cc +++ b/icing/jni/icing-search-engine-jni.cc @@ -17,12 +17,11 @@ #include <string> #include <utility> +#include <google/protobuf/message_lite.h> +#include "icing/icing-search-engine.h" #include "icing/jni/jni-cache.h" #include "icing/jni/scoped-primitive-array-critical.h" #include "icing/jni/scoped-utf-chars.h" -#include <google/protobuf/message_lite.h> -#include "icing/absl_ports/status_imports.h" -#include "icing/icing-search-engine.h" #include "icing/proto/document.pb.h" #include "icing/proto/initialize.pb.h" #include "icing/proto/optimize.pb.h" diff --git a/icing/jni/jni-cache.cc b/icing/jni/jni-cache.cc index 9b75db6..1804b9a 100644 --- a/icing/jni/jni-cache.cc +++ b/icing/jni/jni-cache.cc @@ -159,8 +159,7 @@ libtextclassifier3::StatusOr<std::unique_ptr<JniCache>> JniCache::Create( // BreakIteratorBatcher ICING_GET_CLASS_OR_RETURN_NULL( - breakiterator, - "com/google/android/icing/BreakIteratorBatcher"); + breakiterator, "com/google/android/icing/BreakIteratorBatcher"); ICING_GET_METHOD(breakiterator, constructor, "<init>", "(Ljava/util/Locale;)V"); ICING_GET_METHOD(breakiterator, settext, "setText", "(Ljava/lang/String;)V"); diff --git a/icing/jni/scoped-primitive-array-critical_test.cc b/icing/jni/scoped-primitive-array-critical_test.cc new file mode 100644 index 0000000..3655378 --- /dev/null +++ b/icing/jni/scoped-primitive-array-critical_test.cc @@ -0,0 +1,140 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/jni/scoped-primitive-array-critical.h" + +#include <jni.h> + +#include <utility> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "util/java/mock_jni_env.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::Eq; +using ::testing::IsNull; +using ::testing::Return; +using util::java::test::MockJNIEnv; + +TEST(ScopedJniClassesTest, ScopedPrimitiveArrayNull) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array( + env_mock.get(), /*array=*/nullptr); + EXPECT_THAT(scoped_primitive_array.data(), IsNull()); + EXPECT_THAT(scoped_primitive_array.size(), Eq(0)); + + // Move construct a scoped utf chars + ScopedPrimitiveArrayCritical<uint8_t> moved_scoped_primitive_array( + std::move(scoped_primitive_array)); + EXPECT_THAT(moved_scoped_primitive_array.data(), IsNull()); + EXPECT_THAT(moved_scoped_primitive_array.size(), Eq(0)); + + // Move assign a scoped utf chars + ScopedPrimitiveArrayCritical<uint8_t> move_assigned_scoped_primitive_array = + std::move(moved_scoped_primitive_array); + EXPECT_THAT(move_assigned_scoped_primitive_array.data(), IsNull()); + EXPECT_THAT(move_assigned_scoped_primitive_array.size(), Eq(0)); +} + +TEST(ScopedJniClassesTest, ScopedPrimitiveArrayConstruction) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + jarray fake_jarray = reinterpret_cast<jarray>(-303); + uint8_t fake_array[] = {1, 8, 63, 90}; + ON_CALL(*env_mock, GetPrimitiveArrayCritical(Eq(fake_jarray), IsNull())) + .WillByDefault(Return(fake_array)); + ON_CALL(*env_mock, GetArrayLength(Eq(fake_jarray))).WillByDefault(Return(4)); + + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array( + env_mock.get(), + /*array=*/fake_jarray); + EXPECT_THAT(scoped_primitive_array.data(), Eq(fake_array)); + EXPECT_THAT(scoped_primitive_array.size(), Eq(4)); + + EXPECT_CALL(*env_mock, ReleasePrimitiveArrayCritical(Eq(fake_jarray), + Eq(fake_array), Eq(0))) + .Times(1); +} + +TEST(ScopedJniClassesTest, ScopedPrimitiveArrayMoveConstruction) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + jarray fake_jarray = reinterpret_cast<jarray>(-303); + uint8_t fake_array[] = {1, 8, 63, 90}; + ON_CALL(*env_mock, GetPrimitiveArrayCritical(Eq(fake_jarray), IsNull())) + .WillByDefault(Return(fake_array)); + ON_CALL(*env_mock, GetArrayLength(Eq(fake_jarray))).WillByDefault(Return(4)); + + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array( + env_mock.get(), + /*array=*/fake_jarray); + + // Move construct a scoped utf chars + ScopedPrimitiveArrayCritical<uint8_t> moved_scoped_primitive_array( + std::move(scoped_primitive_array)); + EXPECT_THAT(moved_scoped_primitive_array.data(), Eq(fake_array)); + EXPECT_THAT(moved_scoped_primitive_array.size(), Eq(4)); + + EXPECT_CALL(*env_mock, ReleasePrimitiveArrayCritical(Eq(fake_jarray), + Eq(fake_array), Eq(0))) + .Times(1); +} + +TEST(ScopedJniClassesTest, ScopedPrimitiveArrayMoveAssignment) { + // Setup the mock to return: + // {1, 8, 63, 90} for jstring (-303) + // {5, 9, 82} for jstring (-505) + auto env_mock = std::make_unique<MockJNIEnv>(); + jarray fake_jarray1 = reinterpret_cast<jarray>(-303); + uint8_t fake_array1[] = {1, 8, 63, 90}; + ON_CALL(*env_mock, GetPrimitiveArrayCritical(Eq(fake_jarray1), IsNull())) + .WillByDefault(Return(fake_array1)); + ON_CALL(*env_mock, GetArrayLength(Eq(fake_jarray1))).WillByDefault(Return(4)); + + jarray fake_jarray2 = reinterpret_cast<jarray>(-505); + uint8_t fake_array2[] = {5, 9, 82}; + ON_CALL(*env_mock, GetPrimitiveArrayCritical(Eq(fake_jarray2), IsNull())) + .WillByDefault(Return(fake_array2)); + ON_CALL(*env_mock, GetArrayLength(Eq(fake_jarray2))).WillByDefault(Return(3)); + + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array1( + env_mock.get(), + /*array=*/fake_jarray1); + ScopedPrimitiveArrayCritical<uint8_t> scoped_primitive_array2( + env_mock.get(), + /*array=*/fake_jarray2); + + // Move assign a scoped utf chars + scoped_primitive_array2 = std::move(scoped_primitive_array1); + EXPECT_THAT(scoped_primitive_array2.data(), Eq(fake_array1)); + EXPECT_THAT(scoped_primitive_array2.size(), Eq(4)); + + EXPECT_CALL(*env_mock, ReleasePrimitiveArrayCritical(Eq(fake_jarray1), + Eq(fake_array1), Eq(0))) + .Times(1); + EXPECT_CALL(*env_mock, ReleasePrimitiveArrayCritical(Eq(fake_jarray2), + Eq(fake_array2), Eq(0))) + .Times(1); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/jni/scoped-utf-chars.h b/icing/jni/scoped-utf-chars.h index 2dafcc1..5a3ac6a 100644 --- a/icing/jni/scoped-utf-chars.h +++ b/icing/jni/scoped-utf-chars.h @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. - #ifndef ICING_JNI_SCOPED_UTF_CHARS_H_ #define ICING_JNI_SCOPED_UTF_CHARS_H_ diff --git a/icing/jni/scoped-utf-chars_test.cc b/icing/jni/scoped-utf-chars_test.cc new file mode 100644 index 0000000..d249f69 --- /dev/null +++ b/icing/jni/scoped-utf-chars_test.cc @@ -0,0 +1,126 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/jni/scoped-utf-chars.h" + +#include <jni.h> + +#include <string> +#include <utility> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "util/java/mock_jni_env.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::Eq; +using ::testing::IsNull; +using ::testing::Return; +using util::java::test::MockJNIEnv; + +TEST(ScopedJniClassesTest, ScopedUtfCharsNull) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + ScopedUtfChars scoped_utf_chars(env_mock.get(), /*s=*/nullptr); + EXPECT_THAT(scoped_utf_chars.c_str(), IsNull()); + EXPECT_THAT(scoped_utf_chars.size(), Eq(0)); + + // Move construct a scoped utf chars + ScopedUtfChars moved_scoped_utf_chars(std::move(scoped_utf_chars)); + EXPECT_THAT(moved_scoped_utf_chars.c_str(), IsNull()); + EXPECT_THAT(moved_scoped_utf_chars.size(), Eq(0)); + + // Move assign a scoped utf chars + ScopedUtfChars move_assigned_scoped_utf_chars = + std::move(moved_scoped_utf_chars); + EXPECT_THAT(move_assigned_scoped_utf_chars.c_str(), IsNull()); + EXPECT_THAT(move_assigned_scoped_utf_chars.size(), Eq(0)); +} + +TEST(ScopedJniClassesTest, ScopedUtfCharsConstruction) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + jstring fake_jstring = reinterpret_cast<jstring>(-303); + std::string fake_string = "foo"; + ON_CALL(*env_mock, GetStringUTFChars(Eq(fake_jstring), IsNull())) + .WillByDefault(Return(fake_string.c_str())); + + ScopedUtfChars scoped_utf_chars(env_mock.get(), /*s=*/fake_jstring); + EXPECT_THAT(scoped_utf_chars.c_str(), Eq(fake_string.c_str())); + EXPECT_THAT(scoped_utf_chars.size(), Eq(3)); + + EXPECT_CALL(*env_mock, + ReleaseStringUTFChars(Eq(fake_jstring), Eq(fake_string.c_str()))) + .Times(1); +} + +TEST(ScopedJniClassesTest, ScopedUtfCharsMoveConstruction) { + auto env_mock = std::make_unique<MockJNIEnv>(); + // Construct a scoped utf chars normally. + jstring fake_jstring = reinterpret_cast<jstring>(-303); + std::string fake_string = "foo"; + ON_CALL(*env_mock, GetStringUTFChars(Eq(fake_jstring), IsNull())) + .WillByDefault(Return(fake_string.c_str())); + + ScopedUtfChars scoped_utf_chars(env_mock.get(), /*s=*/fake_jstring); + + // Move construct a scoped utf chars + ScopedUtfChars moved_scoped_utf_chars(std::move(scoped_utf_chars)); + EXPECT_THAT(moved_scoped_utf_chars.c_str(), Eq(fake_string.c_str())); + EXPECT_THAT(moved_scoped_utf_chars.size(), Eq(3)); + + EXPECT_CALL(*env_mock, + ReleaseStringUTFChars(Eq(fake_jstring), Eq(fake_string.c_str()))) + .Times(1); +} + +TEST(ScopedJniClassesTest, ScopedUtfCharsMoveAssignment) { + // Setup the mock to return: + // "foo" for jstring (-303) + // "bar baz" for jstring (-505) + auto env_mock = std::make_unique<MockJNIEnv>(); + jstring fake_jstring1 = reinterpret_cast<jstring>(-303); + std::string fake_string1 = "foo"; + ON_CALL(*env_mock, GetStringUTFChars(Eq(fake_jstring1), IsNull())) + .WillByDefault(Return(fake_string1.c_str())); + + jstring fake_jstring2 = reinterpret_cast<jstring>(-505); + std::string fake_string2 = "bar baz"; + ON_CALL(*env_mock, GetStringUTFChars(Eq(fake_jstring2), IsNull())) + .WillByDefault(Return(fake_string2.c_str())); + + ScopedUtfChars scoped_utf_chars1(env_mock.get(), /*s=*/fake_jstring1); + ScopedUtfChars scoped_utf_chars2(env_mock.get(), /*s=*/fake_jstring2); + + // Move assign a scoped utf chars + scoped_utf_chars2 = std::move(scoped_utf_chars1); + EXPECT_THAT(scoped_utf_chars2.c_str(), Eq(fake_string1.c_str())); + EXPECT_THAT(scoped_utf_chars2.size(), Eq(3)); + + EXPECT_CALL(*env_mock, ReleaseStringUTFChars(Eq(fake_jstring1), + Eq(fake_string1.c_str()))) + .Times(1); + EXPECT_CALL(*env_mock, ReleaseStringUTFChars(Eq(fake_jstring2), + Eq(fake_string2.c_str()))) + .Times(1); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/legacy/index/icing-array-storage.cc b/icing/legacy/index/icing-array-storage.cc index 4d2ef67..de5178a 100644 --- a/icing/legacy/index/icing-array-storage.cc +++ b/icing/legacy/index/icing-array-storage.cc @@ -65,17 +65,13 @@ bool IcingArrayStorage::Init(int fd, size_t fd_offset, bool map_shared, return false; } if (file_size < fd_offset) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Array storage file size %" PRIu64 " less than offset %zu", file_size, - fd_offset); + ICING_LOG(ERROR) << "Array storage file size " << file_size << " less than offset " << fd_offset; return false; } uint32_t capacity_num_elts = (file_size - fd_offset) / elt_size; if (capacity_num_elts < num_elts) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Array storage num elts %u > capacity num elts %u", num_elts, - capacity_num_elts); + ICING_LOG(ERROR) << "Array storage num elts " << num_elts << " > capacity num elts " << capacity_num_elts; return false; } @@ -108,8 +104,7 @@ bool IcingArrayStorage::Init(int fd, size_t fd_offset, bool map_shared, if (init_crc) { *crc_ptr_ = crc; } else if (crc != *crc_ptr_) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Array storage bad crc %u vs %u", crc, *crc_ptr_); + ICING_LOG(ERROR) << "Array storage bad crc " << crc << " vs " << *crc_ptr_; goto failed; } } @@ -276,9 +271,9 @@ void IcingArrayStorage::UpdateCrc() { cur_offset += change.elt_len * elt_size_; } if (!changes_.empty()) { - ICING_VLOG(2) << IcingStringUtil::StringPrintf( - "Array update partial crcs %d truncated %d overlapped %d duplicate %d", - num_partial_crcs, num_truncated, num_overlapped, num_duplicate); + ICING_VLOG(2) << "Array update partial crcs " << num_partial_crcs + << " truncated " << num_truncated << " overlapped " << num_overlapped + << " duplicate " << num_duplicate; } // Now update with grown area. @@ -286,8 +281,7 @@ void IcingArrayStorage::UpdateCrc() { cur_crc = IcingStringUtil::UpdateCrc32( cur_crc, array_cast<char>() + changes_end_ * elt_size_, (cur_num_ - changes_end_) * elt_size_); - ICING_VLOG(2) << IcingStringUtil::StringPrintf( - "Array update tail crc offset %u -> %u", changes_end_, cur_num_); + ICING_VLOG(2) << "Array update tail crc offset " << changes_end_ << " -> " << cur_num_; } // Clear, now that we've applied changes. @@ -341,8 +335,7 @@ uint32_t IcingArrayStorage::Sync() { if (pwrite(fd_, array() + dirty_start, dirty_end - dirty_start, fd_offset_ + dirty_start) != static_cast<ssize_t>(dirty_end - dirty_start)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flushing pages failed (%u, %u)", dirty_start, dirty_end); + ICING_LOG(ERROR) << "Flushing pages failed (" << dirty_start << ", " << dirty_end << ")"; } in_dirty = false; } else if (!in_dirty && is_dirty) { @@ -361,8 +354,7 @@ uint32_t IcingArrayStorage::Sync() { if (pwrite(fd_, array() + dirty_start, dirty_end - dirty_start, fd_offset_ + dirty_start) != static_cast<ssize_t>(dirty_end - dirty_start)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flushing pages failed (%u, %u)", dirty_start, dirty_end); + ICING_LOG(ERROR) << "Flushing pages failed (" << dirty_start << ", " << dirty_end << ")"; } } @@ -377,9 +369,7 @@ uint32_t IcingArrayStorage::Sync() { } if (num_flushed > 0) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Flushing %u/%u %u contiguous pages in %.3fms", num_flushed, - dirty_pages_size, num_contiguous, timer.Elapsed() * 1000.); + ICING_VLOG(1) << "Flushing " << num_flushed << "/" << dirty_pages_size << " " << num_contiguous << " contiguous pages in " << timer.Elapsed() * 1000 << "ms."; } return num_flushed; diff --git a/icing/legacy/index/icing-common-types.h b/icing/legacy/index/icing-common-types.h deleted file mode 100644 index 592b549..0000000 --- a/icing/legacy/index/icing-common-types.h +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright (C) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2014 Google Inc. All Rights Reserved. -// Author: sbanacho@google.com (Scott Banachowski) -// Author: csyoung@google.com (C. Sean Young) - -#ifndef ICING_LEGACY_INDEX_ICING_COMMON_TYPES_H_ -#define ICING_LEGACY_INDEX_ICING_COMMON_TYPES_H_ - -#include "icing/legacy/core/icing-core-types.h" - -// Protocol buffers are shared across several components. -namespace com { -namespace google { -namespace android { -namespace gms { -namespace icing { -namespace lib { - -class ClientFileGroup; -class Document; -class Document_Section; -class DocumentStoreStatusProto; -class IMEUpdate; -class IMEUpdateResponse; -class IndexCorpusScoringConfig; -class IndexCorpusScoringConfig_Section; -class IndexScoringConfig; -class InitStatus; -class InitStatus_CorpusInitInfo; -class PendingDeleteUsageReport; -class PhraseAffinityRequest; -class QueryResponse; -class QueryResponse_Corpus; -class QueryResponse_Corpus_Section; -class QueryResponse_Corpus_Tag; -class QueryRequestSpec; -class QueryRequestSpec_CorpusSpec; -class QueryRequestSpec_SectionSpec; -class ResponseDebugInfo; -class ResultDebugInfo; -class SectionConfig; -class SuggestionResponse; -class SuggestionResponse_Suggestion; -class UsageReportsResponse; -class UsageStats; -class UsageStats_Corpus; - -} // namespace lib -} // namespace icing -} // namespace gms -} // namespace android -} // namespace google -} // namespace com - -namespace icing { -namespace lib { - -// Typedefs. -using IcingDocId = uint32_t; - -using IcingSectionId = uint32_t; - -using IcingCorpusId = uint16_t; -using IcingSectionIdMask = uint16_t; - -using IcingTagsCount = uint16_t; - -using IcingSequenceNumber = int64_t; - -using IcingScore = uint64_t; - -constexpr size_t kIcingMaxTokenLen = 30; // default shared between query - // processor and indexer -constexpr int kIcingQueryTermLimit = 50; // Maximum number of terms in a query -constexpr int kIcingMaxVariantsPerToken = 10; // Maximum number of variants - -// LINT.IfChange -constexpr int kIcingDocIdBits = 20; // 1M docs -constexpr IcingDocId kIcingInvalidDocId = (1u << kIcingDocIdBits) - 1; -constexpr IcingDocId kIcingMaxDocId = kIcingInvalidDocId - 1; -// LINT.ThenChange(//depot/google3/wireless/android/icing/plx/google_sql_common_macros.sql) - -constexpr int kIcingDocScoreBits = 32; - -constexpr int kIcingSectionIdBits = 4; // 4 bits for 16 values -constexpr IcingSectionId kIcingMaxSectionId = (1u << kIcingSectionIdBits) - 1; -constexpr IcingSectionId kIcingInvalidSectionId = kIcingMaxSectionId + 1; -constexpr IcingSectionIdMask kIcingSectionIdMaskAll = ~IcingSectionIdMask{0}; -constexpr IcingSectionIdMask kIcingSectionIdMaskNone = IcingSectionIdMask{0}; - -constexpr int kIcingCorpusIdBits = 15; // 32K -constexpr IcingCorpusId kIcingInvalidCorpusId = (1u << kIcingCorpusIdBits) - 1; -constexpr IcingCorpusId kIcingMaxCorpusId = kIcingInvalidCorpusId - 1; - -constexpr size_t kIcingMaxSearchableDocumentSize = (1u << 16) - 1; // 64K -// Max num tokens per document. 64KB is our original maximum (searchable) -// document size. We clip if document exceeds this. -constexpr uint32_t kIcingMaxNumTokensPerDoc = - kIcingMaxSearchableDocumentSize / 5; -constexpr uint32_t kIcingMaxNumHitsPerDocument = - kIcingMaxNumTokensPerDoc * kIcingMaxVariantsPerToken; - -constexpr IcingTagsCount kIcingInvalidTagCount = ~IcingTagsCount{0}; -constexpr IcingTagsCount kIcingMaxTagCount = kIcingInvalidTagCount - 1; - -// Location refers to document storage. -constexpr uint64_t kIcingInvalidLocation = ~uint64_t{0}; -constexpr uint64_t kIcingMaxDocStoreWriteLocation = uint64_t{1} - << 32; // 4bytes. - -// Dump symbols in the proto namespace. -using namespace ::com::google::android::gms::icing; // NOLINT(build/namespaces) -} // namespace lib -} // namespace icing - -#endif // ICING_LEGACY_INDEX_ICING_COMMON_TYPES_H_ diff --git a/icing/legacy/index/icing-dynamic-trie.cc b/icing/legacy/index/icing-dynamic-trie.cc index 4428599..c6816ad 100644 --- a/icing/legacy/index/icing-dynamic-trie.cc +++ b/icing/legacy/index/icing-dynamic-trie.cc @@ -460,8 +460,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Init() { if (i == 0) { // Header. if (file_size != IcingMMapper::system_page_size()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Trie hdr wrong size: %" PRIu64, file_size); + ICING_LOG(ERROR) << "Trie hdr wrong size: " << file_size; goto failed; } @@ -522,8 +521,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Init() { sizeof(char), hdr_.hdr.suffixes_size(), hdr_.hdr.max_suffixes_size(), &crcs_->array_crcs[SUFFIX], init_crcs)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Trie mmap suffix failed"); + ICING_LOG(ERROR) << "Trie mmap suffix failed"; goto failed; } @@ -671,8 +669,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Sync() { } if (!WriteHeader()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flushing trie header failed: %s", strerror(errno)); + ICING_LOG(ERROR) << "Flushing trie header failed: " << strerror(errno); success = false; } @@ -686,8 +683,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Sync() { } if (total_flushed > 0) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Flushing %u pages of trie", - total_flushed); + ICING_VLOG(1) << "Flushing " << total_flushed << " pages of trie"; } return success; @@ -817,8 +813,7 @@ uint32_t IcingDynamicTrie::IcingDynamicTrieStorage::UpdateCrc() { uint32_t IcingDynamicTrie::IcingDynamicTrieStorage::UpdateCrcInternal( bool write_hdr) { if (write_hdr && !WriteHeader()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flushing trie header failed: %s", strerror(errno)); + ICING_LOG(ERROR) << "Flushing trie header failed: " << strerror(errno); } crcs_->header_crc = GetHeaderCrc(); @@ -912,8 +907,7 @@ bool IcingDynamicTrie::IcingDynamicTrieStorage::Header::SerializeToArray( bool IcingDynamicTrie::IcingDynamicTrieStorage::Header::Verify() { // Check version. if (hdr.version() != kCurVersion) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Trie version %u mismatch", hdr.version()); + ICING_LOG(ERROR) << "Trie version " << hdr.version() << " mismatch"; return false; } @@ -1155,9 +1149,8 @@ bool IcingDynamicTrie::Sync() { Warm(); - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Syncing dynamic trie %s took %.3fms", filename_base_.c_str(), - timer.Elapsed() * 1000.); + ICING_VLOG(1) << "Syncing dynamic trie " << filename_base_.c_str() + << " took " << timer.Elapsed() * 1000 << "ms"; return success; } @@ -1207,8 +1200,7 @@ std::unique_ptr<IcingFlashBitmap> IcingDynamicTrie::OpenAndInitBitmap( const IcingFilesystem *filesystem) { auto bitmap = std::make_unique<IcingFlashBitmap>(filename, filesystem); if (!bitmap->Init() || (verify && !bitmap->Verify())) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Init of %s failed", - filename.c_str()); + ICING_LOG(ERROR) << "Init of " << filename.c_str() << " failed"; return nullptr; } return bitmap; @@ -1238,16 +1230,14 @@ bool IcingDynamicTrie::InitPropertyBitmaps() { vector<std::string> files; if (!filesystem_->GetMatchingFiles((property_bitmaps_prefix_ + "*").c_str(), &files)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Could not get files at prefix %s", property_bitmaps_prefix_.c_str()); + ICING_LOG(ERROR) << "Could not get files at prefix " << property_bitmaps_prefix_; goto failed; } for (size_t i = 0; i < files.size(); i++) { // Decode property id from filename. size_t property_id_start_idx = files[i].rfind('.'); if (property_id_start_idx == std::string::npos) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Malformed filename %s", - files[i].c_str()); + ICING_LOG(ERROR) << "Malformed filename " << files[i]; continue; } property_id_start_idx++; // skip dot @@ -1255,8 +1245,7 @@ bool IcingDynamicTrie::InitPropertyBitmaps() { uint32_t property_id = strtol(files[i].c_str() + property_id_start_idx, &end, 10); // NOLINT if (!end || end != (files[i].c_str() + files[i].size())) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Malformed filename %s", - files[i].c_str()); + ICING_LOG(ERROR) << "Malformed filename " << files[i]; continue; } std::unique_ptr<IcingFlashBitmap> bitmap = OpenAndInitBitmap( @@ -1264,8 +1253,7 @@ bool IcingDynamicTrie::InitPropertyBitmaps() { runtime_options_.storage_policy == RuntimeOptions::kMapSharedWithCrc, filesystem_); if (!bitmap) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Open prop bitmap failed: %s", files[i].c_str()); + ICING_LOG(ERROR) << "Open prop bitmap failed: " << files[i]; goto failed; } bitmap->Truncate(truncate_idx); @@ -1795,11 +1783,12 @@ bool IcingDynamicTrie::Find(const char *key, void *value, } IcingDynamicTrie::Iterator::Iterator(const IcingDynamicTrie &trie, - const char *prefix) + const char *prefix, bool reverse) : cur_key_(prefix), cur_suffix_(nullptr), cur_suffix_len_(0), single_leaf_match_(false), + reverse_(reverse), trie_(trie) { if (!trie.is_initialized()) { ICING_LOG(FATAL) << "DynamicTrie not initialized"; @@ -1808,19 +1797,29 @@ IcingDynamicTrie::Iterator::Iterator(const IcingDynamicTrie &trie, Reset(); } -void IcingDynamicTrie::Iterator::LeftBranchToLeaf(uint32_t node_index) { +void IcingDynamicTrie::Iterator::BranchToLeaf(uint32_t node_index, + BranchType branch_type) { // Go down the trie, following the left-most child until we hit a // leaf. Push to stack and cur_key nodes and chars as we go. - for (; !trie_.storage_->GetNode(node_index)->is_leaf(); - node_index = - trie_.storage_ - ->GetNext(trie_.storage_->GetNode(node_index)->next_index(), 0) - ->node_index()) { - branch_stack_.push_back(Branch(node_index)); - cur_key_.push_back( - trie_.storage_ - ->GetNext(trie_.storage_->GetNode(node_index)->next_index(), 0) - ->val()); + // When reverse_ is true, the method will follow the right-most child. + const Node *node = trie_.storage_->GetNode(node_index); + while (!node->is_leaf()) { + const Next *next_start = trie_.storage_->GetNext(node->next_index(), 0); + int child_idx; + if (branch_type == BranchType::kRightMost) { + uint32_t next_array_size = 1u << node->log2_num_children(); + child_idx = trie_.GetValidNextsSize(next_start, next_array_size) - 1; + } else { + // node isn't a leaf. So it must have >0 children. + // 0 is the left-most child. + child_idx = 0; + } + const Next &child_next = next_start[child_idx]; + branch_stack_.push_back(Branch(node_index, child_idx)); + cur_key_.push_back(child_next.val()); + + node_index = child_next.node_index(); + node = trie_.storage_->GetNode(node_index); } // We're at a leaf. @@ -1856,7 +1855,7 @@ void IcingDynamicTrie::Iterator::Reset() { // Two cases/states: // // - Found an intermediate node. If we matched all of prefix - // (cur_key_), LeftBranchToLeaf. + // (cur_key_), BranchToLeaf. // // - Found a leaf node, which is the ONLY matching key for this // prefix. Check that suffix matches the prefix. Then we set @@ -1879,7 +1878,9 @@ void IcingDynamicTrie::Iterator::Reset() { cur_suffix_len_ = strlen(cur_suffix_); single_leaf_match_ = true; } else if (static_cast<size_t>(key_offset) == cur_key_.size()) { - LeftBranchToLeaf(node_index); + BranchType branch_type = + (reverse_) ? BranchType::kRightMost : BranchType::kLeftMost; + BranchToLeaf(node_index, branch_type); } } @@ -1906,19 +1907,25 @@ bool IcingDynamicTrie::Iterator::Advance() { while (!branch_stack_.empty()) { Branch *branch = &branch_stack_.back(); const Node *node = trie_.storage_->GetNode(branch->node_idx); - branch->child_idx++; - if (branch->child_idx < (1 << node->log2_num_children()) && - trie_.storage_->GetNext(node->next_index(), branch->child_idx) - ->node_index() != kInvalidNodeIndex) { - // Successfully incremented to the next child. Update the char - // value at this depth. - cur_key_[cur_key_.size() - 1] = - trie_.storage_->GetNext(node->next_index(), branch->child_idx)->val(); - // We successfully found a sub-trie to explore. - LeftBranchToLeaf( - trie_.storage_->GetNext(node->next_index(), branch->child_idx) - ->node_index()); - return true; + if (reverse_) { + branch->child_idx--; + } else { + branch->child_idx++; + } + if (branch->child_idx >= 0 && + branch->child_idx < (1 << node->log2_num_children())) { + const Next *child_next = + trie_.storage_->GetNext(node->next_index(), branch->child_idx); + if (child_next->node_index() != kInvalidNodeIndex) { + // Successfully incremented to the next child. Update the char + // value at this depth. + cur_key_[cur_key_.size() - 1] = child_next->val(); + // We successfully found a sub-trie to explore. + BranchType branch_type = + (reverse_) ? BranchType::kRightMost : BranchType::kLeftMost; + BranchToLeaf(child_next->node_index(), branch_type); + return true; + } } branch_stack_.pop_back(); cur_key_.resize(cur_key_.size() - 1); @@ -2108,7 +2115,8 @@ const IcingDynamicTrie::Next *IcingDynamicTrie::GetNextByChar( } int IcingDynamicTrie::GetValidNextsSize( - IcingDynamicTrie::Next *next_array_start, int next_array_length) const { + const IcingDynamicTrie::Next *next_array_start, + int next_array_length) const { // Only searching for key char 0xff is not sufficient, as 0xff can be a valid // character. We must also specify kInvalidNodeIndex as the target node index // when searching the next array. @@ -2295,15 +2303,16 @@ bool IcingDynamicTrie::IsBranchingTerm(const char *key) const { return false; } - // key is not present in the trie. + // There is no intermediate node for key in the trie. if (key[key_offset] != '\0') { return false; } // Found key as an intermediate node, but key is not a valid term stored in - // the trie. + // the trie. In this case, we need at least two children for key to be a + // branching term. if (GetNextByChar(cur_node, '\0') == nullptr) { - return false; + return cur_node->log2_num_children() >= 1; } // The intermediate node for key must have more than two children for key to @@ -2320,8 +2329,7 @@ void IcingDynamicTrie::GetDebugInfo(int verbosity, std::string *out) const { vector<std::string> files; if (!filesystem_->GetMatchingFiles((property_bitmaps_prefix_ + "*").c_str(), &files)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Could not get files at prefix %s", property_bitmaps_prefix_.c_str()); + ICING_LOG(ERROR) << "Could not get files at prefix " << property_bitmaps_prefix_; return; } for (size_t i = 0; i < files.size(); i++) { @@ -2393,8 +2401,7 @@ IcingFlashBitmap *IcingDynamicTrie::OpenOrCreatePropertyBitmap( } if (property_id > kMaxPropertyId) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Property id %u out of range", property_id); + ICING_LOG(ERROR) << "Property id " << property_id << " out of range"; return nullptr; } @@ -2567,8 +2574,7 @@ bool IcingDynamicTrie::ClearPropertyForAllValues(uint32_t property_id) { PropertyReadersAll readers(*this); if (!readers.Exists(property_id)) { - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Properties for id %u don't exist", property_id); + ICING_VLOG(1) << "Properties for id " << property_id << " don't exist"; return true; } diff --git a/icing/legacy/index/icing-dynamic-trie.h b/icing/legacy/index/icing-dynamic-trie.h index ec8b31a..b172632 100644 --- a/icing/legacy/index/icing-dynamic-trie.h +++ b/icing/legacy/index/icing-dynamic-trie.h @@ -405,9 +405,6 @@ class IcingDynamicTrie : public IIcingStorage { // key is a branching term, if and only if there exists terms s1 and s2 in the // trie such that key is the maximum common prefix of s1 and s2, but s1 and s2 // are not prefixes of each other. - // - // The function assumes that key is already present in the trie. Otherwise, - // false will be returned. bool IsBranchingTerm(const char *key) const; void GetDebugInfo(int verbosity, std::string *out) const override; @@ -520,7 +517,8 @@ class IcingDynamicTrie : public IIcingStorage { // Change in underlying trie invalidates iterator. class Iterator { public: - Iterator(const IcingDynamicTrie &trie, const char *prefix); + Iterator(const IcingDynamicTrie &trie, const char *prefix, + bool reverse = false); void Reset(); bool Advance(); @@ -537,9 +535,10 @@ class IcingDynamicTrie : public IIcingStorage { Iterator(); // Copy is ok. - // Helper function that takes the left-most branch down - // intermediate nodes to a leaf. - void LeftBranchToLeaf(uint32_t node_index); + enum class BranchType { kLeftMost = 0, kRightMost = 1 }; + // Helper function that takes the left-most or the right-most branch down + // intermediate nodes to a leaf, based on branch_type. + void BranchToLeaf(uint32_t node_index, BranchType branch_type); std::string cur_key_; const char *cur_suffix_; @@ -548,10 +547,12 @@ class IcingDynamicTrie : public IIcingStorage { uint32_t node_idx; int child_idx; - explicit Branch(uint32_t ni) : node_idx(ni), child_idx(0) {} + explicit Branch(uint32_t node_index, int child_index) + : node_idx(node_index), child_idx(child_index) {} }; std::vector<Branch> branch_stack_; bool single_leaf_match_; + bool reverse_; const IcingDynamicTrie &trie_; }; @@ -625,7 +626,7 @@ class IcingDynamicTrie : public IIcingStorage { const Next *LowerBound(const Next *start, const Next *end, uint8_t key_char, uint32_t node_index = 0) const; // Returns the number of valid nexts in the array. - int GetValidNextsSize(IcingDynamicTrie::Next *next_array_start, + int GetValidNextsSize(const IcingDynamicTrie::Next *next_array_start, int next_array_length) const; void FindBestNode(const char *key, uint32_t *best_node_index, int *key_offset, bool prefix, bool utf8 = false) const; diff --git a/icing/legacy/index/icing-dynamic-trie_test.cc b/icing/legacy/index/icing-dynamic-trie_test.cc index b69ee64..850fcdc 100644 --- a/icing/legacy/index/icing-dynamic-trie_test.cc +++ b/icing/legacy/index/icing-dynamic-trie_test.cc @@ -39,6 +39,7 @@ namespace { using testing::ContainerEq; using testing::ElementsAre; +using testing::StrEq; constexpr std::string_view kKeys[] = { "", "ab", "ac", "abd", "bac", "bb", "bacd", "abbb", "abcdefg", @@ -109,6 +110,17 @@ class IcingDynamicTrieTest : public ::testing::Test { std::string trie_files_prefix_; }; +std::vector<std::pair<std::string, int>> RetrieveKeyValuePairs( + IcingDynamicTrie::Iterator& term_iter) { + std::vector<std::pair<std::string, int>> key_value; + for (; term_iter.IsValid(); term_iter.Advance()) { + uint32_t val; + memcpy(&val, term_iter.GetValue(), sizeof(val)); + key_value.push_back(std::make_pair(term_iter.GetKey(), val)); + } + return key_value; +} + constexpr std::string_view kCommonEnglishWords[] = { "that", "was", "for", "on", "are", "with", "they", "be", "at", "one", "have", "this", "from", "word", "but", "what", "some", "you", @@ -161,7 +173,6 @@ TEST_F(IcingDynamicTrieTest, Init) { TEST_F(IcingDynamicTrieTest, Iterator) { // Test iterator. IcingFilesystem filesystem; - uint32_t val; IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), &filesystem); ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); @@ -171,104 +182,161 @@ TEST_F(IcingDynamicTrieTest, Iterator) { ASSERT_TRUE(trie.Insert(kKeys[i].data(), &i)); } - // We try everything twice to test that Reset also works. - // Should get the entire trie. + std::vector<std::pair<std::string, int>> exp_key_values = { + {"", 0}, {"ab", 1}, {"abbb", 7}, {"abcdefg", 8}, {"abd", 3}, + {"ac", 2}, {"bac", 4}, {"bacd", 6}, {"bb", 5}}; IcingDynamicTrie::Iterator it_all(trie, ""); - for (int i = 0; i < 2; i++) { - uint32_t count = 0; - for (; it_all.IsValid(); it_all.Advance()) { - uint32_t val_idx = it_all.GetValueIndex(); - EXPECT_EQ(it_all.GetValue(), trie.GetValueAtIndex(val_idx)); - count++; - } - EXPECT_EQ(count, kNumKeys); - it_all.Reset(); - } + std::vector<std::pair<std::string, int>> key_values = + RetrieveKeyValuePairs(it_all); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Should get same results after calling Reset + it_all.Reset(); + key_values = RetrieveKeyValuePairs(it_all); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); // Get everything under "a". + exp_key_values = { + {"ab", 1}, {"abbb", 7}, {"abcdefg", 8}, {"abd", 3}, {"ac", 2}}; IcingDynamicTrie::Iterator it1(trie, "a"); - for (int i = 0; i < 2; i++) { - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "ab"); - static const uint32_t kOne = 1; - ASSERT_TRUE(it1.GetValue() != nullptr); - EXPECT_TRUE(!memcmp(it1.GetValue(), &kOne, sizeof(kOne))); + key_values = RetrieveKeyValuePairs(it1); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - ASSERT_TRUE(it1.Advance()); - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "abbb"); + // Should get same results after calling Reset + it1.Reset(); + key_values = RetrieveKeyValuePairs(it1); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - ASSERT_TRUE(it1.Advance()); - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "abcdefg"); + // Now "b". + exp_key_values = {{"bac", 4}, {"bacd", 6}, {"bb", 5}}; + IcingDynamicTrie::Iterator it2(trie, "b"); + key_values = RetrieveKeyValuePairs(it2); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - ASSERT_TRUE(it1.Advance()); - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "abd"); + // Should get same results after calling Reset + it2.Reset(); + key_values = RetrieveKeyValuePairs(it2); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - ASSERT_TRUE(it1.Advance()); - ASSERT_TRUE(it1.IsValid()); - EXPECT_STREQ(it1.GetKey(), "ac"); + // Get everything under "ab". + exp_key_values = {{"ab", 1}, {"abbb", 7}, {"abcdefg", 8}, {"abd", 3}}; + IcingDynamicTrie::Iterator it3(trie, "ab"); + key_values = RetrieveKeyValuePairs(it3); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - EXPECT_FALSE(it1.Advance()); - EXPECT_FALSE(it1.IsValid()); + // Should get same results after calling Reset + it3.Reset(); + key_values = RetrieveKeyValuePairs(it3); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - it1.Reset(); + // Should match only one key exactly. + constexpr std::string_view kOneMatch[] = { + "abd", + "abcd", + "abcdef", + "abcdefg", + }; + // With the following match: + constexpr std::string_view kOneMatchMatched[] = { + "abd", + "abcdefg", + "abcdefg", + "abcdefg", + }; + + for (size_t k = 0; k < ABSL_ARRAYSIZE(kOneMatch); k++) { + IcingDynamicTrie::Iterator it_single(trie, kOneMatch[k].data()); + ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; + EXPECT_THAT(it_single.GetKey(), StrEq(kOneMatchMatched[k].data())); + EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; + EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; + + // Should get same results after calling Reset + it_single.Reset(); + ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; + EXPECT_THAT(it_single.GetKey(), StrEq(kOneMatchMatched[k].data())); + EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; + EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; } - // Now "b". - IcingDynamicTrie::Iterator it2(trie, "b"); - for (int i = 0; i < 2; i++) { - ASSERT_TRUE(it2.IsValid()); - EXPECT_STREQ(it2.GetKey(), "bac"); - val = 1; - ASSERT_TRUE(it1.GetValue() != nullptr); - EXPECT_TRUE(!memcmp(it1.GetValue(), &val, sizeof(val))); - val = 4; - ASSERT_TRUE(it2.GetValue() != nullptr); - EXPECT_TRUE(!memcmp(it2.GetValue(), &val, sizeof(val))); - - ASSERT_TRUE(it2.Advance()); - ASSERT_TRUE(it2.IsValid()); - EXPECT_STREQ(it2.GetKey(), "bacd"); - - ASSERT_TRUE(it2.Advance()); - ASSERT_TRUE(it2.IsValid()); - EXPECT_STREQ(it2.GetKey(), "bb"); - - EXPECT_FALSE(it2.Advance()); - EXPECT_FALSE(it2.IsValid()); - - it2.Reset(); + // Matches nothing. + constexpr std::string_view kNoMatch[] = { + "abbd", + "abcdeg", + "abcdefh", + }; + for (size_t k = 0; k < ABSL_ARRAYSIZE(kNoMatch); k++) { + IcingDynamicTrie::Iterator it_empty(trie, kNoMatch[k].data()); + EXPECT_FALSE(it_empty.IsValid()); + it_empty.Reset(); + EXPECT_FALSE(it_empty.IsValid()); } - // Get everything under "ab". - IcingDynamicTrie::Iterator it3(trie, "ab"); - for (int i = 0; i < 2; i++) { - ASSERT_TRUE(it3.IsValid()); - EXPECT_STREQ(it3.GetKey(), "ab"); - val = 1; - ASSERT_TRUE(it3.GetValue() != nullptr); - EXPECT_TRUE(!memcmp(it3.GetValue(), &val, sizeof(val))); + // Clear. + trie.Clear(); + EXPECT_FALSE(IcingDynamicTrie::Iterator(trie, "").IsValid()); + EXPECT_EQ(0u, trie.size()); + EXPECT_EQ(1.0, trie.min_free_fraction()); +} - ASSERT_TRUE(it3.Advance()); - ASSERT_TRUE(it3.IsValid()); - EXPECT_STREQ(it3.GetKey(), "abbb"); +TEST_F(IcingDynamicTrieTest, IteratorReverse) { + // Test iterator. + IcingFilesystem filesystem; + IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), + &filesystem); + ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); + ASSERT_TRUE(trie.Init()); - ASSERT_TRUE(it3.Advance()); - ASSERT_TRUE(it3.IsValid()); - EXPECT_STREQ(it3.GetKey(), "abcdefg"); + for (uint32_t i = 0; i < kNumKeys; i++) { + ASSERT_TRUE(trie.Insert(kKeys[i].data(), &i)); + } - ASSERT_TRUE(it3.Advance()); - ASSERT_TRUE(it3.IsValid()); - EXPECT_STREQ(it3.GetKey(), "abd"); + // Should get the entire trie. + std::vector<std::pair<std::string, int>> exp_key_values = { + {"bb", 5}, {"bacd", 6}, {"bac", 4}, {"ac", 2}, {"abd", 3}, + {"abcdefg", 8}, {"abbb", 7}, {"ab", 1}, {"", 0}}; + IcingDynamicTrie::Iterator it_all(trie, "", /*reverse=*/true); + std::vector<std::pair<std::string, int>> key_values = + RetrieveKeyValuePairs(it_all); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + it_all.Reset(); + key_values = RetrieveKeyValuePairs(it_all); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - EXPECT_FALSE(it3.Advance()); - EXPECT_FALSE(it3.IsValid()); + // Get everything under "a". + exp_key_values = { + {"ac", 2}, {"abd", 3}, {"abcdefg", 8}, {"abbb", 7}, {"ab", 1}}; + IcingDynamicTrie::Iterator it1(trie, "a", /*reverse=*/true); + key_values = RetrieveKeyValuePairs(it1); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); - it3.Reset(); - } + // Should get same results after calling Reset + it1.Reset(); + key_values = RetrieveKeyValuePairs(it1); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Now "b". + exp_key_values = {{"bb", 5}, {"bacd", 6}, {"bac", 4}}; + IcingDynamicTrie::Iterator it2(trie, "b", /*reverse=*/true); + key_values = RetrieveKeyValuePairs(it2); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Should get same results after calling Reset + it2.Reset(); + key_values = RetrieveKeyValuePairs(it2); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Get everything under "ab". + exp_key_values = {{"abd", 3}, {"abcdefg", 8}, {"abbb", 7}, {"ab", 1}}; + IcingDynamicTrie::Iterator it3(trie, "ab", /*reverse=*/true); + key_values = RetrieveKeyValuePairs(it3); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Should get same results after calling Reset + it3.Reset(); + key_values = RetrieveKeyValuePairs(it3); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); // Should match only one key exactly. constexpr std::string_view kOneMatch[] = { @@ -286,15 +354,19 @@ TEST_F(IcingDynamicTrieTest, Iterator) { }; for (size_t k = 0; k < ABSL_ARRAYSIZE(kOneMatch); k++) { - IcingDynamicTrie::Iterator it_single(trie, kOneMatch[k].data()); - for (int i = 0; i < 2; i++) { - ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; - EXPECT_STREQ(it_single.GetKey(), kOneMatchMatched[k].data()); - EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; - EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; - - it_single.Reset(); - } + IcingDynamicTrie::Iterator it_single(trie, kOneMatch[k].data(), + /*reverse=*/true); + ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; + EXPECT_THAT(it_single.GetKey(), StrEq(kOneMatchMatched[k].data())); + EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; + EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; + + // Should get same results after calling Reset + it_single.Reset(); + ASSERT_TRUE(it_single.IsValid()) << kOneMatch[k]; + EXPECT_THAT(it_single.GetKey(), StrEq(kOneMatchMatched[k].data())); + EXPECT_FALSE(it_single.Advance()) << kOneMatch[k]; + EXPECT_FALSE(it_single.IsValid()) << kOneMatch[k]; } // Matches nothing. @@ -304,21 +376,65 @@ TEST_F(IcingDynamicTrieTest, Iterator) { "abcdefh", }; for (size_t k = 0; k < ABSL_ARRAYSIZE(kNoMatch); k++) { - IcingDynamicTrie::Iterator it_empty(trie, kNoMatch[k].data()); - for (int i = 0; i < 2; i++) { - EXPECT_FALSE(it_empty.IsValid()); - - it_empty.Reset(); - } + IcingDynamicTrie::Iterator it_empty(trie, kNoMatch[k].data(), + /*reverse=*/true); + EXPECT_FALSE(it_empty.IsValid()); + it_empty.Reset(); + EXPECT_FALSE(it_empty.IsValid()); } // Clear. trie.Clear(); - EXPECT_FALSE(IcingDynamicTrie::Iterator(trie, "").IsValid()); + EXPECT_FALSE( + IcingDynamicTrie::Iterator(trie, "", /*reverse=*/true).IsValid()); EXPECT_EQ(0u, trie.size()); EXPECT_EQ(1.0, trie.min_free_fraction()); } +TEST_F(IcingDynamicTrieTest, IteratorLoadTest) { + IcingFilesystem filesystem; + IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), + &filesystem); + ASSERT_TRUE(trie.CreateIfNotExist(IcingDynamicTrie::Options())); + ASSERT_TRUE(trie.Init()); + + std::default_random_engine random; + ICING_LOG(ERROR) << "Seed: " << std::default_random_engine::default_seed; + + std::vector<std::pair<std::string, int>> exp_key_values; + // Randomly generate 1024 terms. + for (int i = 0; i < 1024; ++i) { + std::string term = RandomString("abcdefg", 5, &random) + std::to_string(i); + ASSERT_TRUE(trie.Insert(term.c_str(), &i)); + exp_key_values.push_back(std::make_pair(term, i)); + } + // Lexicographically sort the expected keys. + std::sort(exp_key_values.begin(), exp_key_values.end()); + + // Check that the iterator works. + IcingDynamicTrie::Iterator term_iter(trie, /*prefix=*/""); + std::vector<std::pair<std::string, int>> key_values = + RetrieveKeyValuePairs(term_iter); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Check that Reset works. + term_iter.Reset(); + key_values = RetrieveKeyValuePairs(term_iter); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + std::reverse(exp_key_values.begin(), exp_key_values.end()); + // Check that the reverse iterator works. + IcingDynamicTrie::Iterator term_iter_reverse(trie, /*prefix=*/"", + /*reverse=*/true); + key_values = RetrieveKeyValuePairs(term_iter_reverse); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); + + // Check that Reset works. + term_iter_reverse.Reset(); + key_values = RetrieveKeyValuePairs(term_iter_reverse); + EXPECT_THAT(key_values, ContainerEq(exp_key_values)); +} + TEST_F(IcingDynamicTrieTest, Persistence) { // Test persistence on the English dictionary. IcingFilesystem filesystem; @@ -1233,7 +1349,7 @@ TEST_F(IcingDynamicTrieTest, BitmapsClosedWhenInitFails) { ASSERT_EQ(0, trie.property_bitmaps_.size()); } -TEST_F(IcingDynamicTrieTest, IsBranchingTerm) { +TEST_F(IcingDynamicTrieTest, IsBranchingTermShouldWorkForExistingTerms) { IcingFilesystem filesystem; IcingDynamicTrie trie(trie_files_prefix_, IcingDynamicTrie::RuntimeOptions(), &filesystem); @@ -1319,34 +1435,52 @@ TEST_F(IcingDynamicTrieTest, IsBranchingTermShouldWorkForNonExistingTerms) { EXPECT_FALSE(trie.IsBranchingTerm("")); EXPECT_FALSE(trie.IsBranchingTerm("a")); EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); ASSERT_TRUE(trie.Insert("aa", &value)); EXPECT_FALSE(trie.IsBranchingTerm("")); EXPECT_FALSE(trie.IsBranchingTerm("a")); - - ASSERT_TRUE(trie.Insert("", &value)); - EXPECT_FALSE(trie.IsBranchingTerm("a")); - - ASSERT_TRUE(trie.Insert("ab", &value)); - EXPECT_FALSE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); ASSERT_TRUE(trie.Insert("ac", &value)); - EXPECT_FALSE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("")); + // "a" does not exist in the trie, but now it branches to "aa" and "ac". + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); ASSERT_TRUE(trie.Insert("ad", &value)); - EXPECT_FALSE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); + EXPECT_FALSE(trie.IsBranchingTerm("abc")); ASSERT_TRUE(trie.Insert("abcd", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_FALSE(trie.IsBranchingTerm("ab")); EXPECT_FALSE(trie.IsBranchingTerm("abc")); - ASSERT_TRUE(trie.Insert("abce", &value)); + ASSERT_TRUE(trie.Insert("abd", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + // "ab" does not exist in the trie, but now it branches to "abcd" and "abd". + EXPECT_TRUE(trie.IsBranchingTerm("ab")); EXPECT_FALSE(trie.IsBranchingTerm("abc")); - ASSERT_TRUE(trie.Insert("abcf", &value)); - EXPECT_FALSE(trie.IsBranchingTerm("abc")); + ASSERT_TRUE(trie.Insert("abce", &value)); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_TRUE(trie.IsBranchingTerm("ab")); + // "abc" does not exist in the trie, but now it branches to "abcd" and "abce". + EXPECT_TRUE(trie.IsBranchingTerm("abc")); ASSERT_TRUE(trie.Insert("abc_suffix", &value)); - EXPECT_FALSE(trie.IsBranchingTerm("abc")); + EXPECT_FALSE(trie.IsBranchingTerm("")); + EXPECT_TRUE(trie.IsBranchingTerm("a")); + EXPECT_TRUE(trie.IsBranchingTerm("ab")); + EXPECT_TRUE(trie.IsBranchingTerm("abc")); EXPECT_FALSE(trie.IsBranchingTerm("abc_s")); EXPECT_FALSE(trie.IsBranchingTerm("abc_su")); EXPECT_FALSE(trie.IsBranchingTerm("abc_suffi")); diff --git a/icing/legacy/index/icing-filesystem.cc b/icing/legacy/index/icing-filesystem.cc index 4f5e571..fbf5a27 100644 --- a/icing/legacy/index/icing-filesystem.cc +++ b/icing/legacy/index/icing-filesystem.cc @@ -65,18 +65,15 @@ void LogOpenFileDescriptors() { constexpr int kMaxFileDescriptorsToStat = 4096; struct rlimit rlim = {0, 0}; if (getrlimit(RLIMIT_NOFILE, &rlim) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "getrlimit() failed (errno=%d)", errno); + ICING_LOG(ERROR) << "getrlimit() failed (errno=" << errno << ")"; return; } int fd_lim = rlim.rlim_cur; if (fd_lim > kMaxFileDescriptorsToStat) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Maximum number of file descriptors (%d) too large.", fd_lim); + ICING_LOG(ERROR) << "Maximum number of file descriptors (" << fd_lim << ") too large."; fd_lim = kMaxFileDescriptorsToStat; } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Listing up to %d file descriptors.", fd_lim); + ICING_LOG(ERROR) << "Listing up to " << fd_lim << " file descriptors."; // Verify that /proc/self/fd is a directory. If not, procfs is not mounted or // inaccessible for some other reason. In that case, there's no point trying @@ -98,15 +95,12 @@ void LogOpenFileDescriptors() { if (len >= 0) { // Zero-terminate the buffer, because readlink() won't. target[len < target_size ? len : target_size - 1] = '\0'; - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("fd %d -> \"%s\"", fd, - target); + ICING_LOG(ERROR) << "fd " << fd << " -> \"" << target << "\""; } else if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("fd %d -> ? (errno=%d)", - fd, errno); + ICING_LOG(ERROR) << "fd " << fd << " -> ? (errno=" << errno << ")"; } } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "File descriptor list complete."); + ICING_LOG(ERROR) << "File descriptor list complete."; } // Logs an error formatted as: desc1 + file_name + desc2 + strerror(errnum). @@ -115,8 +109,7 @@ void LogOpenFileDescriptors() { // file descriptors (see LogOpenFileDescriptors() above). void LogOpenError(const char *desc1, const char *file_name, const char *desc2, int errnum) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "%s%s%s%s", desc1, file_name, desc2, strerror(errnum)); + ICING_LOG(ERROR) << desc1 << file_name << desc2 << strerror(errnum); if (errnum == EMFILE) { LogOpenFileDescriptors(); } @@ -157,8 +150,7 @@ bool ListDirectoryInternal(const char *dir_name, } } if (closedir(dir) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Error closing %s: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Error closing " << dir_name << ": " << strerror(errno); } return true; } @@ -181,12 +173,11 @@ void IcingScopedFd::reset(int fd) { const uint64_t IcingFilesystem::kBadFileSize; bool IcingFilesystem::DeleteFile(const char *file_name) const { - ICING_VLOG(1) << IcingStringUtil::StringPrintf("Deleting file %s", file_name); + ICING_VLOG(1) << "Deleting file " << file_name; int ret = unlink(file_name); bool success = (ret == 0) || (errno == ENOENT); if (!success) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Deleting file %s failed: %s", file_name, strerror(errno)); + ICING_LOG(ERROR) << "Deleting file " << file_name << " failed: " << strerror(errno); } return success; } @@ -195,8 +186,7 @@ bool IcingFilesystem::DeleteDirectory(const char *dir_name) const { int ret = rmdir(dir_name); bool success = (ret == 0) || (errno == ENOENT); if (!success) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Deleting directory %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Deleting directory " << dir_name << " failed: " << strerror(errno); } return success; } @@ -208,8 +198,7 @@ bool IcingFilesystem::DeleteDirectoryRecursively(const char *dir_name) const { if (errno == ENOENT) { return true; // If directory didn't exist, this was successful. } - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Stat %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Stat " << dir_name << " failed: " << strerror(errno); return false; } vector<std::string> entries; @@ -222,8 +211,7 @@ bool IcingFilesystem::DeleteDirectoryRecursively(const char *dir_name) const { ++i) { std::string filename = std::string(dir_name) + '/' + *i; if (stat(filename.c_str(), &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Stat %s failed: %s", filename.c_str(), strerror(errno)); + ICING_LOG(ERROR) << "Stat " << filename << " failed: " << strerror(errno); success = false; } else if (S_ISDIR(st.st_mode)) { success = DeleteDirectoryRecursively(filename.c_str()) && success; @@ -246,8 +234,7 @@ bool IcingFilesystem::FileExists(const char *file_name) const { exists = S_ISREG(st.st_mode) != 0; } else { if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", file_name, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file " << file_name << ": " << strerror(errno); } exists = false; } @@ -261,8 +248,7 @@ bool IcingFilesystem::DirectoryExists(const char *dir_name) const { exists = S_ISDIR(st.st_mode) != 0; } else { if (errno != ENOENT) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat directory %s: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat directory " << dir_name << ": " << strerror(errno); } exists = false; } @@ -317,8 +303,7 @@ bool IcingFilesystem::GetMatchingFiles(const char *glob, int basename_idx = GetBasenameIndex(glob); if (basename_idx == 0) { // We need a directory. - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Expected directory, no matching files for: %s", glob); + ICING_VLOG(1) << "Expected directory, no matching files for: " << glob; return true; } const char *basename_glob = glob + basename_idx; @@ -374,8 +359,7 @@ uint64_t IcingFilesystem::GetFileSize(int fd) const { struct stat st; uint64_t size = kBadFileSize; if (fstat(fd, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file: " << strerror(errno); } else { size = st.st_size; } @@ -386,8 +370,7 @@ uint64_t IcingFilesystem::GetFileSize(const char *filename) const { struct stat st; uint64_t size = kBadFileSize; if (stat(filename, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to stat file %s: %s", filename, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file " << filename << ": " << strerror(errno); } else { size = st.st_size; } @@ -399,8 +382,7 @@ bool IcingFilesystem::Truncate(int fd, uint64_t new_size) const { if (ret == 0) { lseek(fd, new_size, SEEK_SET); } else { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to truncate file: %s", strerror(errno)); + ICING_LOG(ERROR) << "Unable to truncate file: " << strerror(errno); } return (ret == 0); } @@ -418,8 +400,7 @@ bool IcingFilesystem::Truncate(const char *filename, uint64_t new_size) const { bool IcingFilesystem::Grow(int fd, uint64_t new_size) const { int ret = ftruncate(fd, new_size); if (ret != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to grow file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to grow file: " << strerror(errno); } return (ret == 0); } @@ -431,8 +412,7 @@ bool IcingFilesystem::Write(int fd, const void *data, size_t data_size) const { size_t chunk_size = std::min<size_t>(write_len, 64u * 1024); ssize_t wrote = write(fd, data, chunk_size); if (wrote < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad write: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad write: " << strerror(errno); return false; } data = static_cast<const uint8_t *>(data) + wrote; @@ -449,8 +429,7 @@ bool IcingFilesystem::PWrite(int fd, off_t offset, const void *data, size_t chunk_size = std::min<size_t>(write_len, 64u * 1024); ssize_t wrote = pwrite(fd, data, chunk_size, offset); if (wrote < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Bad write: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Bad write: " << strerror(errno); return false; } data = static_cast<const uint8_t *>(data) + wrote; @@ -468,8 +447,7 @@ bool IcingFilesystem::DataSync(int fd) const { #endif if (result < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to sync data: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to sync data: " << strerror(errno); return false; } return true; @@ -478,9 +456,7 @@ bool IcingFilesystem::DataSync(int fd) const { bool IcingFilesystem::RenameFile(const char *old_name, const char *new_name) const { if (rename(old_name, new_name) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Unable to rename file %s to %s: %s", old_name, new_name, - strerror(errno)); + ICING_LOG(ERROR) << "Unable to rename file " << old_name << " to " << new_name << ": " << strerror(errno); return false; } return true; @@ -518,8 +494,7 @@ bool IcingFilesystem::CreateDirectory(const char *dir_name) const { if (mkdir(dir_name, S_IRUSR | S_IWUSR | S_IXUSR) == 0) { success = true; } else { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Creating directory %s failed: %s", dir_name, strerror(errno)); + ICING_LOG(ERROR) << "Creating directory " << dir_name << " failed: " << strerror(errno); } } return success; @@ -561,8 +536,7 @@ end: if (src_fd > 0) close(src_fd); if (dst_fd > 0) close(dst_fd); if (!success) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Couldn't copy file %s to %s", src, dst); + ICING_LOG(ERROR) << "Couldn't copy file " << src << " to " << dst; } return success; } @@ -583,8 +557,7 @@ bool IcingFilesystem::ComputeChecksum(int fd, uint32_t *checksum, uint64_t IcingFilesystem::GetDiskUsage(int fd) const { struct stat st; if (fstat(fd, &st) < 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat file: %s", - strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat file: " << strerror(errno); return kBadFileSize; } return st.st_blocks * kStatBlockSize; @@ -593,8 +566,7 @@ uint64_t IcingFilesystem::GetDiskUsage(int fd) const { uint64_t IcingFilesystem::GetFileDiskUsage(const char *path) const { struct stat st; if (stat(path, &st) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat %s: %s", - path, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat " << path << ": " << strerror(errno); return kBadFileSize; } return st.st_blocks * kStatBlockSize; @@ -603,8 +575,7 @@ uint64_t IcingFilesystem::GetFileDiskUsage(const char *path) const { uint64_t IcingFilesystem::GetDiskUsage(const char *path) const { struct stat st; if (stat(path, &st) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Unable to stat %s: %s", - path, strerror(errno)); + ICING_LOG(ERROR) << "Unable to stat " << path << ": " << strerror(errno); return kBadFileSize; } uint64_t result = st.st_blocks * kStatBlockSize; diff --git a/icing/legacy/index/icing-flash-bitmap.cc b/icing/legacy/index/icing-flash-bitmap.cc index 56dec00..774308f 100644 --- a/icing/legacy/index/icing-flash-bitmap.cc +++ b/icing/legacy/index/icing-flash-bitmap.cc @@ -73,8 +73,7 @@ class IcingFlashBitmap::Accessor { bool IcingFlashBitmap::Verify() const { if (!is_initialized()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Can't verify unopened flash bitmap %s", filename_.c_str()); + ICING_LOG(ERROR) << "Can't verify unopened flash bitmap " << filename_; return false; } if (mmapper_ == nullptr) { @@ -83,26 +82,21 @@ bool IcingFlashBitmap::Verify() const { } Accessor accessor(mmapper_.get()); if (accessor.header()->magic != kMagic) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flash bitmap %s has incorrect magic header", filename_.c_str()); + ICING_LOG(ERROR) << "Flash bitmap " << filename_ << " has incorrect magic header"; return false; } if (accessor.header()->version != kCurVersion) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flash bitmap %s has incorrect version", filename_.c_str()); + ICING_LOG(ERROR) << "Flash bitmap " << filename_ << " has incorrect version"; return false; } if (accessor.header()->dirty) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flash bitmap %s is dirty", filename_.c_str()); + ICING_LOG(ERROR) << "Flash bitmap " << filename_ << " is dirty"; return false; } uint32_t crc = IcingStringUtil::UpdateCrc32(0, accessor.data(), accessor.data_size()); if (accessor.header()->crc != crc) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Flash bitmap %s has incorrect CRC32 %u %u", filename_.c_str(), - accessor.header()->crc, crc); + ICING_LOG(ERROR) << "Flash bitmap " << filename_ << " has incorrect CRC32 " << accessor.header()->crc << " " << crc; return false; } return true; @@ -265,17 +259,14 @@ uint32_t IcingFlashBitmap::UpdateCrc() const { bool IcingFlashBitmap::Grow(size_t new_file_size) { IcingScopedFd fd(filesystem_->OpenForWrite(filename_.c_str())); if (!filesystem_->Grow(fd.get(), new_file_size)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Grow %s to new size %zu failed", filename_.c_str(), new_file_size); + ICING_LOG(ERROR) << "Grow " << filename_ << " to new size " << new_file_size << " failed"; return false; } if (!mmapper_->Remap(fd.get(), 0, new_file_size)) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Remap of %s after grow failed", filename_.c_str()); + ICING_LOG(ERROR) << "Remap of " << filename_ << " after grow failed"; return false; } - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Grew %s new size %zu", filename_.c_str(), new_file_size); + ICING_VLOG(1) << "Grew " << filename_ << " new size " << new_file_size; Accessor accessor(mmapper_.get()); accessor.header()->dirty = true; return true; diff --git a/icing/legacy/index/icing-mmapper.cc b/icing/legacy/index/icing-mmapper.cc index 7946c82..d086da2 100644 --- a/icing/legacy/index/icing-mmapper.cc +++ b/icing/legacy/index/icing-mmapper.cc @@ -67,8 +67,7 @@ void IcingMMapper::DoMapping(int fd, uint64_t location, size_t size) { address_ = reinterpret_cast<uint8_t *>(mmap_result_) + alignment_adjustment; } else { const char *errstr = strerror(errno); - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "Could not mmap file for reading: %s", errstr); + ICING_LOG(ERROR) << "Could not mmap file for reading: " << errstr; mmap_result_ = nullptr; } } @@ -95,8 +94,7 @@ IcingMMapper::~IcingMMapper() { Unmap(); } bool IcingMMapper::Sync() { if (is_valid() && !read_only_) { if (msync(mmap_result_, mmap_len_, MS_SYNC) != 0) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("msync failed: %s", - strerror(errno)); + ICING_LOG(ERROR) << "msync failed: " << strerror(errno); return false; } } diff --git a/icing/legacy/index/icing-storage-file.cc b/icing/legacy/index/icing-storage-file.cc index 35a4418..bbc6b81 100644 --- a/icing/legacy/index/icing-storage-file.cc +++ b/icing/legacy/index/icing-storage-file.cc @@ -69,22 +69,18 @@ bool IcingStorageFile::Sync() { IcingTimer timer; if (!PreSync()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Pre-sync %s failed", - filename_.c_str()); + ICING_LOG(ERROR) << "Pre-sync " << filename_ << " failed"; return false; } if (!filesystem_->DataSync(fd_.get())) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Sync %s failed", - filename_.c_str()); + ICING_LOG(ERROR) << "Sync " << filename_ << " failed"; return false; } if (!PostSync()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf("Post-sync %s failed", - filename_.c_str()); + ICING_LOG(ERROR) << "Post-sync " << filename_ << " failed"; return false; } - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "Syncing %s took %.3fms", filename_.c_str(), timer.Elapsed() * 1000.); + ICING_VLOG(1) << "Syncing " << filename_ << " took " << timer.Elapsed() * 1000 << "ms"; return true; } diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc index a725213..d1cce87 100644 --- a/icing/query/query-processor_test.cc +++ b/icing/query/query-processor_test.cc @@ -17,7 +17,6 @@ #include <memory> #include <string> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -27,6 +26,7 @@ #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" +#include "icing/jni/jni-cache.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/portable/platform.h" #include "icing/proto/schema.pb.h" diff --git a/icing/result/result-retriever-v2.cc b/icing/result/result-retriever-v2.cc index 195f641..92ab048 100644 --- a/icing/result/result-retriever-v2.cc +++ b/icing/result/result-retriever-v2.cc @@ -110,6 +110,7 @@ std::pair<PageResult, bool> ResultRetrieverV2::RetrieveNextPage( // Retrieve info std::vector<SearchResultProto::ResultProto> results; + int32_t num_total_bytes = 0; while (results.size() < result_state.num_per_page() && !result_state.scored_document_hits_ranker->empty()) { ScoredDocumentHit next_best_document_hit = @@ -154,7 +155,17 @@ std::pair<PageResult, bool> ResultRetrieverV2::RetrieveNextPage( // Add the document, itself. *result.mutable_document() = std::move(document); result.set_score(next_best_document_hit.score()); + size_t result_bytes = result.ByteSizeLong(); results.push_back(std::move(result)); + + // Check if num_total_bytes + result_bytes reaches or exceeds + // num_total_bytes_per_page_threshold. Use subtraction to avoid integer + // overflow. + if (result_bytes >= + result_state.num_total_bytes_per_page_threshold() - num_total_bytes) { + break; + } + num_total_bytes += result_bytes; } // Update numbers in ResultState diff --git a/icing/result/result-retriever-v2_group-result-limiter-test.cc b/icing/result/result-retriever-v2_group-result-limiter_test.cc index e4bfe09..e0a6c79 100644 --- a/icing/result/result-retriever-v2_group-result-limiter-test.cc +++ b/icing/result/result-retriever-v2_group-result-limiter_test.cc @@ -185,6 +185,142 @@ TEST_F(ResultRetrieverV2GroupResultLimiterTest, } TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ResultGroupingHasEmptyFirstPage) { + // Creates 2 documents and ensures the relationship in terms of document + // score is: document1 < document2 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score())}; + + // Create a ResultSpec that limits "namespace" to 0 results. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(0); + result_grouping->add_namespaces("namespace"); + + // Creates a ResultState with 2 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // First page: empty page + auto [page_result, has_more_results] = + result_retriever->RetrieveNextPage(result_state); + ASSERT_THAT(page_result.results, IsEmpty()); + EXPECT_FALSE(has_more_results); +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, + ResultGroupingHasEmptyLastPage) { + // Creates 4 documents and ensures the relationship in terms of document + // score is: document1 < document2 < document3 < document4 + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace", "uri/1") + .SetSchema("Document") + .SetScore(1) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(document1)); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace", "uri/2") + .SetSchema("Document") + .SetScore(2) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(document2)); + + DocumentProto document3 = DocumentBuilder() + .SetKey("namespace", "uri/3") + .SetSchema("Document") + .SetScore(3) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store_->Put(document3)); + + DocumentProto document4 = DocumentBuilder() + .SetKey("namespace", "uri/4") + .SetSchema("Document") + .SetScore(4) + .SetCreationTimestampMs(1000) + .Build(); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store_->Put(document4)); + + std::vector<ScoredDocumentHit> scored_document_hits = { + ScoredDocumentHit(document_id1, kSectionIdMaskNone, document1.score()), + ScoredDocumentHit(document_id2, kSectionIdMaskNone, document2.score()), + ScoredDocumentHit(document_id3, kSectionIdMaskNone, document3.score()), + ScoredDocumentHit(document_id4, kSectionIdMaskNone, document4.score())}; + + // Create a ResultSpec that limits "namespace" to 2 results. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(2); + result_grouping->add_namespaces("namespace"); + + // Creates a ResultState with 4 ScoredDocumentHits. + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + *document_store_); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + // First page: document4 and document3 should be returned. + auto [page_result1, has_more_results1] = + result_retriever->RetrieveNextPage(result_state); + ASSERT_THAT(page_result1.results, SizeIs(2)); + EXPECT_THAT(page_result1.results.at(0).document(), EqualsProto(document4)); + EXPECT_THAT(page_result1.results.at(1).document(), EqualsProto(document3)); + EXPECT_TRUE(has_more_results1); + + // Second page: although there are valid document hits in result state, all of + // them will be filtered out by group result limiter, so we should get an + // empty page. + auto [page_result2, has_more_results2] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result2.results, SizeIs(0)); + EXPECT_FALSE(has_more_results2); +} + +TEST_F(ResultRetrieverV2GroupResultLimiterTest, ResultGroupingDoesNotLimitOtherNamespaceResults) { // Creates 4 documents and ensures the relationship in terms of document // score is: document1 < document2 < document3 < document4 diff --git a/icing/result/result-retriever-v2_projection-test.cc b/icing/result/result-retriever-v2_projection_test.cc index bdd1715..bdd1715 100644 --- a/icing/result/result-retriever-v2_projection-test.cc +++ b/icing/result/result-retriever-v2_projection_test.cc diff --git a/icing/result/result-retriever-v2_snippet-test.cc b/icing/result/result-retriever-v2_snippet_test.cc index afb31cf..afb31cf 100644 --- a/icing/result/result-retriever-v2_snippet-test.cc +++ b/icing/result/result-retriever-v2_snippet_test.cc diff --git a/icing/result/result-retriever-v2_test.cc b/icing/result/result-retriever-v2_test.cc index f23a88a..0998754 100644 --- a/icing/result/result-retriever-v2_test.cc +++ b/icing/result/result-retriever-v2_test.cc @@ -56,6 +56,7 @@ using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::DoDefault; using ::testing::ElementsAre; using ::testing::Eq; +using ::testing::Gt; using ::testing::IsEmpty; using ::testing::Pointee; using ::testing::Return; @@ -635,6 +636,179 @@ TEST_F(ResultRetrieverV2Test, ShouldUpdateNumTotalHits) { EXPECT_THAT(num_total_hits_, Eq(0)); } +TEST_F(ResultRetrieverV2Test, ShouldLimitNumTotalBytesPerPage) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/5}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + SearchResultProto::ResultProto result1; + *result1.mutable_document() = CreateDocument(/*id=*/1); + result1.set_score(5); + SearchResultProto::ResultProto result2; + *result2.mutable_document() = CreateDocument(/*id=*/2); + result2.set_score(0); + + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + result_spec.set_num_total_bytes_per_page_threshold(result1.ByteSizeLong()); + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, *doc_store); + + // First page. Only result1 should be returned, since its byte size meets + // num_total_bytes_per_page_threshold and ResultRetriever should terminate + // early even though # of results is still below num_per_page. + auto [page_result1, has_more_results1] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result1.results, ElementsAre(EqualsProto(result1))); + // Has more results. + EXPECT_TRUE(has_more_results1); + + // Second page, result2. + auto [page_result2, has_more_results2] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result2.results, ElementsAre(EqualsProto(result2))); + // No more results. + EXPECT_FALSE(has_more_results2); +} + +TEST_F(ResultRetrieverV2Test, + ShouldReturnSingleLargeResultAboveNumTotalBytesPerPageThreshold) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/5}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + SearchResultProto::ResultProto result1; + *result1.mutable_document() = CreateDocument(/*id=*/1); + result1.set_score(5); + SearchResultProto::ResultProto result2; + *result2.mutable_document() = CreateDocument(/*id=*/2); + result2.set_score(0); + + int threshold = 1; + ASSERT_THAT(result1.ByteSizeLong(), Gt(threshold)); + + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + result_spec.set_num_total_bytes_per_page_threshold(threshold); + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, *doc_store); + + // First page. Should return single result1 even though its byte size exceeds + // num_total_bytes_per_page_threshold. + auto [page_result1, has_more_results1] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result1.results, ElementsAre(EqualsProto(result1))); + // Has more results. + EXPECT_TRUE(has_more_results1); + + // Second page, result2. + auto [page_result2, has_more_results2] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result2.results, ElementsAre(EqualsProto(result2))); + // No more results. + EXPECT_FALSE(has_more_results2); +} + +TEST_F(ResultRetrieverV2Test, + ShouldRetrieveNextResultWhenBelowNumTotalBytesPerPageThreshold) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, test_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + doc_store->Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + doc_store->Put(CreateDocument(/*id=*/2))); + + std::vector<SectionId> hit_section_ids = {GetSectionId("Email", "name"), + GetSectionId("Email", "body")}; + SectionIdMask hit_section_id_mask = CreateSectionIdMask(hit_section_ids); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, hit_section_id_mask, /*score=*/5}, + {document_id2, hit_section_id_mask, /*score=*/0}}; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(doc_store.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + + SearchResultProto::ResultProto result1; + *result1.mutable_document() = CreateDocument(/*id=*/1); + result1.set_score(5); + SearchResultProto::ResultProto result2; + *result2.mutable_document() = CreateDocument(/*id=*/2); + result2.set_score(0); + + int threshold = result1.ByteSizeLong() + 1; + ASSERT_THAT(result1.ByteSizeLong() + result2.ByteSizeLong(), Gt(threshold)); + + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + result_spec.set_num_total_bytes_per_page_threshold(threshold); + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, *doc_store); + + // After retrieving result1, total bytes are still below the threshold and # + // of results is still below num_per_page, so ResultRetriever should continue + // the retrieval process and thus include result2 into this page, even though + // finally total bytes of result1 + result2 exceed the threshold. + auto [page_result, has_more_results] = + result_retriever->RetrieveNextPage(result_state); + EXPECT_THAT(page_result.results, + ElementsAre(EqualsProto(result1), EqualsProto(result2))); + // No more results. + EXPECT_FALSE(has_more_results); +} + } // namespace } // namespace lib diff --git a/icing/result/result-state-manager.cc b/icing/result/result-state-manager.cc index 1057f9b..2783fe2 100644 --- a/icing/result/result-state-manager.cc +++ b/icing/result/result-state-manager.cc @@ -14,7 +14,16 @@ #include "icing/result/result-state-manager.h" +#include <memory> +#include <queue> +#include <utility> + #include "icing/proto/search.pb.h" +#include "icing/query/query-terms.h" +#include "icing/result/page-result.h" +#include "icing/result/result-retriever-v2.h" +#include "icing/result/result-state-v2.h" +#include "icing/scoring/scored-document-hits-ranker.h" #include "icing/util/clock.h" #include "icing/util/logging.h" #include "icing/util/status-macros.h" @@ -31,50 +40,66 @@ ResultStateManager::ResultStateManager(int max_total_hits, random_generator_(GetSteadyTimeNanoseconds()), clock_(*clock) {} -libtextclassifier3::StatusOr<PageResultState> -ResultStateManager::RankAndPaginate(ResultState result_state) { - if (!result_state.HasMoreResults()) { - return absl_ports::InvalidArgumentError("ResultState has no results"); +libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> +ResultStateManager::CacheAndRetrieveFirstPage( + std::unique_ptr<ScoredDocumentHitsRanker> ranker, + SectionRestrictQueryTermsMap query_terms, + const SearchSpecProto& search_spec, const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec, const DocumentStore& document_store, + const ResultRetrieverV2& result_retriever) { + if (ranker == nullptr) { + return absl_ports::InvalidArgumentError("Should not provide null ranker"); } - // Gets the number before calling GetNextPage() because num_returned() may - // change after returning more results. - int num_previously_returned = result_state.num_returned(); - int num_per_page = result_state.num_per_page(); - - std::vector<ScoredDocumentHit> page_result_document_hits = - result_state.GetNextPage(document_store_); - - SnippetContext snippet_context_copy = result_state.snippet_context(); - - std::unordered_map<std::string, ProjectionTree> projection_tree_map_copy = - result_state.projection_tree_map(); - if (!result_state.HasMoreResults()) { + // Create shared pointer of ResultState. + // ResultState should be created by ResultStateManager only. + std::shared_ptr<ResultStateV2> result_state = std::make_shared<ResultStateV2>( + std::move(ranker), std::move(query_terms), search_spec, scoring_spec, + result_spec, document_store); + + // Retrieve docs outside of ResultStateManager critical section. + // Will enter ResultState critical section inside ResultRetriever. + auto [page_result, has_more_results] = + result_retriever.RetrieveNextPage(*result_state); + if (!has_more_results) { // No more pages, won't store ResultState, returns directly - return PageResultState( - std::move(page_result_document_hits), kInvalidNextPageToken, - std::move(snippet_context_copy), std::move(projection_tree_map_copy), - num_previously_returned, num_per_page); + return std::make_pair(kInvalidNextPageToken, std::move(page_result)); } - absl_ports::unique_lock l(&mutex_); - // ResultState has multiple pages, storing it - uint64_t next_page_token = Add(std::move(result_state)); + int num_hits_to_add = 0; + { + // ResultState critical section + absl_ports::unique_lock l(&result_state->mutex); + + result_state->scored_document_hits_ranker->TruncateHitsTo(max_total_hits_); + result_state->RegisterNumTotalHits(&num_total_hits_); + num_hits_to_add = result_state->scored_document_hits_ranker->size(); + } - return PageResultState(std::move(page_result_document_hits), next_page_token, - std::move(snippet_context_copy), - std::move(projection_tree_map_copy), - num_previously_returned, num_per_page); -} + // It is fine to exit ResultState critical section, since it is just created + // above and only this thread (this call stack) has access to it. Thus, it + // won't be changed during the gap before we enter ResultStateManager critical + // section. + uint64_t next_page_token = kInvalidNextPageToken; + { + // ResultStateManager critical section + absl_ports::unique_lock l(&mutex_); + + // Remove expired result states first. + InternalInvalidateExpiredResultStates(kDefaultResultStateTtlInMs); + // Remove states to make room for this new state. + RemoveStatesIfNeeded(num_hits_to_add); + // Generate a new unique token and add it into result_state_map_. + next_page_token = Add(std::move(result_state)); + } -uint64_t ResultStateManager::Add(ResultState result_state) { - RemoveStatesIfNeeded(result_state); - result_state.TruncateHitsTo(max_total_hits_); + return std::make_pair(next_page_token, std::move(page_result)); +} +uint64_t ResultStateManager::Add(std::shared_ptr<ResultStateV2> result_state) { uint64_t new_token = GetUniqueToken(); - num_total_hits_ += result_state.num_remaining(); result_state_map_.emplace(new_token, std::move(result_state)); // Tracks the insertion order token_queue_.push( @@ -83,43 +108,40 @@ uint64_t ResultStateManager::Add(ResultState result_state) { return new_token; } -libtextclassifier3::StatusOr<PageResultState> ResultStateManager::GetNextPage( - uint64_t next_page_token) { - absl_ports::unique_lock l(&mutex_); +libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> +ResultStateManager::GetNextPage(uint64_t next_page_token, + const ResultRetrieverV2& result_retriever) { + std::shared_ptr<ResultStateV2> result_state = nullptr; + { + // ResultStateManager critical section + absl_ports::unique_lock l(&mutex_); - const auto& state_iterator = result_state_map_.find(next_page_token); - if (state_iterator == result_state_map_.end()) { - return absl_ports::NotFoundError("next_page_token not found"); - } + // Remove expired result states before fetching + InternalInvalidateExpiredResultStates(kDefaultResultStateTtlInMs); - int num_returned = state_iterator->second.num_returned(); - int num_per_page = state_iterator->second.num_per_page(); - std::vector<ScoredDocumentHit> result_of_page = - state_iterator->second.GetNextPage(document_store_); - if (result_of_page.empty()) { - // This shouldn't happen, all our active states should contain results, but - // a sanity check here in case of any data inconsistency. - InternalInvalidateResultState(next_page_token); - return absl_ports::NotFoundError( - "No more results, token has been invalidated."); + const auto& state_iterator = result_state_map_.find(next_page_token); + if (state_iterator == result_state_map_.end()) { + return absl_ports::NotFoundError("next_page_token not found"); + } + result_state = state_iterator->second; } - // Copies the SnippetContext in case the ResultState is invalidated. - SnippetContext snippet_context_copy = - state_iterator->second.snippet_context(); + // Retrieve docs outside of ResultStateManager critical section. + // Will enter ResultState critical section inside ResultRetriever. + auto [page_result, has_more_results] = + result_retriever.RetrieveNextPage(*result_state); - std::unordered_map<std::string, ProjectionTree> projection_tree_map_copy = - state_iterator->second.projection_tree_map(); + if (!has_more_results) { + { + // ResultStateManager critical section + absl_ports::unique_lock l(&mutex_); + + InternalInvalidateResultState(next_page_token); + } - if (!state_iterator->second.HasMoreResults()) { - InternalInvalidateResultState(next_page_token); next_page_token = kInvalidNextPageToken; } - - num_total_hits_ -= result_of_page.size(); - return PageResultState( - result_of_page, next_page_token, std::move(snippet_context_copy), - std::move(projection_tree_map_copy), num_returned, num_per_page); + return std::make_pair(next_page_token, std::move(page_result)); } void ResultStateManager::InvalidateResultState(uint64_t next_page_token) { @@ -137,17 +159,13 @@ void ResultStateManager::InvalidateAllResultStates() { InternalInvalidateAllResultStates(); } -void ResultStateManager::InvalidateExpiredResultStates( - int64_t result_state_ttl) { - absl_ports::unique_lock l(&mutex_); - InternalInvalidateExpiredResultStates(result_state_ttl); -} - void ResultStateManager::InternalInvalidateAllResultStates() { + // We don't have to reset num_total_hits_ (to 0) here, since clearing + // result_state_map_ will "eventually" invoke the destructor of ResultState + // (which decrements num_total_hits_) and num_total_hits_ will become 0. result_state_map_.clear(); invalidated_token_set_.clear(); token_queue_ = std::queue<std::pair<uint64_t, int64_t>>(); - num_total_hits_ = 0; } uint64_t ResultStateManager::GetUniqueToken() { @@ -163,14 +181,14 @@ uint64_t ResultStateManager::GetUniqueToken() { return new_token; } -void ResultStateManager::RemoveStatesIfNeeded(const ResultState& result_state) { +void ResultStateManager::RemoveStatesIfNeeded(int num_hits_to_add) { if (result_state_map_.empty() || token_queue_.empty()) { return; } // 1. Check if this new result_state would take up the entire result state // manager budget. - if (result_state.num_remaining() > max_total_hits_) { + if (num_hits_to_add > max_total_hits_) { // This single result state will exceed our budget. Drop everything else to // accomodate it. InternalInvalidateAllResultStates(); @@ -187,7 +205,13 @@ void ResultStateManager::RemoveStatesIfNeeded(const ResultState& result_state) { // 3. If we're over budget, remove states from oldest to newest until we fit // into our budget. - while (result_state.num_remaining() + num_total_hits_ > max_total_hits_) { + // Note: num_total_hits_ may not be decremented immediately after invalidating + // a result state, since other threads may still hold the shared pointer. + // Thus, we have to check if token_queue_ is empty or not, since it is + // possible that num_total_hits_ is non-zero and still greater than + // max_total_hits_ when token_queue_ is empty. Still "eventually" it will be + // decremented after the last thread releases the shared pointer. + while (!token_queue_.empty() && num_total_hits_ > max_total_hits_) { InternalInvalidateResultState(token_queue_.front().first); token_queue_.pop(); } @@ -201,7 +225,9 @@ void ResultStateManager::InternalInvalidateResultState(uint64_t token) { // remove the token in RemoveStatesIfNeeded(). auto itr = result_state_map_.find(token); if (itr != result_state_map_.end()) { - num_total_hits_ -= itr->second.num_remaining(); + // We don't have to decrement num_total_hits_ here, since erasing the shared + // ptr instance will "eventually" invoke the destructor of ResultState and + // it will handle this. result_state_map_.erase(itr); invalidated_token_set_.insert(token); } @@ -214,7 +240,9 @@ void ResultStateManager::InternalInvalidateExpiredResultStates( current_time - token_queue_.front().second >= result_state_ttl) { auto itr = result_state_map_.find(token_queue_.front().first); if (itr != result_state_map_.end()) { - num_total_hits_ -= itr->second.num_remaining(); + // We don't have to decrement num_total_hits_ here, since erasing the + // shared ptr instance will "eventually" invoke the destructor of + // ResultState and it will handle this. result_state_map_.erase(itr); } else { // Since result_state_map_ and invalidated_token_set_ are mutually diff --git a/icing/result/result-state-manager.h b/icing/result/result-state-manager.h index 745b0ec..0684864 100644 --- a/icing/result/result-state-manager.h +++ b/icing/result/result-state-manager.h @@ -15,6 +15,8 @@ #ifndef ICING_RESULT_RESULT_STATE_MANAGER_H_ #define ICING_RESULT_RESULT_STATE_MANAGER_H_ +#include <atomic> +#include <memory> #include <queue> #include <random> #include <unordered_map> @@ -24,8 +26,11 @@ #include "icing/absl_ports/mutex.h" #include "icing/proto/scoring.pb.h" #include "icing/proto/search.pb.h" -#include "icing/result/page-result-state.h" -#include "icing/result/result-state.h" +#include "icing/query/query-terms.h" +#include "icing/result/page-result.h" +#include "icing/result/result-retriever-v2.h" +#include "icing/result/result-state-v2.h" +#include "icing/scoring/scored-document-hits-ranker.h" #include "icing/util/clock.h" namespace icing { @@ -49,30 +54,46 @@ class ResultStateManager { ResultStateManager(const ResultStateManager&) = delete; ResultStateManager& operator=(const ResultStateManager&) = delete; - // Ranks the results and returns the first page of them. The result object - // PageResultState contains a next_page_token which can be used to fetch more - // pages later. It will be set to a default value 0 if there're no more pages. + // Creates a new result state, retrieves and returns PageResult for the first + // page. Also caches the new result state and returns a next_page_token which + // can be used to fetch more pages from the same result state later. Before + // caching the result state, adjusts (truncate) the size and evicts some old + // result states if exceeding the cache size limit. next_page_token will be + // set to a default value kInvalidNextPageToken if there're no more pages. // - // NOTE: it's caller's responsibility not to call this method with the same - // ResultState more than once, otherwise duplicate states will be stored - // internally. + // NOTE: it is possible to have empty result for the first page even if the + // ranker was not empty before the retrieval, since GroupResultLimiter + // may filter out all docs. In this case, the first page is also the + // last page and next_page_token will be set to kInvalidNextPageToken. // // Returns: - // A PageResultState on success - // INVALID_ARGUMENT if the input state contains no results - libtextclassifier3::StatusOr<PageResultState> RankAndPaginate( - ResultState result_state) ICING_LOCKS_EXCLUDED(mutex_); + // A token and PageResult wrapped by std::pair on success + // INVALID_ARGUMENT if the input ranker is null or contains no results + libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> + CacheAndRetrieveFirstPage(std::unique_ptr<ScoredDocumentHitsRanker> ranker, + SectionRestrictQueryTermsMap query_terms, + const SearchSpecProto& search_spec, + const ScoringSpecProto& scoring_spec, + const ResultSpecProto& result_spec, + const DocumentStore& document_store, + const ResultRetrieverV2& result_retriever) + ICING_LOCKS_EXCLUDED(mutex_); - // Retrieves and returns the next page of results wrapped in PageResultState. + // Retrieves and returns PageResult for the next page. // The returned results won't exist in ResultStateManager anymore. If the // query has no more pages after this retrieval, the input token will be // invalidated. // + // NOTE: it is possible to have empty result for the last page even if the + // ranker was not empty before the retrieval, since GroupResultLimiter + // may filtered out all remaining docs. + // // Returns: - // PageResultState on success, guaranteed to have non-empty results + // A token and PageResult wrapped by std::pair on success // NOT_FOUND if failed to find any more results - libtextclassifier3::StatusOr<PageResultState> GetNextPage( - uint64_t next_page_token) ICING_LOCKS_EXCLUDED(mutex_); + libtextclassifier3::StatusOr<std::pair<uint64_t, PageResult>> GetNextPage( + uint64_t next_page_token, const ResultRetrieverV2& result_retriever) + ICING_LOCKS_EXCLUDED(mutex_); // Invalidates the result state associated with the given next-page token. void InvalidateResultState(uint64_t next_page_token) @@ -81,12 +102,6 @@ class ResultStateManager { // Invalidates all result states / tokens currently in ResultStateManager. void InvalidateAllResultStates() ICING_LOCKS_EXCLUDED(mutex_); - // Invalidates expired result states / tokens currently in ResultStateManager - // that were created before current_time - result_state_ttl. - void InvalidateExpiredResultStates( - int64_t result_state_ttl = kDefaultResultStateTtlInMs) - ICING_LOCKS_EXCLUDED(mutex_); - private: absl_ports::shared_mutex mutex_; @@ -100,10 +115,10 @@ class ResultStateManager { // The number of scored document hits that all result states currently held by // the result state manager have. - int num_total_hits_; + std::atomic<int> num_total_hits_; // A hash map of (next-page token -> result state) - std::unordered_map<uint64_t, ResultState> result_state_map_ + std::unordered_map<uint64_t, std::shared_ptr<ResultStateV2>> result_state_map_ ICING_GUARDED_BY(mutex_); // A queue used to track the insertion order of tokens with pushed timestamps. @@ -125,14 +140,16 @@ class ResultStateManager { // currently valid tokens. When the maximum number of result states is // reached, the oldest / firstly added result state will be removed to make // room for the new state. - uint64_t Add(ResultState result_state) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + uint64_t Add(std::shared_ptr<ResultStateV2> result_state) + ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Helper method to generate a next-page token that is unique among all // existing tokens in token_queue_. uint64_t GetUniqueToken() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Helper method to remove old states to make room for incoming states. - void RemoveStatesIfNeeded(const ResultState& result_state) + // Helper method to remove old states to make room for incoming states with + // size num_hits_to_add. + void RemoveStatesIfNeeded(int num_hits_to_add) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Helper method to remove a result state from result_state_map_, the token diff --git a/icing/result/result-state-manager_test.cc b/icing/result/result-state-manager_test.cc index 251a736..7025c63 100644 --- a/icing/result/result-state-manager_test.cc +++ b/icing/result/result-state-manager_test.cc @@ -16,23 +16,39 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "icing/document-builder.h" #include "icing/file/filesystem.h" #include "icing/portable/equals-proto.h" +#include "icing/result/page-result.h" +#include "icing/result/result-retriever-v2.h" #include "icing/schema/schema-store.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" +#include "icing/scoring/scored-document-hits-ranker.h" #include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" #include "icing/util/clock.h" +#include "unicode/uloc.h" namespace icing { namespace lib { namespace { + using ::icing::lib::portable_equals_proto::EqualsProto; -using ::testing::ElementsAre; using ::testing::Eq; -using ::testing::Gt; using ::testing::IsEmpty; +using ::testing::Not; +using ::testing::SizeIs; +using PageResultInfo = std::pair<uint64_t, PageResult>; + +// TODO(sungyc): Refactor helper functions below (builder classes or common test +// utility). ScoringSpecProto CreateScoringSpec() { ScoringSpecProto scoring_spec; @@ -46,102 +62,175 @@ ResultSpecProto CreateResultSpec(int num_per_page) { return result_spec; } -ScoredDocumentHit CreateScoredHit(DocumentId document_id) { - return ScoredDocumentHit(document_id, kSectionIdMaskNone, /*score=*/1); +DocumentProto CreateDocument(int id) { + return DocumentBuilder() + .SetNamespace("namespace") + .SetUri(std::to_string(id)) + .SetSchema("Document") + .SetCreationTimestampMs(1574365086666 + id) + .SetScore(1) + .Build(); } class ResultStateManagerTest : public testing::Test { protected: + ResultStateManagerTest() : test_dir_(GetTestTempDir() + "/icing") { + filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); + } + void SetUp() override { + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + clock_ = std::make_unique<FakeClock>(); - schema_store_base_dir_ = GetTestTempDir() + "/schema_store"; - filesystem_.CreateDirectoryRecursively(schema_store_base_dir_.c_str()); + language_segmenter_factory::SegmenterOptions options(ULOC_US); ICING_ASSERT_OK_AND_ASSIGN( - schema_store_, SchemaStore::Create(&filesystem_, schema_store_base_dir_, - clock_.get())); + language_segmenter_, + language_segmenter_factory::Create(std::move(options))); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, clock_.get())); SchemaProto schema; schema.add_types()->set_schema_type("Document"); ICING_ASSERT_OK(schema_store_->SetSchema(std::move(schema))); - doc_store_base_dir_ = GetTestTempDir() + "/document_store"; - filesystem_.CreateDirectoryRecursively(doc_store_base_dir_.c_str()); + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( + /*max_term_byte_size=*/10000)); + ICING_ASSERT_OK_AND_ASSIGN( DocumentStore::CreateResult result, - DocumentStore::Create(&filesystem_, doc_store_base_dir_, clock_.get(), + DocumentStore::Create(&filesystem_, test_dir_, clock_.get(), schema_store_.get())); document_store_ = std::move(result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN( + result_retriever_, ResultRetrieverV2::Create( + document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); } void TearDown() override { - filesystem_.DeleteDirectoryRecursively(doc_store_base_dir_.c_str()); - filesystem_.DeleteDirectoryRecursively(schema_store_base_dir_.c_str()); + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); clock_.reset(); } - ResultState CreateResultState( - const std::vector<ScoredDocumentHit>& scored_document_hits, - int num_per_page) { - return ResultState(scored_document_hits, /*query_terms=*/{}, - SearchSpecProto::default_instance(), CreateScoringSpec(), - CreateResultSpec(num_per_page), *document_store_); - } - - ScoredDocumentHit AddScoredDocument(DocumentId document_id) { + std::pair<ScoredDocumentHit, DocumentProto> AddScoredDocument( + DocumentId document_id) { DocumentProto document; document.set_namespace_("namespace"); document.set_uri(std::to_string(document_id)); document.set_schema("Document"); - document_store_->Put(std::move(document)); - return ScoredDocumentHit(document_id, kSectionIdMaskNone, /*score=*/1); + document.set_creation_timestamp_ms(1574365086666 + document_id); + document_store_->Put(document); + return std::make_pair( + ScoredDocumentHit(document_id, kSectionIdMaskNone, /*score=*/1), + std::move(document)); + } + + std::pair<std::vector<ScoredDocumentHit>, std::vector<DocumentProto>> + AddScoredDocuments(const std::vector<DocumentId>& document_ids) { + std::vector<ScoredDocumentHit> scored_document_hits; + std::vector<DocumentProto> document_protos; + + for (DocumentId document_id : document_ids) { + std::pair<ScoredDocumentHit, DocumentProto> pair = + AddScoredDocument(document_id); + scored_document_hits.emplace_back(std::move(pair.first)); + document_protos.emplace_back(std::move(pair.second)); + } + + std::reverse(document_protos.begin(), document_protos.end()); + + return std::make_pair(std::move(scored_document_hits), + std::move(document_protos)); } FakeClock* clock() { return clock_.get(); } const FakeClock* clock() const { return clock_.get(); } + DocumentStore& document_store() { return *document_store_; } const DocumentStore& document_store() const { return *document_store_; } + const ResultRetrieverV2& result_retriever() const { + return *result_retriever_; + } + private: Filesystem filesystem_; + const std::string test_dir_; std::unique_ptr<FakeClock> clock_; - std::string doc_store_base_dir_; - std::string schema_store_base_dir_; - std::unique_ptr<DocumentStore> document_store_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<Normalizer> normalizer_; + std::unique_ptr<DocumentStore> document_store_; + std::unique_ptr<ResultRetrieverV2> result_retriever_; }; -TEST_F(ResultStateManagerTest, ShouldRankAndPaginateOnePage) { - ResultState original_result_state = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2)}, - /*num_per_page=*/10); +TEST_F(ResultStateManagerTest, ShouldCacheAndRetrieveFirstPageOnePage) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store().Put(CreateDocument(/*id=*/3))); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}, + {document_id3, kSectionIdMaskNone, /*score=*/1}}; + std::unique_ptr<ScoredDocumentHitsRanker> ranker = + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true); ResultStateManager result_state_manager( /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state, - result_state_manager.RankAndPaginate(std::move(original_result_state))); + PageResultInfo page_result_info, + result_state_manager.CacheAndRetrieveFirstPage( + std::move(ranker), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/10), + document_store(), result_retriever())); - EXPECT_THAT(page_result_state.next_page_token, Eq(kInvalidNextPageToken)); + EXPECT_THAT(page_result_info.first, Eq(kInvalidNextPageToken)); - // Should get the original scored document hits - EXPECT_THAT( - page_result_state.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/2)), - EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/1)), - EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/0)))); + // Should get docs. + ASSERT_THAT(page_result_info.second.results, SizeIs(3)); + EXPECT_THAT(page_result_info.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/3))); + EXPECT_THAT(page_result_info.second.results.at(1).document(), + EqualsProto(CreateDocument(/*id=*/2))); + EXPECT_THAT(page_result_info.second.results.at(2).document(), + EqualsProto(CreateDocument(/*id=*/1))); } -TEST_F(ResultStateManagerTest, ShouldRankAndPaginateMultiplePages) { - ResultState original_result_state = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4)}, - /*num_per_page=*/2); +TEST_F(ResultStateManagerTest, ShouldCacheAndRetrieveFirstPageMultiplePages) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store().Put(CreateDocument(/*id=*/3))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store().Put(CreateDocument(/*id=*/4))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, + document_store().Put(CreateDocument(/*id=*/5))); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}, + {document_id3, kSectionIdMaskNone, /*score=*/1}, + {document_id4, kSectionIdMaskNone, /*score=*/1}, + {document_id5, kSectionIdMaskNone, /*score=*/1}}; + std::unique_ptr<ScoredDocumentHitsRanker> ranker = + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true); ResultStateManager result_state_manager( /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), @@ -149,977 +238,1132 @@ TEST_F(ResultStateManagerTest, ShouldRankAndPaginateMultiplePages) { // First page, 2 results ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - EXPECT_THAT( - page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/4)), - EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/3)))); - - uint64_t next_page_token = page_result_state1.next_page_token; + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::move(ranker), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/2), + document_store(), result_retriever())); + EXPECT_THAT(page_result_info1.first, Not(Eq(kInvalidNextPageToken))); + ASSERT_THAT(page_result_info1.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/5))); + EXPECT_THAT(page_result_info1.second.results.at(1).document(), + EqualsProto(CreateDocument(/*id=*/4))); + + uint64_t next_page_token = page_result_info1.first; // Second page, 2 results - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state2, - result_state_manager.GetNextPage(next_page_token)); - EXPECT_THAT( - page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/2)), - EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/1)))); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + EXPECT_THAT(page_result_info2.first, Eq(next_page_token)); + ASSERT_THAT(page_result_info2.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/3))); + EXPECT_THAT(page_result_info2.second.results.at(1).document(), + EqualsProto(CreateDocument(/*id=*/2))); // Third page, 1 result - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state3, - result_state_manager.GetNextPage(next_page_token)); - EXPECT_THAT( - page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/0)))); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info3, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + EXPECT_THAT(page_result_info3.first, Eq(kInvalidNextPageToken)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/1))); // No results - EXPECT_THAT(result_state_manager.GetNextPage(next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT( + result_state_manager.GetNextPage(next_page_token, result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } -TEST_F(ResultStateManagerTest, EmptyStateShouldReturnError) { - ResultState empty_result_state = CreateResultState({}, /*num_per_page=*/1); +TEST_F(ResultStateManagerTest, NullRankerShouldReturnError) { + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + EXPECT_THAT(result_state_manager.CacheAndRetrieveFirstPage( + /*ranker=*/nullptr, + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever()), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} +TEST_F(ResultStateManagerTest, EmptyRankerShouldReturnEmptyFirstPage) { ResultStateManager result_state_manager( /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), clock()); - EXPECT_THAT( - result_state_manager.RankAndPaginate(std::move(empty_result_state)), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(page_result_info.first, Eq(kInvalidNextPageToken)); + EXPECT_THAT(page_result_info.second.results, IsEmpty()); } -TEST_F(ResultStateManagerTest, ShouldInvalidateOneToken) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); +TEST_F(ResultStateManagerTest, ShouldAllowEmptyFirstPage) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}}; + + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + // Create a ResultSpec that limits "namespace" to 0 results. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(0); + result_grouping->add_namespaces("namespace"); + + // First page, no result. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), result_spec, document_store(), + result_retriever())); + // If the first page has no result, then it should be the last page. + EXPECT_THAT(page_result_info.first, Eq(kInvalidNextPageToken)); + EXPECT_THAT(page_result_info.second.results, IsEmpty()); +} + +TEST_F(ResultStateManagerTest, ShouldAllowEmptyLastPage) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store().Put(CreateDocument(/*id=*/3))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store().Put(CreateDocument(/*id=*/4))); + std::vector<ScoredDocumentHit> scored_document_hits = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}, + {document_id3, kSectionIdMaskNone, /*score=*/1}, + {document_id4, kSectionIdMaskNone, /*score=*/1}}; + + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + // Create a ResultSpec that limits "namespace" to 2 results. + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + ResultSpecProto::ResultGrouping* result_grouping = + result_spec.add_result_groupings(); + result_grouping->set_max_results(2); + result_grouping->add_namespaces("namespace"); + + // First page, 2 results. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), result_spec, document_store(), + result_retriever())); + EXPECT_THAT(page_result_info1.first, Not(Eq(kInvalidNextPageToken))); + ASSERT_THAT(page_result_info1.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/4))); + EXPECT_THAT(page_result_info1.second.results.at(1).document(), + EqualsProto(CreateDocument(/*id=*/3))); + + uint64_t next_page_token = page_result_info1.first; + + // Second page, all remaining documents will be filtered out by group result + // limiter, so we should get an empty page. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + EXPECT_THAT(page_result_info2.first, Eq(kInvalidNextPageToken)); + EXPECT_THAT(page_result_info2.second.results, IsEmpty()); +} + +TEST_F(ResultStateManagerTest, + ShouldInvalidateExpiredTokensWhenCacheAndRetrieveFirstPage) { + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = AddScoredDocuments( + {/*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); ResultStateManager result_state_manager( /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), clock()); + + SectionRestrictQueryTermsMap query_terms; + SearchSpecProto search_spec; + ScoringSpecProto scoring_spec = CreateScoringSpec(); + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); + + // Set time as 1s and add state 1. + clock()->SetSystemTimeMilliseconds(1000); ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + query_terms, search_spec, scoring_spec, result_spec, document_store(), + result_retriever())); + ASSERT_THAT(page_result_info1.first, Not(Eq(kInvalidNextPageToken))); + + // Set time as 1hr1s and add state 2. + clock()->SetSystemTimeMilliseconds(kDefaultResultStateTtlInMs + 1000); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + query_terms, search_spec, scoring_spec, result_spec, document_store(), + result_retriever())); + + // Calling CacheAndRetrieveFirstPage() on state 2 should invalidate the + // expired state 1 internally. + // + // We test the behavior by setting time back to 1s, to make sure the + // invalidation of state 1 was done by the previous + // CacheAndRetrieveFirstPage() instead of the following GetNextPage(). + clock()->SetSystemTimeMilliseconds(1000); + // page_result_info1's token (page_result_info1.first) shouldn't be found. + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); +} + +TEST_F(ResultStateManagerTest, + ShouldInvalidateExpiredTokensWhenGetNextPageOnOthers) { + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = AddScoredDocuments( + {/*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); + + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); + + // Set time as 1s and add state 1. + clock()->SetSystemTimeMilliseconds(1000); ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ASSERT_THAT(page_result_info1.first, Not(Eq(kInvalidNextPageToken))); + + // Set time as 2s and add state 2. + clock()->SetSystemTimeMilliseconds(2000); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ASSERT_THAT(page_result_info2.first, Not(Eq(kInvalidNextPageToken))); + + // 1. Set time as 1hr1s. + // 2. Call GetNextPage() on state 2. It should correctly invalidate the + // expired state 1. + // 3. Then calling GetNextPage() on state 1 shouldn't get anything. + clock()->SetSystemTimeMilliseconds(kDefaultResultStateTtlInMs + 1000); + // page_result_info2's token (page_result_info2.first) should be found + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + // We test the behavior by setting time back to 2s, to make sure the + // invalidation of state 1 was done by the previous GetNextPage() instead of + // the following GetNextPage(). + clock()->SetSystemTimeMilliseconds(2000); + // page_result_info1's token (page_result_info1.first) shouldn't be found. + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); +} - result_state_manager.InvalidateResultState( - page_result_state1.next_page_token); +TEST_F(ResultStateManagerTest, + ShouldInvalidateExpiredTokensWhenGetNextPageOnItself) { + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = AddScoredDocuments( + {/*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); - // page_result_state1.next_page_token() shouldn't be found - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + ResultStateManager result_state_manager( + /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), + clock()); - // page_result_state2.next_page_token() should still exist + // Set time as 1s and add state. + clock()->SetSystemTimeMilliseconds(1000); ICING_ASSERT_OK_AND_ASSIGN( - page_result_state2, - result_state_manager.GetNextPage(page_result_state2.next_page_token)); - EXPECT_THAT( - page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit(/*document_id=*/4)))); + PageResultInfo page_result_info, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ASSERT_THAT(page_result_info.first, Not(Eq(kInvalidNextPageToken))); + + // 1. Set time as 1hr1s. + // 2. Then calling GetNextPage() on the state shouldn't get anything. + clock()->SetSystemTimeMilliseconds(kDefaultResultStateTtlInMs + 1000); + // page_result_info's token (page_result_info.first) shouldn't be found. + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } -TEST_F(ResultStateManagerTest, ShouldInvalidateAllTokens) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); +TEST_F(ResultStateManagerTest, ShouldInvalidateOneToken) { + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store().Put(CreateDocument(/*id=*/1))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store().Put(CreateDocument(/*id=*/2))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id3, + document_store().Put(CreateDocument(/*id=*/3))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id4, + document_store().Put(CreateDocument(/*id=*/4))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id5, + document_store().Put(CreateDocument(/*id=*/5))); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id6, + document_store().Put(CreateDocument(/*id=*/6))); + std::vector<ScoredDocumentHit> scored_document_hits1 = { + {document_id1, kSectionIdMaskNone, /*score=*/1}, + {document_id2, kSectionIdMaskNone, /*score=*/1}, + {document_id3, kSectionIdMaskNone, /*score=*/1}}; + std::vector<ScoredDocumentHit> scored_document_hits2 = { + {document_id4, kSectionIdMaskNone, /*score=*/1}, + {document_id5, kSectionIdMaskNone, /*score=*/1}, + {document_id6, kSectionIdMaskNone, /*score=*/1}}; ResultStateManager result_state_manager( /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); - result_state_manager.InvalidateAllResultStates(); + // Invalidate first result state by the token. + result_state_manager.InvalidateResultState(page_result_info1.first); - // page_result_state1.next_page_token() shouldn't be found - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + // page_result_info1's token (page_result_info1.first) shouldn't be found + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - // page_result_state2.next_page_token() shouldn't be found - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + // page_result_info2's token (page_result_info2.first) should still exist + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + // Should get docs. + ASSERT_THAT(page_result_info2.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(CreateDocument(/*id=*/5))); } -TEST_F(ResultStateManagerTest, ShouldInvalidateOldTokens) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); +TEST_F(ResultStateManagerTest, ShouldInvalidateAllTokens) { + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = AddScoredDocuments( + {/*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); ResultStateManager result_state_manager( /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), clock()); - // Set time as 1s and add state 1. - clock()->SetSystemTimeMilliseconds(1000); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); - // Set time as 1hr2s and add state 2. - clock()->SetSystemTimeMilliseconds(kDefaultResultStateTtlInMs + 2000); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); - // Invalidates expired states with default ttl (1 hr). This should only - // invalidate state 1. - result_state_manager.InvalidateExpiredResultStates(); + result_state_manager.InvalidateAllResultStates(); - // page_result_state1.next_page_token() shouldn't be found - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + // page_result_info1's token (page_result_info1.first) shouldn't be found + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - // page_result_state2.next_page_token() should be found - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state2, - result_state_manager.GetNextPage(page_result_state2.next_page_token)); - EXPECT_THAT(page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + // page_result_info2's token (page_result_info2.first) shouldn't be found + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } TEST_F(ResultStateManagerTest, ShouldRemoveOldestResultState) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); ResultStateManager result_state_manager(/*max_total_hits=*/2, document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + // Adding state 3 should cause state 1 to be removed. ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state2, - result_state_manager.GetNextPage(page_result_state2.next_page_token)); - EXPECT_THAT(page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/2)))); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + ASSERT_THAT(page_result_info2.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(document_protos2.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); } TEST_F(ResultStateManagerTest, InvalidatedResultStateShouldDecreaseCurrentHitsCount) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // Invalidates state 2, so that the number of hits current cached should be // decremented to 2. - result_state_manager.InvalidateResultState( - page_result_state2.next_page_token); + result_state_manager.InvalidateResultState(page_result_info2.first); // If invalidating state 2 correctly decremented the current hit count to 2, // then adding state 4 should still be within our budget and no other result // states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state1, - result_state_manager.GetNextPage(page_result_state1.next_page_token)); - EXPECT_THAT(page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/0)))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info1, + result_state_manager.GetNextPage( + page_result_info1.first, result_retriever())); + ASSERT_THAT(page_result_info1.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(document_protos1.at(1))); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); } TEST_F(ResultStateManagerTest, InvalidatedAllResultStatesShouldResetCurrentHitCount) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // Invalidates all states so that the current hit count will be 0. result_state_manager.InvalidateAllResultStates(); // If invalidating all states correctly reset the current hit count to 0, - // then the entirety of state 4 should still be within our budget and no other + // then adding state 4, 5, 6 should still be within our budget and no other // result states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); - ResultState result_state5 = - CreateResultState({AddScoredDocument(/*document_id=*/8), - AddScoredDocument(/*document_id=*/9)}, - /*num_per_page=*/1); - ResultState result_state6 = - CreateResultState({AddScoredDocument(/*document_id=*/10), - AddScoredDocument(/*document_id=*/11)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state5, - result_state_manager.RankAndPaginate(std::move(result_state5))); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state6, - result_state_manager.RankAndPaginate(std::move(result_state6))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state3.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state5, - result_state_manager.GetNextPage(page_result_state5.next_page_token)); - EXPECT_THAT(page_result_state5.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/8)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state6, - result_state_manager.GetNextPage(page_result_state6.next_page_token)); - EXPECT_THAT(page_result_state6.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/10)))); -} - -TEST_F(ResultStateManagerTest, - InvalidatedOldResultStatesShouldDecreaseCurrentHitsCount) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/8), - AddScoredDocument(/*document_id=*/9)}, - /*num_per_page=*/1); - - // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // So state 1 ~ state 4 will take up 6 hits in total. - ResultStateManager result_state_manager(/*max_total_hits=*/6, - document_store(), clock()); - // Set time as 1000ms and add state 1. - clock()->SetSystemTimeMilliseconds(1000); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); - // Set time as 1001ms and add state 2. - clock()->SetSystemTimeMilliseconds(1001); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); - // Set time as 1002ms and add state 3. - clock()->SetSystemTimeMilliseconds(1002); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); - // Set time as 1003ms and add state 4. - clock()->SetSystemTimeMilliseconds(1003); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); - - // Set time as kDefaultResultStateTtlInMs + 1001ms and invalidate expired - // states with default ttl (1 hr). This should invalidate state 1 and state 2. - clock()->SetSystemTimeMilliseconds(kDefaultResultStateTtlInMs + 1001); - result_state_manager.InvalidateExpiredResultStates(); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); + auto [scored_document_hits5, document_protos5] = + AddScoredDocuments({/*document_id=*/8, /*document_id=*/9}); + auto [scored_document_hits6, document_protos6] = + AddScoredDocuments({/*document_id=*/10, /*document_id=*/11}); + + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info5, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits5), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info6, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits6), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - // page_result_state1.next_page_token() shouldn't be found - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - // page_result_state2.next_page_token() shouldn't be found - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - // If invalidating state 1 and state 2 correctly decremented the current hit - // count by 4 (to 2), then adding state 5 should still be within our budget - // and no other result states should be evicted. - ResultState result_state5 = - CreateResultState({AddScoredDocument(/*document_id=*/10), - AddScoredDocument(/*document_id=*/11), - AddScoredDocument(/*document_id=*/12), - AddScoredDocument(/*document_id=*/13), - AddScoredDocument(/*document_id=*/14)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state5, - result_state_manager.RankAndPaginate(std::move(result_state5))); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info3.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - // page_result_state3.next_page_token() should be found since there is no - // eviction. - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); - // page_result_state4.next_page_token() should be found since there is no - // eviction. - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/8)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info5, + result_state_manager.GetNextPage( + page_result_info5.first, result_retriever())); + ASSERT_THAT(page_result_info5.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info5.second.results.at(0).document(), + EqualsProto(document_protos5.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info6, + result_state_manager.GetNextPage( + page_result_info6.first, result_retriever())); + ASSERT_THAT(page_result_info6.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info6.second.results.at(0).document(), + EqualsProto(document_protos6.at(1))); } TEST_F( ResultStateManagerTest, InvalidatedResultStateShouldDecreaseCurrentHitsCountByExactStateHitCount) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // Invalidates state 2, so that the number of hits current cached should be // decremented to 2. - result_state_manager.InvalidateResultState( - page_result_state2.next_page_token); + result_state_manager.InvalidateResultState(page_result_info2.first); // If invalidating state 2 correctly decremented the current hit count to 2, // then adding state 4 should still be within our budget and no other result // states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // If invalidating result state 2 correctly decremented the current hit count // to 2 and adding state 4 correctly incremented it to 3, then adding this // result state should trigger the eviction of state 1. - ResultState result_state5 = - CreateResultState({AddScoredDocument(/*document_id=*/8), - AddScoredDocument(/*document_id=*/9)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state5, - result_state_manager.RankAndPaginate(std::move(result_state5))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + auto [scored_document_hits5, document_protos5] = + AddScoredDocuments({/*document_id=*/8, /*document_id=*/9}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info5, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits5), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state5, - result_state_manager.GetNextPage(page_result_state5.next_page_token)); - EXPECT_THAT(page_result_state5.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/8)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info5, + result_state_manager.GetNextPage( + page_result_info5.first, result_retriever())); + ASSERT_THAT(page_result_info5.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info5.second.results.at(0).document(), + EqualsProto(document_protos5.at(1))); } TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCount) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // GetNextPage for result state 1 should return its result and decrement the // number of cached hits to 2. - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state1, - result_state_manager.GetNextPage(page_result_state1.next_page_token)); - EXPECT_THAT(page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/0)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info1, + result_state_manager.GetNextPage( + page_result_info1.first, result_retriever())); + ASSERT_THAT(page_result_info1.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(document_protos1.at(1))); // If retrieving the next page for result state 1 correctly decremented the // current hit count to 2, then adding state 4 should still be within our // budget and no other result states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state2, - result_state_manager.GetNextPage(page_result_state2.next_page_token)); - EXPECT_THAT(page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/2)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + ASSERT_THAT(page_result_info2.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(document_protos2.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); } TEST_F(ResultStateManagerTest, GetNextPageShouldDecreaseCurrentHitsCountByExactlyOnePage) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3)}, - /*num_per_page=*/1); - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = + AddScoredDocuments({/*document_id=*/0, /*document_id=*/1}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/2, /*document_id=*/3}); + auto [scored_document_hits3, document_protos3] = + AddScoredDocuments({/*document_id=*/4, /*document_id=*/5}); // Add the first three states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1 and a result set of 2 hits. So each - // result will take up one hit of our three hit budget. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1 and a + // result set of 2 hits. So each result will take up one hit of our three hit + // budget. ResultStateManager result_state_manager(/*max_total_hits=*/3, document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // GetNextPage for result state 1 should return its result and decrement the // number of cached hits to 2. - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state1, - result_state_manager.GetNextPage(page_result_state1.next_page_token)); - EXPECT_THAT(page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/0)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info1, + result_state_manager.GetNextPage( + page_result_info1.first, result_retriever())); + ASSERT_THAT(page_result_info1.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(document_protos1.at(1))); // If retrieving the next page for result state 1 correctly decremented the // current hit count to 2, then adding state 4 should still be within our // budget and no other result states should be evicted. - ResultState result_state4 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); + auto [scored_document_hits4, document_protos4] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state4, - result_state_manager.RankAndPaginate(std::move(result_state4))); + PageResultInfo page_result_info4, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits4), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // If retrieving the next page for result state 1 correctly decremented the // current hit count to 2 and adding state 4 correctly incremented it to 3, // then adding this result state should trigger the eviction of state 2. - ResultState result_state5 = - CreateResultState({AddScoredDocument(/*document_id=*/8), - AddScoredDocument(/*document_id=*/9)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state5, - result_state_manager.RankAndPaginate(std::move(result_state5))); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/4)))); + auto [scored_document_hits5, document_protos5] = + AddScoredDocuments({/*document_id=*/8, /*document_id=*/9}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info5, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits5), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state4, - result_state_manager.GetNextPage(page_result_state4.next_page_token)); - EXPECT_THAT(page_result_state4.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state5, - result_state_manager.GetNextPage(page_result_state5.next_page_token)); - EXPECT_THAT(page_result_state5.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/8)))); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info3, + result_state_manager.GetNextPage( + page_result_info3.first, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info4, + result_state_manager.GetNextPage( + page_result_info4.first, result_retriever())); + ASSERT_THAT(page_result_info4.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info4.second.results.at(0).document(), + EqualsProto(document_protos4.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN(page_result_info5, + result_state_manager.GetNextPage( + page_result_info5.first, result_retriever())); + ASSERT_THAT(page_result_info5.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info5.second.results.at(0).document(), + EqualsProto(document_protos5.at(1))); } TEST_F(ResultStateManagerTest, AddingOverBudgetResultStateShouldEvictAllStates) { - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2)}, - /*num_per_page=*/1); - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2}); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/3, /*document_id=*/4}); // Add the first two states. Remember, the first page for each result state - // won't be cached (since it is returned immediately from RankAndPaginate). - // Each result state has a page size of 1. So 3 hits will remain cached. + // won't be cached (since it is returned immediately from + // CacheAndRetrieveFirstPage). Each result state has a page size of 1. So 3 + // hits will remain cached. ResultStateManager result_state_manager(/*max_total_hits=*/4, document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // Add a result state that is larger than the entire budget. This should // result in all previous result states being evicted, the first hit from // result state 3 being returned and the next four hits being cached (the last // hit should be dropped because it exceeds the max). - ResultState result_state3 = - CreateResultState({AddScoredDocument(/*document_id=*/5), - AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7), - AddScoredDocument(/*document_id=*/8), - AddScoredDocument(/*document_id=*/9), - AddScoredDocument(/*document_id=*/10)}, - /*num_per_page=*/1); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state3, - result_state_manager.RankAndPaginate(std::move(result_state3))); + auto [scored_document_hits3, document_protos3] = AddScoredDocuments( + {/*document_id=*/5, /*document_id=*/6, /*document_id=*/7, + /*document_id=*/8, /*document_id=*/9, /*document_id=*/10}); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info3, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits3), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); + EXPECT_THAT(page_result_info3.first, Not(Eq(kInvalidNextPageToken))); // GetNextPage for result state 1 and 2 should return NOT_FOUND. - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state2.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info2.first, + result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); // Only the next four results in state 3 should be retrievable. - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/9)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/8)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/7)))); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state3, - result_state_manager.GetNextPage(page_result_state3.next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); - - // The final result should have been dropped because it exceeded the budget. + uint64_t next_page_token3 = page_result_info3.first; + ICING_ASSERT_OK_AND_ASSIGN( + page_result_info3, + result_state_manager.GetNextPage(next_page_token3, result_retriever())); + EXPECT_THAT(page_result_info3.first, Eq(next_page_token3)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(1))); + + ICING_ASSERT_OK_AND_ASSIGN( + page_result_info3, + result_state_manager.GetNextPage(next_page_token3, result_retriever())); + EXPECT_THAT(page_result_info3.first, Eq(next_page_token3)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(2))); + + ICING_ASSERT_OK_AND_ASSIGN( + page_result_info3, + result_state_manager.GetNextPage(next_page_token3, result_retriever())); + EXPECT_THAT(page_result_info3.first, Eq(next_page_token3)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(3))); + + ICING_ASSERT_OK_AND_ASSIGN( + page_result_info3, + result_state_manager.GetNextPage(next_page_token3, result_retriever())); + // The final document should have been dropped because it exceeded the budget, + // so the next page token of the second last round should be + // kInvalidNextPageToken. + EXPECT_THAT(page_result_info3.first, Eq(kInvalidNextPageToken)); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos3.at(4))); + + // Double check that next_page_token3 is not retrievable anymore. EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state3.next_page_token), + result_state_manager.GetNextPage(next_page_token3, result_retriever()), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } TEST_F(ResultStateManagerTest, AddingResultStateShouldEvictOverBudgetResultState) { - ResultStateManager result_state_manager(/*max_total_hits=*/4, - document_store(), clock()); // Add a result state that is larger than the entire budget. The entire result // state will still be cached - ResultState result_state1 = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4), - AddScoredDocument(/*document_id=*/5)}, - /*num_per_page=*/1); + auto [scored_document_hits1, document_protos1] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2, + /*document_id=*/3, /*document_id=*/4, /*document_id=*/5}); + + ResultStateManager result_state_manager(/*max_total_hits=*/4, + document_store(), clock()); + ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(result_state1))); + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits1), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // Add a result state. Because state2 + state1 is larger than the budget, // state1 should be evicted. - ResultState result_state2 = - CreateResultState({AddScoredDocument(/*document_id=*/6), - AddScoredDocument(/*document_id=*/7)}, - /*num_per_page=*/1); + auto [scored_document_hits2, document_protos2] = + AddScoredDocuments({/*document_id=*/6, /*document_id=*/7}); ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state2, - result_state_manager.RankAndPaginate(std::move(result_state2))); + PageResultInfo page_result_info2, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits2), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/1), + document_store(), result_retriever())); // state1 should have been evicted and state2 should still be retrievable. - EXPECT_THAT( - result_state_manager.GetNextPage(page_result_state1.next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - - ICING_ASSERT_OK_AND_ASSIGN( - page_result_state2, - result_state_manager.GetNextPage(page_result_state2.next_page_token)); - EXPECT_THAT(page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(CreateScoredHit( - /*document_id=*/6)))); -} - -TEST_F(ResultStateManagerTest, ShouldGetSnippetContext) { - ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); - result_spec.mutable_snippet_spec()->set_num_to_snippet(5); - result_spec.mutable_snippet_spec()->set_num_matches_per_property(5); - result_spec.mutable_snippet_spec()->set_max_window_utf32_length(5); - - SearchSpecProto search_spec; - search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); - - SectionRestrictQueryTermsMap query_terms_map; - query_terms_map.emplace("term1", std::unordered_set<std::string>()); - - ResultState original_result_state = ResultState( - /*scored_document_hits=*/{AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - query_terms_map, search_spec, CreateScoringSpec(), result_spec, - document_store()); - - ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), - clock()); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - - ASSERT_THAT(page_result_state.next_page_token, Gt(kInvalidNextPageToken)); - - EXPECT_THAT(page_result_state.snippet_context.match_type, - Eq(TermMatchType::EXACT_ONLY)); - EXPECT_TRUE(page_result_state.snippet_context.query_terms.find("term1") != - page_result_state.snippet_context.query_terms.end()); - EXPECT_THAT(page_result_state.snippet_context.snippet_spec, - EqualsProto(result_spec.snippet_spec())); -} - -TEST_F(ResultStateManagerTest, ShouldGetDefaultSnippetContext) { - ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/1); - // 0 indicates no snippeting - result_spec.mutable_snippet_spec()->set_num_to_snippet(0); - result_spec.mutable_snippet_spec()->set_num_matches_per_property(0); - result_spec.mutable_snippet_spec()->set_max_window_utf32_length(0); - - SearchSpecProto search_spec; - search_spec.set_term_match_type(TermMatchType::EXACT_ONLY); - - SectionRestrictQueryTermsMap query_terms_map; - query_terms_map.emplace("term1", std::unordered_set<std::string>()); - - ResultState original_result_state = ResultState( - /*scored_document_hits=*/{AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1)}, - query_terms_map, search_spec, CreateScoringSpec(), result_spec, - document_store()); - - ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), - clock()); - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - - ASSERT_THAT(page_result_state.next_page_token, Gt(kInvalidNextPageToken)); - - EXPECT_THAT(page_result_state.snippet_context.query_terms, IsEmpty()); - EXPECT_THAT( - page_result_state.snippet_context.snippet_spec, - EqualsProto(ResultSpecProto::SnippetSpecProto::default_instance())); - EXPECT_THAT(page_result_state.snippet_context.match_type, - Eq(TermMatchType::UNKNOWN)); -} - -TEST_F(ResultStateManagerTest, ShouldGetCorrectNumPreviouslyReturned) { - ResultState original_result_state = - CreateResultState({AddScoredDocument(/*document_id=*/0), - AddScoredDocument(/*document_id=*/1), - AddScoredDocument(/*document_id=*/2), - AddScoredDocument(/*document_id=*/3), - AddScoredDocument(/*document_id=*/4)}, - /*num_per_page=*/2); - - ResultStateManager result_state_manager( - /*max_total_hits=*/std::numeric_limits<int>::max(), document_store(), - clock()); - - // First page, 2 results - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - ASSERT_THAT(page_result_state1.scored_document_hits.size(), Eq(2)); - - // No previously returned results - EXPECT_THAT(page_result_state1.num_previously_returned, Eq(0)); - - uint64_t next_page_token = page_result_state1.next_page_token; - - // Second page, 2 results - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state2, - result_state_manager.GetNextPage(next_page_token)); - ASSERT_THAT(page_result_state2.scored_document_hits.size(), Eq(2)); - - // num_previously_returned = size of first page - EXPECT_THAT(page_result_state2.num_previously_returned, Eq(2)); - - // Third page, 1 result - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state3, - result_state_manager.GetNextPage(next_page_token)); - ASSERT_THAT(page_result_state3.scored_document_hits.size(), Eq(1)); - - // num_previously_returned = size of first and second pages - EXPECT_THAT(page_result_state3.num_previously_returned, Eq(4)); - - // No more results - EXPECT_THAT(result_state_manager.GetNextPage(next_page_token), + EXPECT_THAT(result_state_manager.GetNextPage(page_result_info1.first, + result_retriever()), StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); -} -TEST_F(ResultStateManagerTest, ShouldStoreAllHits) { - ScoredDocumentHit scored_hit_1 = AddScoredDocument(/*document_id=*/0); - ScoredDocumentHit scored_hit_2 = AddScoredDocument(/*document_id=*/1); - ScoredDocumentHit scored_hit_3 = AddScoredDocument(/*document_id=*/2); - ScoredDocumentHit scored_hit_4 = AddScoredDocument(/*document_id=*/3); - ScoredDocumentHit scored_hit_5 = AddScoredDocument(/*document_id=*/4); + ICING_ASSERT_OK_AND_ASSIGN(page_result_info2, + result_state_manager.GetNextPage( + page_result_info2.first, result_retriever())); + ASSERT_THAT(page_result_info2.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(document_protos2.at(1))); +} - ResultState original_result_state = CreateResultState( - {scored_hit_1, scored_hit_2, scored_hit_3, scored_hit_4, scored_hit_5}, - /*num_per_page=*/2); +TEST_F(ResultStateManagerTest, + AddingResultStateShouldNotTruncatedAfterFirstPage) { + // Add a result state that is larger than the entire budget, but within the + // entire budget after the first page. The entire result state will still be + // cached and not truncated. + auto [scored_document_hits, document_protos] = AddScoredDocuments( + {/*document_id=*/0, /*document_id=*/1, /*document_id=*/2, + /*document_id=*/3, /*document_id=*/4}); ResultStateManager result_state_manager(/*max_total_hits=*/4, document_store(), clock()); @@ -1127,33 +1371,46 @@ TEST_F(ResultStateManagerTest, ShouldStoreAllHits) { // The 5 input scored document hits will not be truncated. The first page of // two hits will be returned immediately and the other three hits will fit // within our caching budget. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/true), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(/*num_per_page=*/2), + document_store(), result_retriever())); // First page, 2 results - ICING_ASSERT_OK_AND_ASSIGN( - PageResultState page_result_state1, - result_state_manager.RankAndPaginate(std::move(original_result_state))); - EXPECT_THAT(page_result_state1.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_5), - EqualsScoredDocumentHit(scored_hit_4))); + ASSERT_THAT(page_result_info1.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info1.second.results.at(0).document(), + EqualsProto(document_protos.at(0))); + EXPECT_THAT(page_result_info1.second.results.at(1).document(), + EqualsProto(document_protos.at(1))); - uint64_t next_page_token = page_result_state1.next_page_token; + uint64_t next_page_token = page_result_info1.first; // Second page, 2 results. - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state2, - result_state_manager.GetNextPage(next_page_token)); - EXPECT_THAT(page_result_state2.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_3), - EqualsScoredDocumentHit(scored_hit_2))); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info2, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + ASSERT_THAT(page_result_info2.second.results, SizeIs(2)); + EXPECT_THAT(page_result_info2.second.results.at(0).document(), + EqualsProto(document_protos.at(2))); + EXPECT_THAT(page_result_info2.second.results.at(1).document(), + EqualsProto(document_protos.at(3))); // Third page, 1 result. - ICING_ASSERT_OK_AND_ASSIGN(PageResultState page_result_state3, - result_state_manager.GetNextPage(next_page_token)); - EXPECT_THAT(page_result_state3.scored_document_hits, - ElementsAre(EqualsScoredDocumentHit(scored_hit_1))); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info3, + result_state_manager.GetNextPage(next_page_token, result_retriever())); + ASSERT_THAT(page_result_info3.second.results, SizeIs(1)); + EXPECT_THAT(page_result_info3.second.results.at(0).document(), + EqualsProto(document_protos.at(4))); // Fourth page, 0 results. - EXPECT_THAT(result_state_manager.GetNextPage(next_page_token), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT( + result_state_manager.GetNextPage(next_page_token, result_retriever()), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); } } // namespace diff --git a/icing/result/result-state-v2.cc b/icing/result/result-state-v2.cc index dde50e3..9cb3838 100644 --- a/icing/result/result-state-v2.cc +++ b/icing/result/result-state-v2.cc @@ -52,6 +52,8 @@ ResultStateV2::ResultStateV2( snippet_context_(CreateSnippetContext(std::move(query_terms), search_spec, result_spec)), num_per_page_(result_spec.num_per_page()), + num_total_bytes_per_page_threshold_( + result_spec.num_total_bytes_per_page_threshold()), num_total_hits_(nullptr) { for (const TypePropertyMask& type_field_mask : result_spec.type_property_masks()) { diff --git a/icing/result/result-state-v2.h b/icing/result/result-state-v2.h index fc56936..97ff4b6 100644 --- a/icing/result/result-state-v2.h +++ b/icing/result/result-state-v2.h @@ -78,6 +78,11 @@ class ResultStateV2 { return num_per_page_; } + int32_t num_total_bytes_per_page_threshold() const + ICING_SHARED_LOCKS_REQUIRED(mutex) { + return num_total_bytes_per_page_threshold_; + } + absl_ports::shared_mutex mutex; // When evaluating the next top K hits from scored_document_hits_ranker, some @@ -113,8 +118,16 @@ class ResultStateV2 { // Number of results to return in each page. int num_per_page_ ICING_GUARDED_BY(mutex); - // Pointer to a global counter to sum up the size of - // scored_document_hits_ranker in all ResultStates. + // The threshold of total bytes of all documents to cutoff, in order to limit + // # of bytes in a single page. + // Note that it doesn't guarantee the result # of bytes will be smaller, equal + // to, or larger than the threshold. Instead, it is just a threshold to + // cutoff, and only guarantees total bytes of search results won't exceed the + // threshold too much. + int32_t num_total_bytes_per_page_threshold_ ICING_GUARDED_BY(mutex); + + // Pointer to a global counter to sum up the size of scored_document_hits in + // all ResultStates. // Does not own. std::atomic<int>* num_total_hits_ ICING_GUARDED_BY(mutex); }; diff --git a/icing/result/result-state-v2_test.cc b/icing/result/result-state-v2_test.cc index 8e6b29a..360e03a 100644 --- a/icing/result/result-state-v2_test.cc +++ b/icing/result/result-state-v2_test.cc @@ -122,6 +122,49 @@ class ResultStateV2Test : public ::testing::Test { std::atomic<int> num_total_hits_; }; +TEST_F(ResultStateV2Test, ShouldInitializeValuesAccordingToSpecs) { + ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); + result_spec.set_num_total_bytes_per_page_threshold(4096); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), result_spec, + document_store()); + + absl_ports::shared_lock l(&result_state.mutex); + + EXPECT_THAT(result_state.num_returned, Eq(0)); + EXPECT_THAT(result_state.num_per_page(), Eq(result_spec.num_per_page())); + EXPECT_THAT(result_state.num_total_bytes_per_page_threshold(), + Eq(result_spec.num_total_bytes_per_page_threshold())); +} + +TEST_F(ResultStateV2Test, ShouldInitializeValuesAccordingToDefaultSpecs) { + ResultSpecProto default_result_spec = ResultSpecProto::default_instance(); + ASSERT_THAT(default_result_spec.num_per_page(), Eq(10)); + ASSERT_THAT(default_result_spec.num_total_bytes_per_page_threshold(), + Eq(std::numeric_limits<int32_t>::max())); + + ResultStateV2 result_state( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::vector<ScoredDocumentHit>(), + /*is_descending=*/true), + /*query_terms=*/{}, CreateSearchSpec(TermMatchType::EXACT_ONLY), + CreateScoringSpec(/*is_descending_order=*/true), default_result_spec, + document_store()); + + absl_ports::shared_lock l(&result_state.mutex); + + EXPECT_THAT(result_state.num_returned, Eq(0)); + EXPECT_THAT(result_state.num_per_page(), + Eq(default_result_spec.num_per_page())); + EXPECT_THAT(result_state.num_total_bytes_per_page_threshold(), + Eq(default_result_spec.num_total_bytes_per_page_threshold())); +} + TEST_F(ResultStateV2Test, ShouldReturnSnippetContextAccordingToSpecs) { ResultSpecProto result_spec = CreateResultSpec(/*num_per_page=*/2); result_spec.mutable_snippet_spec()->set_num_to_snippet(5); diff --git a/icing/result/snippet-retriever.cc b/icing/result/snippet-retriever.cc index bd1524e..2391900 100644 --- a/icing/result/snippet-retriever.cc +++ b/icing/result/snippet-retriever.cc @@ -80,6 +80,20 @@ inline std::string AddIndexToPath(int values_size, int index, // is applied based on the Token's type. std::string NormalizeToken(const Normalizer& normalizer, const Token& token) { switch (token.type) { + case Token::Type::RFC822_NAME: + [[fallthrough]]; + case Token::Type::RFC822_COMMENT: + [[fallthrough]]; + case Token::Type::RFC822_LOCAL_ADDRESS: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS_COMPONENT_HOST: + [[fallthrough]]; + case Token::Type::RFC822_TOKEN: + [[fallthrough]]; case Token::Type::REGULAR: return normalizer.NormalizeTerm(token.text); case Token::Type::VERBATIM: @@ -126,6 +140,20 @@ CharacterIterator FindMatchEnd(const Normalizer& normalizer, const Token& token, [[fallthrough]]; case Token::Type::QUERY_PROPERTY: [[fallthrough]]; + case Token::Type::RFC822_NAME: + [[fallthrough]]; + case Token::Type::RFC822_COMMENT: + [[fallthrough]]; + case Token::Type::RFC822_LOCAL_ADDRESS: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL: + [[fallthrough]]; + case Token::Type::RFC822_ADDRESS_COMPONENT_HOST: + [[fallthrough]]; + case Token::Type::RFC822_TOKEN: + [[fallthrough]]; case Token::Type::INVALID: ICING_LOG(WARNING) << "Unexpected Token type " << static_cast<int>(token.type) diff --git a/icing/scoring/bm25f-calculator.cc b/icing/scoring/bm25f-calculator.cc index 4b426a9..28ee2ba 100644 --- a/icing/scoring/bm25f-calculator.cc +++ b/icing/scoring/bm25f-calculator.cc @@ -20,7 +20,6 @@ #include <unordered_set> #include <vector> -#include "icing/absl_ports/str_cat.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/store/corpus-associated-scoring-data.h" @@ -116,9 +115,8 @@ float Bm25fCalculator::ComputeScore(const DocHitInfoIterator* query_it, score += idf_weight * normalized_tf; } - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "BM25F: corpus_id:%d docid:%d score:%f\n", data.corpus_id(), - hit_info.document_id(), score); + ICING_VLOG(1) << "BM25F: corpus_id:" << data.corpus_id() << " docid:" + << hit_info.document_id() << " score:" << score; return score; } @@ -144,8 +142,7 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, // First, figure out corpus scoring data. auto status_or = document_store_->GetCorpusAssociatedScoreData(corpus_id); if (!status_or.ok()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "No scoring data for corpus [%d]", corpus_id); + ICING_LOG(ERROR) << "No scoring data for corpus [" << corpus_id << "]"; return 0; } CorpusAssociatedScoreData csdata = status_or.ValueOrDie(); @@ -155,9 +152,8 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, float idf = nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi + 0.5f)) : 0.0f; corpus_idf_map_.insert({corpus_term_info.value, idf}); - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "corpus_id:%d term:%s N:%d nqi:%d idf:%f", corpus_id, - std::string(term).c_str(), num_docs, nqi, idf); + ICING_VLOG(1) << "corpus_id:" << corpus_id << " term:" + << term << " N:" << num_docs << "nqi:" << nqi << " idf:" << idf; return idf; } @@ -176,8 +172,7 @@ float Bm25fCalculator::GetCorpusAvgDocLength(CorpusId corpus_id) { // First, figure out corpus scoring data. auto status_or = document_store_->GetCorpusAssociatedScoreData(corpus_id); if (!status_or.ok()) { - ICING_LOG(ERROR) << IcingStringUtil::StringPrintf( - "No scoring data for corpus [%d]", corpus_id); + ICING_LOG(ERROR) << "No scoring data for corpus [" << corpus_id << "]"; return 0; } CorpusAssociatedScoreData csdata = status_or.ValueOrDie(); @@ -205,9 +200,9 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency( float normalized_tf = f_q * (k1_ + 1) / (f_q + k1_ * (1 - b_ + b_ * dl / avgdl)); - ICING_VLOG(1) << IcingStringUtil::StringPrintf( - "corpus_id:%d docid:%d dl:%d avgdl:%f f_q:%f norm_tf:%f\n", - data.corpus_id(), hit_info.document_id(), dl, avgdl, f_q, normalized_tf); + ICING_VLOG(1) << "corpus_id:" << data.corpus_id() << " docid:" + << hit_info.document_id() << " dl:" << dl << " avgdl:" << avgdl << " f_q:" + << f_q << " norm_tf:" << normalized_tf; return normalized_tf; } @@ -240,8 +235,8 @@ SchemaTypeId Bm25fCalculator::GetSchemaTypeId(DocumentId document_id) const { // GetDocumentFilterData is if the document_id is outside of the range of // allocated document_ids, which shouldn't be possible since we're getting // this document_id from the posting lists. - ICING_LOG(WARNING) << IcingStringUtil::StringPrintf( - "No document filter data for document [%d]", document_id); + ICING_LOG(WARNING) << "No document filter data for document [" + << document_id << "]"; return kInvalidSchemaTypeId; } return filter_data_optional.value().schema_type_id(); diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker.cc b/icing/scoring/priority-queue-scored-document-hits-ranker.cc index 13da0ae..691b088 100644 --- a/icing/scoring/priority-queue-scored-document-hits-ranker.cc +++ b/icing/scoring/priority-queue-scored-document-hits-ranker.cc @@ -23,11 +23,9 @@ namespace icing { namespace lib { PriorityQueueScoredDocumentHitsRanker::PriorityQueueScoredDocumentHitsRanker( - const std::vector<ScoredDocumentHit>& scored_document_hits, - bool is_descending) + std::vector<ScoredDocumentHit>&& scored_document_hits, bool is_descending) : comparator_(/*is_ascending=*/!is_descending), - scored_document_hits_pq_(scored_document_hits.begin(), - scored_document_hits.end(), comparator_) {} + scored_document_hits_pq_(comparator_, std::move(scored_document_hits)) {} ScoredDocumentHit PriorityQueueScoredDocumentHitsRanker::PopNext() { ScoredDocumentHit ret = scored_document_hits_pq_.top(); diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker.h b/icing/scoring/priority-queue-scored-document-hits-ranker.h index c104585..e0ae4b0 100644 --- a/icing/scoring/priority-queue-scored-document-hits-ranker.h +++ b/icing/scoring/priority-queue-scored-document-hits-ranker.h @@ -29,7 +29,7 @@ namespace lib { class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker { public: explicit PriorityQueueScoredDocumentHitsRanker( - const std::vector<ScoredDocumentHit>& scored_document_hits, + std::vector<ScoredDocumentHit>&& scored_document_hits, bool is_descending = true); ~PriorityQueueScoredDocumentHitsRanker() override = default; diff --git a/icing/store/document-store.cc b/icing/store/document-store.cc index aa3122b..8a79b6d 100644 --- a/icing/store/document-store.cc +++ b/icing/store/document-store.cc @@ -1489,8 +1489,11 @@ libtextclassifier3::Status DocumentStore::UpdateSchemaStore( // Update the SchemaTypeId for this entry ICING_ASSIGN_OR_RETURN(SchemaTypeId schema_type_id, schema_store_->GetSchemaTypeId(document.schema())); - filter_cache_->mutable_array()[document_id].set_schema_type_id( - schema_type_id); + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<DocumentFilterData>::MutableView + doc_filter_data_view, + filter_cache_->GetMutable(document_id)); + doc_filter_data_view.Get().set_schema_type_id(schema_type_id); } else { // Document is no longer valid with the new SchemaStore. Mark as // deleted @@ -1550,8 +1553,11 @@ libtextclassifier3::Status DocumentStore::OptimizedUpdateSchemaStore( ICING_ASSIGN_OR_RETURN( SchemaTypeId schema_type_id, schema_store_->GetSchemaTypeId(document.schema())); - filter_cache_->mutable_array()[document_id].set_schema_type_id( - schema_type_id); + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<DocumentFilterData>::MutableView + doc_filter_data_view, + filter_cache_->GetMutable(document_id)); + doc_filter_data_view.Get().set_schema_type_id(schema_type_id); } if (revalidate_document) { delete_document = !document_validator_.Validate(document).ok(); @@ -1576,9 +1582,10 @@ libtextclassifier3::Status DocumentStore::Optimize() { return libtextclassifier3::Status::OK; } -libtextclassifier3::Status DocumentStore::OptimizeInto( - const std::string& new_directory, const LanguageSegmenter* lang_segmenter, - OptimizeStatsProto* stats) { +libtextclassifier3::StatusOr<std::vector<DocumentId>> +DocumentStore::OptimizeInto(const std::string& new_directory, + const LanguageSegmenter* lang_segmenter, + OptimizeStatsProto* stats) { // Validates directory if (new_directory == base_dir_) { return absl_ports::InvalidArgumentError( @@ -1596,6 +1603,7 @@ libtextclassifier3::Status DocumentStore::OptimizeInto( int num_deleted = 0; int num_expired = 0; UsageStore::UsageScores default_usage; + std::vector<DocumentId> document_id_old_to_new(size, kInvalidDocumentId); for (DocumentId document_id = 0; document_id < size; document_id++) { auto document_or = Get(document_id, /*clear_internal_fields=*/false); if (absl_ports::IsNotFound(document_or.status())) { @@ -1641,6 +1649,8 @@ libtextclassifier3::Status DocumentStore::OptimizeInto( return new_document_id_or.status(); } + document_id_old_to_new[document_id] = new_document_id_or.ValueOrDie(); + // Copy over usage scores. ICING_ASSIGN_OR_RETURN(UsageStore::UsageScores usage_scores, usage_store_->GetUsageScores(document_id)); @@ -1659,7 +1669,7 @@ libtextclassifier3::Status DocumentStore::OptimizeInto( stats->set_num_expired_documents(num_expired); } ICING_RETURN_IF_ERROR(new_doc_store->PersistToDisk(PersistType::FULL)); - return libtextclassifier3::Status::OK; + return document_id_old_to_new; } libtextclassifier3::StatusOr<DocumentStore::OptimizeInfo> diff --git a/icing/store/document-store.h b/icing/store/document-store.h index 450b1b9..41dd6a9 100644 --- a/icing/store/document-store.h +++ b/icing/store/document-store.h @@ -388,10 +388,10 @@ class DocumentStore { // method based on device usage. // // Returns: - // OK on success + // A vector that maps from old document id to new document id on success // INVALID_ARGUMENT if new_directory is same as current base directory // INTERNAL_ERROR on IO error - libtextclassifier3::Status OptimizeInto( + libtextclassifier3::StatusOr<std::vector<DocumentId>> OptimizeInto( const std::string& new_directory, const LanguageSegmenter* lang_segmenter, OptimizeStatsProto* stats = nullptr); diff --git a/icing/store/document-store_test.cc b/icing/store/document-store_test.cc index 59e5d74..6f444cb 100644 --- a/icing/store/document-store_test.cc +++ b/icing/store/document-store_test.cc @@ -59,6 +59,7 @@ namespace { using ::icing::lib::portable_equals_proto::EqualsProto; using ::testing::_; +using ::testing::ElementsAre; using ::testing::Eq; using ::testing::Ge; using ::testing::Gt; @@ -1058,8 +1059,8 @@ TEST_F(DocumentStoreTest, OptimizeInto) { // deleted ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); - ICING_ASSERT_OK( - doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); + EXPECT_THAT(doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(ElementsAre(0, 1, 2))); int64_t optimized_size1 = filesystem_.GetFileSize(optimized_document_log.c_str()); EXPECT_EQ(original_size, optimized_size1); @@ -1069,8 +1070,9 @@ TEST_F(DocumentStoreTest, OptimizeInto) { ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); ICING_ASSERT_OK(doc_store->Delete("namespace", "uri1")); - ICING_ASSERT_OK( - doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); + // DocumentId 0 is removed. + EXPECT_THAT(doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(ElementsAre(kInvalidDocumentId, 0, 1))); int64_t optimized_size2 = filesystem_.GetFileSize(optimized_document_log.c_str()); EXPECT_THAT(original_size, Gt(optimized_size2)); @@ -1083,11 +1085,39 @@ TEST_F(DocumentStoreTest, OptimizeInto) { // expired ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); - ICING_ASSERT_OK( - doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get())); + // DocumentId 0 is removed, and DocumentId 2 is expired. + EXPECT_THAT( + doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(ElementsAre(kInvalidDocumentId, 0, kInvalidDocumentId))); int64_t optimized_size3 = filesystem_.GetFileSize(optimized_document_log.c_str()); EXPECT_THAT(optimized_size2, Gt(optimized_size3)); + + // Delete the last document + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); + ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); + ICING_ASSERT_OK(doc_store->Delete("namespace", "uri2")); + // DocumentId 0 and 1 is removed, and DocumentId 2 is expired. + EXPECT_THAT(doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(ElementsAre(kInvalidDocumentId, kInvalidDocumentId, + kInvalidDocumentId))); + int64_t optimized_size4 = + filesystem_.GetFileSize(optimized_document_log.c_str()); + EXPECT_THAT(optimized_size3, Gt(optimized_size4)); +} + +TEST_F(DocumentStoreTest, OptimizeIntoForEmptyDocumentStore) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, document_store_dir_, &fake_clock_, + schema_store_.get())); + std::unique_ptr<DocumentStore> doc_store = + std::move(create_result.document_store); + std::string optimized_dir = document_store_dir_ + "_optimize"; + ASSERT_TRUE(filesystem_.DeleteDirectoryRecursively(optimized_dir.c_str())); + ASSERT_TRUE(filesystem_.CreateDirectoryRecursively(optimized_dir.c_str())); + EXPECT_THAT(doc_store->OptimizeInto(optimized_dir, lang_segmenter_.get()), + IsOkAndHolds(IsEmpty())); } TEST_F(DocumentStoreTest, ShouldRecoverFromDataLoss) { diff --git a/icing/tokenization/icu/icu-language-segmenter_test.cc b/icing/tokenization/icu/icu-language-segmenter_test.cc index 4098be5..71e04e2 100644 --- a/icing/tokenization/icu/icu-language-segmenter_test.cc +++ b/icing/tokenization/icu/icu-language-segmenter_test.cc @@ -15,12 +15,12 @@ #include <memory> #include <string_view> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/absl_ports/str_cat.h" +#include "icing/jni/jni-cache.h" #include "icing/testing/common-matchers.h" #include "icing/testing/icu-data-file-helper.h" #include "icing/testing/icu-i18n-test-utils.h" diff --git a/icing/tokenization/language-segmenter-factory.h b/icing/tokenization/language-segmenter-factory.h index cae3eee..2505a07 100644 --- a/icing/tokenization/language-segmenter-factory.h +++ b/icing/tokenization/language-segmenter-factory.h @@ -18,9 +18,8 @@ #include <memory> #include <string_view> -#include "icing/jni/jni-cache.h" - #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/jni/jni-cache.h" #include "icing/tokenization/language-segmenter.h" namespace icing { diff --git a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc index 8e1e563..dbd7f5a 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc @@ -21,11 +21,11 @@ #include <cmath> #include <map> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/text_classifier/lib3/utils/java/jni-base.h" #include "icing/text_classifier/lib3/utils/java/jni-helper.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/jni/jni-cache.h" #include "icing/util/status-macros.h" namespace icing { diff --git a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.h b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.h index 41b470c..537666c 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.h +++ b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.h @@ -20,8 +20,8 @@ #include <queue> #include <string> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/java/jni-base.h" +#include "icing/jni/jni-cache.h" namespace icing { namespace lib { diff --git a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter-factory.cc b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter-factory.cc index 0da4c2d..a251f90 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter-factory.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter-factory.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "icing/jni/jni-cache.h" #include "icing/absl_ports/canonical_errors.h" +#include "icing/jni/jni-cache.h" #include "icing/tokenization/language-segmenter-factory.h" #include "icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h" #include "icing/util/logging.h" diff --git a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h index f06dac9..29df4ee 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h +++ b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.h @@ -21,8 +21,8 @@ #include <string_view> #include <vector> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/jni/jni-cache.h" #include "icing/tokenization/language-segmenter.h" namespace icing { diff --git a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc index 8b13cd1..47a01fe 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc @@ -17,11 +17,11 @@ #include <memory> #include <string_view> -#include "icing/jni/jni-cache.h" #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "gmock/gmock.h" #include "icing/absl_ports/str_cat.h" +#include "icing/jni/jni-cache.h" #include "icing/testing/common-matchers.h" #include "icing/testing/icu-i18n-test-utils.h" #include "icing/testing/jni-test-helpers.h" diff --git a/icing/tokenization/rfc822-tokenizer.cc b/icing/tokenization/rfc822-tokenizer.cc new file mode 100644 index 0000000..4a96783 --- /dev/null +++ b/icing/tokenization/rfc822-tokenizer.cc @@ -0,0 +1,565 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/tokenization/rfc822-tokenizer.h" + +#include <algorithm> +#include <deque> +#include <queue> +#include <string_view> +#include <utility> + +#include "icing/tokenization/token.h" +#include "icing/tokenization/tokenizer.h" +#include "icing/util/character-iterator.h" +#include "icing/util/i18n-utils.h" +#include "icing/util/status-macros.h" +#include "unicode/umachine.h" + +namespace icing { +namespace lib { + +class Rfc822TokenIterator : public Tokenizer::Iterator { + public: + // Cursor is the index into the string_view, text_end_ is the length. + explicit Rfc822TokenIterator(std::string_view text) + : term_(std::move(text)), + iterator_(text, 0, 0, 0), + text_end_(text.length()) {} + + struct NameInfo { + NameInfo(const char* at_sign, bool name_found) + : at_sign(at_sign), name_found(name_found) {} + const char* at_sign; + bool name_found; + }; + + bool Advance() override { + // Advance through the queue. + if (!token_queue_.empty()) { + token_queue_.pop_front(); + } + + // There is still something left. + if (!token_queue_.empty()) { + return true; + } + + // Done with the entire string_view + if (iterator_.utf8_index() >= text_end_) { + return false; + } + + AdvancePastWhitespace(); + + GetNextRfc822Token(); + + return true; + } + + // Advance until the next email delimiter, generating as many tokens as + // necessary. + void GetNextRfc822Token() { + int token_start = iterator_.utf8_index(); + const char* at_sign_in_name = nullptr; + bool address_found = false; + bool name_found = false; + // We start at unquoted and run until a ",;\n<( . + while (iterator_.utf8_index() < text_end_) { + UChar32 c = iterator_.GetCurrentChar(); + if (c == ',' || c == ';' || c == '\n') { + // End of the token, advance cursor past this then quit + token_queue_.push_back(Token( + Token::Type::RFC822_TOKEN, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + AdvanceCursor(); + break; + } + + if (c == '"') { + NameInfo quoted_result = ConsumeQuotedSection(); + if (quoted_result.at_sign != nullptr) { + at_sign_in_name = quoted_result.at_sign; + } + if (!name_found) { + name_found = quoted_result.name_found; + } + } else if (c == '(') { + ConsumeParenthesizedSection(); + } else if (c == '<') { + // Only set address_found to true if ConsumeAdress returns true. + // Otherwise, keep address_found as is to prevent setting address_found + // back to false if it is true + if (ConsumeAddress()) { + address_found = true; + } + } else { + NameInfo unquoted_result = ConsumeUnquotedSection(); + if (unquoted_result.at_sign != nullptr) { + at_sign_in_name = unquoted_result.at_sign; + } + if (!name_found) { + name_found = unquoted_result.name_found; + } + } + } + if (iterator_.utf8_index() >= text_end_) { + token_queue_.push_back( + Token(Token::Type::RFC822_TOKEN, + term_.substr(token_start, text_end_ - token_start))); + } + + // At this point the token_queue is not empty. + // If an address is found, use the tokens we have + // If an address isn't found, and a name isn't found, also use the tokens + // we have. + // If an address isn't found but a name is, convert name Tokens to email + // Tokens + if (!address_found && name_found) { + ConvertNameToEmail(at_sign_in_name); + } + } + + void ConvertNameToEmail(const char* at_sign_in_name) { + // The name tokens will be will be used as the address now + const char* address_start = nullptr; + const char* local_address_end = nullptr; + const char* address_end = term_.begin(); + + // If we need to transform name tokens into various tokens, we keep the + // order of which the name tokens appeared. Name tokens that appear before + // an @ sign in the name will become RFC822_ADDRESS_COMPONENT_LOCAL, and + // those after will become RFC822_ADDRESS_COMPONENT_HOST. We aren't able + // to determine RFC822_ADDRESS and RFC822_LOCAL_ADDRESS before checking + // the name tokens, so they will be added after the component tokens. + + for (Token& token : token_queue_) { + if (token.type == Token::Type::RFC822_NAME) { + // Names need to be converted to address tokens + std::string_view text = token.text; + + // Find the ADDRESS and LOCAL_ADDRESS. + if (address_start == nullptr) { + address_start = text.begin(); + } + + if (at_sign_in_name >= text.end()) { + local_address_end = text.end(); + } + + address_end = text.end(); + + if (text.begin() < at_sign_in_name) { + token = Token(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, text); + } else if (text.begin() > at_sign_in_name) { + token = Token(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, text); + } + } + } + + token_queue_.push_back( + Token(Token::Type::RFC822_ADDRESS, + std::string_view(address_start, address_end - address_start))); + + if (local_address_end != nullptr) { + token_queue_.push_back(Token( + Token::Type::RFC822_LOCAL_ADDRESS, + std::string_view(address_start, local_address_end - address_start))); + } + } + + // Returns the location of the last at sign in the unquoted section, and if + // we have found a name. This is useful in case we do not find an address + // and have to use the name. An unquoted section may look like "Alex Sav", or + // "alex@google.com". In the absense of a bracketed email address, the + // unquoted section will be used as the email address along with the quoted + // section. + NameInfo ConsumeUnquotedSection() { + const char* at_sign_location = nullptr; + UChar32 c; + + int token_start = -1; + bool name_found = false; + + // Advance to another state or a character marking the end of token, one + // of \n,; . + while (iterator_.utf8_index() < text_end_) { + c = iterator_.GetCurrentChar(); + + if (i18n_utils::IsAlphaNumeric(c)) { + name_found = true; + + if (token_start == -1) { + // Start recording + token_start = iterator_.utf8_index(); + } + AdvanceCursor(); + + } else { + if (token_start != -1) { + if (c == '@') { + // Mark the last @ sign. + at_sign_location = term_.data() + iterator_.utf8_index(); + } + + // The character is non alphabetic, save a token. + token_queue_.push_back(Token( + Token::Type::RFC822_NAME, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + token_start = -1; + } + + if (c == '"' || c == '<' || c == '(' || c == '\n' || c == ';' || + c == ',') { + // Stay on the token. + break; + } + + AdvanceCursor(); + } + } + if (token_start != -1) { + token_queue_.push_back(Token( + Token::Type::RFC822_NAME, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + } + return NameInfo(at_sign_location, name_found); + } + + // Names that are within quotes should have all characters blindly unescaped. + // When a name is made into an address, it isn't re-escaped. + + // Returns the location of the last at sign in the quoted section. This is + // useful in case we do not find an address and have to use the name. The + // quoted section may contain whitespaces + NameInfo ConsumeQuotedSection() { + // Get past the first quote. + AdvanceCursor(); + const char* at_sign_location = nullptr; + + bool end_quote_found = false; + bool name_found = false; + UChar32 c; + + int token_start = -1; + + while (!end_quote_found && (iterator_.utf8_index() < text_end_)) { + c = iterator_.GetCurrentChar(); + + if (i18n_utils::IsAlphaNumeric(c)) { + name_found = true; + + if (token_start == -1) { + // Start tracking the token. + token_start = iterator_.utf8_index(); + } + AdvanceCursor(); + + } else { + // Non- alphabetic + if (c == '\\') { + // A backslash, let's look at the next character. + CharacterIterator temp = iterator_; + temp.AdvanceToUtf32(iterator_.utf32_index() + 1); + UChar32 n = temp.GetCurrentChar(); + if (i18n_utils::IsAlphaNumeric(n)) { + // The next character is alphabetic, skip the slash and don't end + // the last token. For quoted sections, the only things that are + // escaped are double quotes and slashes. For example, in "a\lex", + // an l appears after the slash. We want to treat this as if it was + // just "alex". So we tokenize it as <RFC822_NAME, "a\lex">. + AdvanceCursor(); + } else { + // Not alphabetic, so save the last token if necessary. + if (token_start != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_NAME, + term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + + // Skip the backslash. + AdvanceCursor(); + + if (n == '"' || n == '\\' || n == '@') { + // Skip these too if they're next. + AdvanceCursor(); + } + } + + } else { + // Not a backslash. + + if (c == '@') { + // Mark the last @ sign. + at_sign_location = term_.data() + iterator_.utf8_index(); + } + + if (token_start != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_NAME, + term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + + if (c == '"') { + end_quote_found = true; + } + // Advance one more time to get past the non-alphabetic character. + AdvanceCursor(); + } + } + } + if (token_start != -1) { + token_queue_.push_back(Token( + Token::Type::RFC822_NAME, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + } + return NameInfo(at_sign_location, name_found); + } + + // '(', ')', '\\' chars should be escaped. All other escaped chars should be + // unescaped. + void ConsumeParenthesizedSection() { + // Skip the initial ( + AdvanceCursor(); + + int paren_layer = 1; + UChar32 c; + + int token_start = -1; + + while (paren_layer > 0 && (iterator_.utf8_index() < text_end_)) { + c = iterator_.GetCurrentChar(); + + if (i18n_utils::IsAlphaNumeric(c)) { + if (token_start == -1) { + // Start tracking a token. + token_start = iterator_.utf8_index(); + } + AdvanceCursor(); + + } else { + // Non alphabetic. + if (c == '\\') { + // A backslash, let's look at the next character. + UChar32 n = i18n_utils::GetUChar32At(term_.data(), term_.length(), + iterator_.utf8_index() + 1); + if (i18n_utils::IsAlphaNumeric(n)) { + // Alphabetic, skip the slash and don't end the last token. + AdvanceCursor(); + } else { + // Not alphabetic, save the last token if necessary. + if (token_start != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_COMMENT, + term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + + // Skip the backslash. + AdvanceCursor(); + + if (n == ')' || n == '(' || n == '\\') { + // Skip these too if they're next. + AdvanceCursor(); + } + } + } else { + // Not a backslash. + if (token_start != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_COMMENT, + term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + + if (c == '(') { + paren_layer++; + } else if (c == ')') { + paren_layer--; + } + AdvanceCursor(); + } + } + } + + if (token_start != -1) { + // Ran past the end of term_ without getting the last token. + + // substr returns "a view of the substring [pos, pos + // rcount), where + // rcount is the smaller of count and size() - pos" therefore the count + // argument can be any value >= this->cursor - token_start. Therefore, + // ignoring the mutation warning. + token_queue_.push_back(Token( + Token::Type::RFC822_COMMENT, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + } + } + + // Returns true if we find an address. + bool ConsumeAddress() { + // Skip the first <. + AdvanceCursor(); + + // Save the start position. + CharacterIterator address_start_iterator = iterator_; + + int at_sign = -1; + int address_end = -1; + + UChar32 c = iterator_.GetCurrentChar(); + // Quick scan for @ and > signs. + while (c != '>' && iterator_.utf8_index() < text_end_) { + AdvanceCursor(); + c = iterator_.GetCurrentChar(); + if (c == '@') { + at_sign = iterator_.utf8_index(); + } + } + + if (iterator_.utf8_index() <= address_start_iterator.utf8_index()) { + // There is nothing between the brackets, either we have "<" or "<>" + return false; + } + + // Either we find a > or run to the end, either way this is the end of the + // address. The ending bracket will be handled by ConsumeUnquoted. + address_end = iterator_.utf8_index(); + + // Reset to the start. + iterator_ = address_start_iterator; + + int address_start = address_start_iterator.utf8_index(); + + Token::Type type = Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL; + + // Create a local address token. + if (at_sign != -1) { + token_queue_.push_back( + Token(Token::Type::RFC822_LOCAL_ADDRESS, + term_.substr(address_start, at_sign - address_start))); + } else { + // All the tokens in the address are host components. + type = Token::Type::RFC822_ADDRESS_COMPONENT_HOST; + } + + token_queue_.push_back( + Token(Token::Type::RFC822_ADDRESS, + term_.substr(address_start, address_end - address_start))); + + int token_start = -1; + + while (iterator_.utf8_index() < address_end) { + c = iterator_.GetCurrentChar(); + + if (i18n_utils::IsAlphaNumeric(c)) { + if (token_start == -1) { + token_start = iterator_.utf8_index(); + } + + } else { + // non alphabetic + if (c == '\\') { + // A backslash, let's look at the next character. + CharacterIterator temp = iterator_; + temp.AdvanceToUtf32(iterator_.utf32_index() + 1); + UChar32 n = temp.GetCurrentChar(); + if (!i18n_utils::IsAlphaNumeric(n)) { + // Not alphabetic, end the last token if necessary. + if (token_start != -1) { + token_queue_.push_back(Token( + type, term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + } + } else { + // Not backslash. + if (token_start != -1) { + token_queue_.push_back(Token( + type, term_.substr(token_start, + iterator_.utf8_index() - token_start))); + token_start = -1; + } + // Switch to host component tokens. + if (iterator_.utf8_index() == at_sign) { + type = Token::Type::RFC822_ADDRESS_COMPONENT_HOST; + } + } + } + AdvanceCursor(); + } + if (token_start != -1) { + token_queue_.push_back(Token( + type, + term_.substr(token_start, iterator_.utf8_index() - token_start))); + } + // Unquoted will handle the closing bracket > if these is one. + return true; + } + + Token GetToken() const override { + if (token_queue_.empty()) { + return Token(Token::Type::INVALID, term_); + } + return token_queue_.front(); + } + + private: + void AdvanceCursor() { + iterator_.AdvanceToUtf32(iterator_.utf32_index() + 1); + } + + void AdvancePastWhitespace() { + while (i18n_utils::IsWhitespaceAt(term_, iterator_.utf8_index())) { + AdvanceCursor(); + } + } + + std::string_view term_; + CharacterIterator iterator_; + int text_end_; + + // A temporary store of Tokens. As we advance through the provided string, we + // parse entire addresses at a time rather than one token at a time. However, + // since we call the tokenizer with Advance() alternating with GetToken(), we + // need to store tokens for subsequent GetToken calls if Advance generates + // multiple tokens (it usually does). A queue is used as we want the first + // token generated to be the first token returned from GetToken. + std::deque<Token> token_queue_; +}; + +libtextclassifier3::StatusOr<std::unique_ptr<Tokenizer::Iterator>> +Rfc822Tokenizer::Tokenize(std::string_view text) const { + return std::make_unique<Rfc822TokenIterator>(text); +} + +libtextclassifier3::StatusOr<std::vector<Token>> Rfc822Tokenizer::TokenizeAll( + std::string_view text) const { + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer::Iterator> iterator, + Tokenize(text)); + std::vector<Token> tokens; + while (iterator->Advance()) { + tokens.push_back(iterator->GetToken()); + } + return tokens; +} + +} // namespace lib +} // namespace icing diff --git a/icing/absl_ports/status_imports.h b/icing/tokenization/rfc822-tokenizer.h index 3a97fd6..09e4624 100644 --- a/icing/absl_ports/status_imports.h +++ b/icing/tokenization/rfc822-tokenizer.h @@ -1,4 +1,4 @@ -// Copyright (C) 2019 Google LLC +// Copyright (C) 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef ICING_ABSL_PORTS_STATUS_IMPORTS_H_ -#define ICING_ABSL_PORTS_STATUS_IMPORTS_H_ +#ifndef ICING_TOKENIZATION_RFC822_TOKENIZER_H_ +#define ICING_TOKENIZATION_RFC822_TOKENIZER_H_ -#include "icing/text_classifier/lib3/utils/base/status.h" +#include <vector> + +#include "icing/tokenization/tokenizer.h" namespace icing { namespace lib { -namespace absl_ports { -// TODO(b/144458732) Delete this file once visibility on TC3 Status has been -// granted to the sample app. -using Status = libtextclassifier3::Status; +class Rfc822Tokenizer : public Tokenizer { + public: + libtextclassifier3::StatusOr<std::unique_ptr<Tokenizer::Iterator>> Tokenize( + std::string_view text) const override; + + libtextclassifier3::StatusOr<std::vector<Token>> TokenizeAll( + std::string_view text) const override; + +}; -} // namespace absl_ports } // namespace lib } // namespace icing -#endif // ICING_ABSL_PORTS_STATUS_IMPORTS_H_ +#endif // ICING_TOKENIZATION_RFC822_TOKENIZER_H_ diff --git a/icing/tokenization/rfc822-tokenizer_test.cc b/icing/tokenization/rfc822-tokenizer_test.cc new file mode 100644 index 0000000..e3c6da6 --- /dev/null +++ b/icing/tokenization/rfc822-tokenizer_test.cc @@ -0,0 +1,797 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/tokenization/rfc822-tokenizer.h" + +#include <memory> +#include <string> +#include <string_view> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/jni-test-helpers.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { +namespace { +using ::testing::ElementsAre; + +class Rfc822TokenizerTest : public testing::Test { + protected: + void SetUp() override { + jni_cache_ = GetTestJniCache(); + language_segmenter_factory::SegmenterOptions options(ULOC_US, + jni_cache_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(std::move(options))); + } + std::unique_ptr<const JniCache> jni_cache_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; +}; + +TEST_F(Rfc822TokenizerTest, Simple) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("<你alex@google.com>"); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "你alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "你alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "你alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<你alex@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, Small) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("\"a\""); + + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "a"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"a\""), + EqualsToken(Token::Type::RFC822_ADDRESS, "a")))); + + s = "\"a\", \"b\""; + + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "a"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"a\""), + EqualsToken(Token::Type::RFC822_ADDRESS, "a"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "b"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"b\""), + EqualsToken(Token::Type::RFC822_ADDRESS, "b")))); + + s = "(a)"; + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre(EqualsToken(Token::Type::RFC822_COMMENT, "a"), + EqualsToken(Token::Type::RFC822_TOKEN, "(a)")))); +} + +TEST_F(Rfc822TokenizerTest, PB) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("peanut (comment) butter, <alex@google.com>"); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "peanut"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, "peanut (comment) butter"), + EqualsToken(Token::Type::RFC822_ADDRESS, "peanut (comment) butter"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<alex@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, NoBrackets) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("alex@google.com"); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex")))); +} + +TEST_F(Rfc822TokenizerTest, TwoAddresses) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("<你alex@google.com>; <alexsav@gmail.com>"); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "你alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "你alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "你alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<你alex@google.com>"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alexsav"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alexsav@gmail.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alexsav"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "gmail"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<alexsav@gmail.com>")))); +} + +TEST_F(Rfc822TokenizerTest, CommentB) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("(a comment) <alex@google.com>"); + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, "a"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + "(a comment) <alex@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, NameAndComment) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string_view s("\"a name\" also a name <alex@google.com>"); + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "a"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_NAME, "also"), + EqualsToken(Token::Type::RFC822_NAME, "a"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + "\"a name\" also a name <alex@google.com>")))); +} + +// Test from tokenizer_test.cc. +TEST_F(Rfc822TokenizerTest, Rfc822SanityCheck) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string addr1("A name (A comment) <address@domain.com>"); + std::string addr2( + "\"(Another name)\" (A different comment) " + "<bob-loblaw@foo.bar.com>"); + std::string addr3("<no.at.sign.present>"); + std::string addr4("<double@at@signs.present>"); + std::string rfc822 = addr1 + ", " + addr2 + ", " + addr3 + ", " + addr4; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(rfc822), + IsOkAndHolds(ElementsAre( + + EqualsToken(Token::Type::RFC822_NAME, "A"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "A"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS, "address@domain.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "domain"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, addr1), + + EqualsToken(Token::Type::RFC822_NAME, "Another"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "A"), + EqualsToken(Token::Type::RFC822_COMMENT, "different"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "bob-loblaw"), + EqualsToken(Token::Type::RFC822_ADDRESS, "bob-loblaw@foo.bar.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "bob"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "loblaw"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "bar"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, addr2), + + EqualsToken(Token::Type::RFC822_ADDRESS, "no.at.sign.present"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "no"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "at"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "sign"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "present"), + EqualsToken(Token::Type::RFC822_TOKEN, addr3), + + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "double@at"), + EqualsToken(Token::Type::RFC822_ADDRESS, "double@at@signs.present"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "double"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "at"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "signs"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "present"), + EqualsToken(Token::Type::RFC822_TOKEN, addr4)))); +} + +// Tests from rfc822 converter. +TEST_F(Rfc822TokenizerTest, SimpleRfcText) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string test_string = + "foo@google.com,bar@google.com,baz@google.com,foo+hello@google.com,baz@" + "corp.google.com"; + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(test_string), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "foo@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "bar"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "bar@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "bar@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "bar"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "baz"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "baz@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "baz@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "baz"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "hello"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "foo+hello@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo+hello@google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "foo+hello"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "baz"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "corp"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "baz@corp.google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS, "baz@corp.google.com"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "baz")))); +} + +TEST_F(Rfc822TokenizerTest, ComplicatedRfcText) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string test_string = + R"raw("Weird, But&(Also)\\Valid" Name (!With, "an" \\odd\\ cmt too¡) <Foo B(a)r,Baz@g.co> + <easy@google.com>)raw"; + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(test_string), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "Weird"), + EqualsToken(Token::Type::RFC822_NAME, "But"), + EqualsToken(Token::Type::RFC822_NAME, "Also"), + EqualsToken(Token::Type::RFC822_NAME, "Valid"), + EqualsToken(Token::Type::RFC822_NAME, "Name"), + EqualsToken(Token::Type::RFC822_COMMENT, "With"), + EqualsToken(Token::Type::RFC822_COMMENT, "an"), + EqualsToken(Token::Type::RFC822_COMMENT, "odd"), + EqualsToken(Token::Type::RFC822_COMMENT, "cmt"), + EqualsToken(Token::Type::RFC822_COMMENT, "too"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "Foo B(a)r,Baz"), + EqualsToken(Token::Type::RFC822_ADDRESS, "Foo B(a)r,Baz@g.co"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "Foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "B"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "a"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "r"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "Baz"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "g"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "co"), + EqualsToken( + Token::Type::RFC822_TOKEN, + R"raw("Weird, But&(Also)\\Valid" Name (!With, "an" \\odd\\ cmt too¡) <Foo B(a)r,Baz@g.co>)raw"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "easy"), + EqualsToken(Token::Type::RFC822_ADDRESS, "easy@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "easy"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<easy@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, FromHtmlBugs) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + // This input used to cause HTML parsing exception. We don't do HTML parsing + // any more (b/8388100) so we are just checking that it does not crash and + // that it retains the input. + + // http://b/8988210. Put crashing string "&\r" x 100 into name and comment + // field of rfc822 token. + + std::string s("\""); + for (int i = 0; i < 100; i++) { + s.append("&\r"); + } + s.append("\" ("); + for (int i = 0; i < 100; i++) { + s.append("&\r"); + } + s.append(") <foo@google.com>"); + + // It shouldn't change anything + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, s)))); +} + +TEST_F(Rfc822TokenizerTest, EmptyComponentsTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(""), + IsOkAndHolds(testing::IsEmpty())); + + // Name is considered the address if address is empty. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("name<>"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "name<>"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + // Empty name and address means that there is no token. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("(a long comment with nothing else)"), + IsOkAndHolds( + ElementsAre(EqualsToken(Token::Type::RFC822_COMMENT, "a"), + EqualsToken(Token::Type::RFC822_COMMENT, "long"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_COMMENT, "with"), + EqualsToken(Token::Type::RFC822_COMMENT, "nothing"), + EqualsToken(Token::Type::RFC822_COMMENT, "else"), + EqualsToken(Token::Type::RFC822_TOKEN, + "(a long comment with nothing else)")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("name ()"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "name ()"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(R"((comment) "")"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_TOKEN, "(comment) \"\"")))); +} + +TEST_F(Rfc822TokenizerTest, NameTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + // Name spread between address or comment. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("peanut <address> butter"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "peanut"), + EqualsToken(Token::Type::RFC822_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "address"), + EqualsToken(Token::Type::RFC822_NAME, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, "peanut <address> butter")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("peanut (comment) butter"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "peanut"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, "peanut (comment) butter"), + EqualsToken(Token::Type::RFC822_ADDRESS, + "peanut (comment) butter")))); + + // Dropping quotes when they're not needed. + std::string s = R"(peanut <address> "butter")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "peanut"), + EqualsToken(Token::Type::RFC822_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "address"), + EqualsToken(Token::Type::RFC822_NAME, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, s)))); + + s = R"(peanut "butter")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(s), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "peanut"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "butter"), + EqualsToken(Token::Type::RFC822_TOKEN, s), + EqualsToken(Token::Type::RFC822_ADDRESS, "peanut \"butter")))); + // Adding quotes when they are needed. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("ple@se quote this <addr>"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "ple"), + EqualsToken(Token::Type::RFC822_NAME, "se"), + EqualsToken(Token::Type::RFC822_NAME, "quote"), + EqualsToken(Token::Type::RFC822_NAME, "this"), + EqualsToken(Token::Type::RFC822_ADDRESS, "addr"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "addr"), + + EqualsToken(Token::Type::RFC822_TOKEN, "ple@se quote this <addr>")))); +} + +TEST_F(Rfc822TokenizerTest, CommentEscapeTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + // '(', ')', '\\' chars should be escaped. All other escaped chars should be + // unescaped. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"((co\)mm\\en\(t))"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, "co"), + EqualsToken(Token::Type::RFC822_COMMENT, "mm"), + EqualsToken(Token::Type::RFC822_COMMENT, "en"), + EqualsToken(Token::Type::RFC822_COMMENT, "t"), + EqualsToken(Token::Type::RFC822_TOKEN, R"((co\)mm\\en\(t))")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"((c\om\ment) name)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, R"(c\om\ment)"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, R"((c\om\ment) name)"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"((co(m\))ment) name)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_COMMENT, "co"), + EqualsToken(Token::Type::RFC822_COMMENT, "m"), + EqualsToken(Token::Type::RFC822_COMMENT, "ment"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, R"((co(m\))ment) name)"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); +} + +TEST_F(Rfc822TokenizerTest, QuoteEscapeTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + // All names that include non-alphanumeric chars must be quoted and have '\\' + // and '"' chars escaped. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(n\\a\me <addr>)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "n"), + EqualsToken(Token::Type::RFC822_NAME, "a"), + EqualsToken(Token::Type::RFC822_NAME, "me"), + EqualsToken(Token::Type::RFC822_ADDRESS, "addr"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "addr"), + EqualsToken(Token::Type::RFC822_TOKEN, R"(n\\a\me <addr>)")))); + + // Names that are within quotes should have all characters blindly unescaped. + // When a name is made into an address, it isn't re-escaped. + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"("n\\a\m\"e")"), + // <n\am"e> + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "n"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "a\\m"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "e"), + EqualsToken(Token::Type::RFC822_TOKEN, R"("n\\a\m\"e")"), + EqualsToken(Token::Type::RFC822_ADDRESS, R"(n\\a\m\"e)")))); +} + +TEST_F(Rfc822TokenizerTest, UnterminatedComponentTest) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll("name (comment"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_TOKEN, "name (comment"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(half of "the name)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "half"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "of"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "the"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "half of \"the name"), + EqualsToken(Token::Type::RFC822_ADDRESS, "half of \"the name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"("name\)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"name\\"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(name (comment\)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_TOKEN, "name (comment\\"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(<addr> "name\)"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS, "addr"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "addr"), + EqualsToken(Token::Type::RFC822_NAME, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "<addr> \"name\\")))); + + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(R"(name (comment\))"), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_COMMENT, "comment"), + EqualsToken(Token::Type::RFC822_TOKEN, R"(name (comment\))"), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); +} + +TEST_F(Rfc822TokenizerTest, Tokenize) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + std::string text = + R"raw("Berg" (home) <berg\@google.com>, tom\@google.com (work))raw"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "Berg"), + EqualsToken(Token::Type::RFC822_COMMENT, "home"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "berg\\"), + EqualsToken(Token::Type::RFC822_ADDRESS, "berg\\@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "berg"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + R"("Berg" (home) <berg\@google.com>)"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "tom"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_COMMENT, "work"), + EqualsToken(Token::Type::RFC822_TOKEN, "tom\\@google.com (work)"), + EqualsToken(Token::Type::RFC822_ADDRESS, "tom\\@google.com")))); + + text = R"raw(Foo Bar (something) <foo\@google.com>, )raw" + R"raw(blah\@google.com (something))raw"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "Foo"), + EqualsToken(Token::Type::RFC822_NAME, "Bar"), + EqualsToken(Token::Type::RFC822_COMMENT, "something"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "foo\\"), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo\\@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + "Foo Bar (something) <foo\\@google.com>"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "blah"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_COMMENT, "something"), + EqualsToken(Token::Type::RFC822_TOKEN, + "blah\\@google.com (something)"), + EqualsToken(Token::Type::RFC822_ADDRESS, "blah\\@google.com")))); +} + +TEST_F(Rfc822TokenizerTest, EdgeCases) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + + // Text to trigger the scenario where you have a non-alphabetic followed + // by a \ followed by non alphabetic to end an in-address token. + std::string text = R"raw(<be.\&rg@google.com>)raw"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "be.\\&rg"), + EqualsToken(Token::Type::RFC822_ADDRESS, "be.\\&rg@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "be"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "rg"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + R"raw(<be.\&rg@google.com>)raw")))); + + // A \ followed by an alphabetic shouldn't end the token. + text = "<a\\lex@google.com>"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "a\\lex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "a\\lex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "a\\lex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<a\\lex@google.com>")))); + + // \\ or \" in a quoted section. + text = R"("al\\ex@goo\"<idk>gle.com")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "al"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "ex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "goo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "idk"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "gle"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, + R"("al\\ex@goo\"<idk>gle.com")"), + EqualsToken(Token::Type::RFC822_ADDRESS, + R"(al\\ex@goo\"<idk>gle.com)"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "al\\\\ex")))); + + text = "<alex@google.com"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<alex@google.com")))); +} + +TEST_F(Rfc822TokenizerTest, NumberInAddress) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "<3alex@google.com>"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "3alex"), + EqualsToken(Token::Type::RFC822_ADDRESS, "3alex@google.com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "3alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, "<3alex@google.com>")))); +} + +TEST_F(Rfc822TokenizerTest, DoubleQuoteDoubleSlash) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = R"("alex\"")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "alex"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, "alex")))); + + text = R"("alex\\\a")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "alex"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "a"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, R"(alex\\\a)")))); +} + +TEST_F(Rfc822TokenizerTest, TwoEmails) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "tjbarron@google.com alexsav@google.com"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "tjbarron"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "com"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "alexsav"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "google"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "com"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, text), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, + "tjbarron@google.com alexsav")))); +} + +TEST_F(Rfc822TokenizerTest, BackSlashes) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = R"("\name")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "name"), + EqualsToken(Token::Type::RFC822_TOKEN, "\"\\name\""), + EqualsToken(Token::Type::RFC822_ADDRESS, "name")))); + + text = R"("name@foo\@gmail")"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_LOCAL, "name"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "foo"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "gmail"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, "name@foo\\@gmail"), + EqualsToken(Token::Type::RFC822_LOCAL_ADDRESS, "name")))); +} + +TEST_F(Rfc822TokenizerTest, BigWhitespace) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "\"quoted\" <address>"; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_NAME, "quoted"), + EqualsToken(Token::Type::RFC822_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "address"), + EqualsToken(Token::Type::RFC822_TOKEN, text)))); +} + +TEST_F(Rfc822TokenizerTest, AtSignFirst) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "\"@foo\""; + EXPECT_THAT( + rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, "foo"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, "foo")))); +} + +TEST_F(Rfc822TokenizerTest, SlashThenUnicode) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = R"("quoted\你cjk")"; + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, + "quoted\\你cjk"), + EqualsToken(Token::Type::RFC822_TOKEN, text), + EqualsToken(Token::Type::RFC822_ADDRESS, "quoted\\你cjk")))); +} + +TEST_F(Rfc822TokenizerTest, AddressEmptyAddress) { + Rfc822Tokenizer rfc822_tokenizer = Rfc822Tokenizer(); + std::string text = "<address> <> Name"; + EXPECT_THAT(rfc822_tokenizer.TokenizeAll(text), + IsOkAndHolds(ElementsAre( + EqualsToken(Token::Type::RFC822_ADDRESS, "address"), + EqualsToken(Token::Type::RFC822_ADDRESS_COMPONENT_HOST, + "address"), + EqualsToken(Token::Type::RFC822_NAME, "Name"), + EqualsToken(Token::Type::RFC822_TOKEN, text)))); +} + +} // namespace +} // namespace lib +} // namespace icing diff --git a/icing/tokenization/token.h b/icing/tokenization/token.h index 0c268be..24f567b 100644 --- a/icing/tokenization/token.h +++ b/icing/tokenization/token.h @@ -29,6 +29,15 @@ struct Token { VERBATIM, // A token that should be indexed and searched without any // modifications to the raw text + // An RFC822 section with the content in RFC822_TOKEN tokenizes as follows: + RFC822_NAME, // "User", "Johnsson" + RFC822_COMMENT, // "A", "comment", "here" + RFC822_LOCAL_ADDRESS, // "user.name" + RFC822_ADDRESS, // "user.name@domain.name.com" + RFC822_ADDRESS_COMPONENT_LOCAL, // "user", "name", + RFC822_ADDRESS_COMPONENT_HOST, // "domain", "name", "com" + RFC822_TOKEN, // "User Johnsson (A comment) <user.name@domain.name.com>" + // Types only used in raw query QUERY_OR, // Indicates OR logic between its left and right tokens QUERY_EXCLUSION, // Indicates exclusion operation on next token @@ -45,10 +54,10 @@ struct Token { : type(type_in), text(text_in) {} // The type of token - const Type type; + Type type; // The content of token - const std::string_view text; + std::string_view text; }; } // namespace lib diff --git a/icing/util/clock.h b/icing/util/clock.h index 2bb7818..9e57854 100644 --- a/icing/util/clock.h +++ b/icing/util/clock.h @@ -16,6 +16,7 @@ #define ICING_UTIL_CLOCK_H_ #include <cstdint> +#include <functional> #include <memory> namespace icing { @@ -69,6 +70,32 @@ class Clock { virtual std::unique_ptr<Timer> GetNewTimer() const; }; +// A convenient RAII timer class that receives a callback. Upon destruction, the +// callback will be called with the elapsed milliseconds or nanoseconds passed +// as a parameter, depending on which Unit was passed in the constructor. +class ScopedTimer { + public: + enum class Unit { kMillisecond, kNanosecond }; + + ScopedTimer(std::unique_ptr<Timer> timer, + std::function<void(int64_t)> callback, + Unit unit = Unit::kMillisecond) + : timer_(std::move(timer)), callback_(std::move(callback)), unit_(unit) {} + + ~ScopedTimer() { + if (unit_ == Unit::kMillisecond) { + callback_(timer_->GetElapsedMilliseconds()); + } else { + callback_(timer_->GetElapsedNanoseconds()); + } + } + + private: + std::unique_ptr<Timer> timer_; + std::function<void(int64_t)> callback_; + Unit unit_; +}; + } // namespace lib } // namespace icing diff --git a/icing/util/crc32.h b/icing/util/crc32.h index 5befe44..207a80a 100644 --- a/icing/util/crc32.h +++ b/icing/util/crc32.h @@ -35,6 +35,8 @@ class Crc32 { explicit Crc32(uint32_t init_crc) : crc_(init_crc) {} + explicit Crc32(std::string_view str) : crc_(0) { Append(str); } + inline bool operator==(const Crc32& other) const { return crc_ == other.Get(); } diff --git a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java index c690990..b55cfd1 100644 --- a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java +++ b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java @@ -17,6 +17,7 @@ package com.google.android.icing; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import com.google.android.icing.IcingSearchEngine; import com.google.android.icing.proto.DebugInfoResultProto; import com.google.android.icing.proto.DebugInfoVerbosity; import com.google.android.icing.proto.DeleteByNamespaceResultProto; @@ -60,7 +61,6 @@ import com.google.android.icing.proto.SuggestionSpecProto.SuggestionScoringSpecP import com.google.android.icing.proto.TermMatchType; import com.google.android.icing.proto.TermMatchType.Code; import com.google.android.icing.proto.UsageReport; -import com.google.android.icing.IcingSearchEngine; import java.io.File; import java.util.HashMap; import java.util.Map; diff --git a/proto/icing/proto/optimize.proto b/proto/icing/proto/optimize.proto index 42290f3..0accb9a 100644 --- a/proto/icing/proto/optimize.proto +++ b/proto/icing/proto/optimize.proto @@ -63,7 +63,7 @@ message GetOptimizeInfoResultProto { optional int64 time_since_last_optimize_ms = 4; } -// Next tag: 10 +// Next tag: 11 message OptimizeStatsProto { // Overall time used for the function call. optional int32 latency_ms = 1; @@ -91,4 +91,15 @@ message OptimizeStatsProto { // The amount of time since the last optimize ran. optional int64 time_since_last_optimize_ms = 9; + + enum IndexRestorationMode { + // The index has been translated in place to match the optimized document + // store. + INDEX_TRANSLATION = 0; + // The index has been rebuilt from scratch during optimization. This could + // happen when we received a DATA_LOSS error from OptimizeDocumentStore, + // Index::Optimize failed, or rebuilding could be faster. + FULL_INDEX_REBUILD = 1; + } + optional IndexRestorationMode index_restoration_mode = 10; } diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index f005c76..7a361d3 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -65,7 +65,7 @@ message SearchSpecProto { // Client-supplied specifications on what to include/how to format the search // results. -// Next tag: 6 +// Next tag: 7 message ResultSpecProto { // The results will be returned in pages, and num_per_page specifies the // number of documents in one page. @@ -133,6 +133,15 @@ message ResultSpecProto { // ["ns0doc0", "ns0doc1", "ns1doc0", "ns3doc0", "ns3doc1", "ns2doc1", // "ns3doc2"]. repeated ResultGrouping result_groupings = 5; + + // The threshold of total bytes of all documents to cutoff, in order to limit + // # of bytes in a single page. + // Note that it doesn't guarantee the result # of bytes will be smaller, equal + // to, or larger than the threshold. Instead, it is just a threshold to + // cutoff, and only guarantees total bytes of search results will exceed the + // threshold by less than the size of the final search result. + optional int32 num_total_bytes_per_page_threshold = 6 + [default = 2147483647]; // INT_MAX } // The representation of a single match within a DocumentProto property. diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index 305f410..cd00254 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=455217954) +set(synced_AOSP_CL_number=466546985) |