diff options
author | Grace Zhao <gracezrx@google.com> | 2022-09-08 20:26:31 +0000 |
---|---|---|
committer | Grace Zhao <gracezrx@google.com> | 2022-09-08 22:53:11 +0000 |
commit | b02eecda6a12241798cdbaaa7069d19f2fc5f41f (patch) | |
tree | 15687379068030d4d5443c916d91e9ed364f9b39 | |
parent | 87267cbc5531600072a283ba0c9500c3fcac87af (diff) | |
download | icing-b02eecda6a12241798cdbaaa7069d19f2fc5f41f.tar.gz |
Sync from upstream.
Descriptions:
======================================================================
[FileBackedVector Consolidation][4/x] Fix potential PWrite bug in GrowIfNecessary
======================================================================
[FileBackedVector Consolidation][5/x] Create benchmark for FileBackedVector
======================================================================
[FileBackedVector Consolidation][6/x] Avoid calling GetFileSize in GrowIfNecessary
======================================================================
[PersistentHashMap][3.3/x] Implement Delete
======================================================================
Fix the PopulateMatchedTermsStats bug
======================================================================
Add JNI latency for query latency stats breakdown.
======================================================================
[ResultStateManager] Thread safety test1
======================================================================
[ResultStateManager][2/x] Thread safety test2
======================================================================
Add native lock contention latency for measuring query latency
======================================================================
Fix implementation of HasMember operator in ANTLR-based list-filter prototype.
======================================================================
Fix improper uses of std::string_view
======================================================================
Extend the scale of Icing
======================================================================
Decouple the term frequency array from DocHitInfo
======================================================================
Disable hit_term_frequency for non-relevance queries
======================================================================
[ResultStateManager][3/x] Thread safety test3
======================================================================
[PersistentHashMap][4/x] Implement iterator
=======================================================================
Fix the lite index compaction bug
=======================================================================
Change-Id: I0edad67affed97af107e2d7cd73770e0268c0903
71 files changed, 3357 insertions, 959 deletions
diff --git a/icing/file/file-backed-vector.h b/icing/file/file-backed-vector.h index bcfbbdd..7916666 100644 --- a/icing/file/file-backed-vector.h +++ b/icing/file/file-backed-vector.h @@ -261,9 +261,20 @@ class FileBackedVector { // // Returns: // OUT_OF_RANGE_ERROR if idx < 0 or idx > kMaxIndex or file cannot be grown - // idx size + // to fit idx + 1 elements libtextclassifier3::Status Set(int32_t idx, const T& value); + // Set [idx, idx + len) to a single value. + // + // May grow the underlying file and mmapped region as needed to fit the new + // value. If it does grow, then any pointers/references to previous values + // returned from Get/GetMutable/Allocate may be invalidated. + // + // Returns: + // OUT_OF_RANGE_ERROR if idx < 0 or idx + len > kMaxNumElements or file + // cannot be grown to fit idx + len elements + libtextclassifier3::Status Set(int32_t idx, int32_t len, const T& value); + // Appends the value to the end of the vector. // // May grow the underlying file and mmapped region as needed to fit the new @@ -369,8 +380,8 @@ class FileBackedVector { // It handles SetDirty properly for the file-backed-vector when modifying // elements. // - // REQUIRES: arr is valid && arr_len >= 0 && idx + arr_len <= size(), - // otherwise the behavior is undefined. + // REQUIRES: arr is valid && arr_len >= 0 && idx >= 0 && idx + arr_len <= + // size(), otherwise the behavior is undefined. void SetArray(int32_t idx, const T* arr, int32_t arr_len) { for (int32_t i = 0; i < arr_len; ++i) { SetDirty(idx + i); @@ -433,10 +444,11 @@ class FileBackedVector { static constexpr int32_t kMaxIndex = kMaxNumElements - 1; // Can only be created through the factory ::Create function - FileBackedVector(const Filesystem& filesystem, const std::string& file_path, - std::unique_ptr<Header> header, - std::unique_ptr<MemoryMappedFile> mmapped_file, - int32_t max_file_size); + explicit FileBackedVector(const Filesystem& filesystem, + const std::string& file_path, + std::unique_ptr<Header> header, + std::unique_ptr<MemoryMappedFile> mmapped_file, + int32_t max_file_size); // Initialize a new FileBackedVector, and create the file. static libtextclassifier3::StatusOr<std::unique_ptr<FileBackedVector<T>>> @@ -765,30 +777,44 @@ FileBackedVector<T>::GetMutable(int32_t idx, int32_t len) { template <typename T> libtextclassifier3::Status FileBackedVector<T>::Set(int32_t idx, const T& value) { + return Set(idx, 1, value); +} + +template <typename T> +libtextclassifier3::Status FileBackedVector<T>::Set(int32_t idx, int32_t len, + const T& value) { if (idx < 0) { return absl_ports::OutOfRangeError( IcingStringUtil::StringPrintf("Index, %d, was less than 0", idx)); } - if (idx > kMaxIndex) { - return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( - "Index, %d, was greater than max index allowed, %d", idx, kMaxIndex)); + if (len <= 0) { + return absl_ports::OutOfRangeError("Invalid set length"); } - ICING_RETURN_IF_ERROR(GrowIfNecessary(idx + 1)); - - if (idx + 1 > header_->num_elements) { - header_->num_elements = idx + 1; + if (idx > kMaxNumElements - len) { + return absl_ports::OutOfRangeError( + IcingStringUtil::StringPrintf("Length %d (with index %d), was too long " + "for max num elements allowed, %d", + len, idx, kMaxNumElements)); } - if (mutable_array()[idx] == value) { - // No need to update - return libtextclassifier3::Status::OK; + ICING_RETURN_IF_ERROR(GrowIfNecessary(idx + len)); + + if (idx + len > header_->num_elements) { + header_->num_elements = idx + len; } - SetDirty(idx); + for (int32_t i = 0; i < len; ++i) { + if (array()[idx + i] == value) { + // No need to update + continue; + } + + SetDirty(idx + i); + mutable_array()[idx + i] = value; + } - mutable_array()[idx] = value; return libtextclassifier3::Status::OK; } @@ -835,19 +861,16 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( num_elements, max_file_size_ - Header::kHeaderSize)); } - int64_t current_file_size = filesystem_->GetFileSize(file_path_.c_str()); - if (current_file_size == Filesystem::kBadFileSize) { - return absl_ports::InternalError("Unable to retrieve file size."); - } - - int32_t least_file_size_needed = - Header::kHeaderSize + num_elements * kElementTypeSize; // Won't overflow - if (least_file_size_needed <= current_file_size) { - // Our underlying file can hold the target num_elements cause we've grown + int32_t least_element_file_size_needed = + num_elements * kElementTypeSize; // Won't overflow + if (least_element_file_size_needed <= mmapped_file_->region_size()) { + // Our mmapped region can hold the target num_elements cause we've grown // before return libtextclassifier3::Status::OK; } + int32_t least_file_size_needed = + Header::kHeaderSize + least_element_file_size_needed; // Otherwise, we need to grow. Grow to kGrowElements boundary. // Note that we need to use int64_t here, since int32_t might overflow after // round up. @@ -857,6 +880,12 @@ libtextclassifier3::Status FileBackedVector<T>::GrowIfNecessary( least_file_size_needed = std::min(round_up_file_size_needed, int64_t{max_file_size_}); + // Get the actual file size here. + int64_t current_file_size = filesystem_->GetFileSize(file_path_.c_str()); + if (current_file_size == Filesystem::kBadFileSize) { + return absl_ports::InternalError("Unable to retrieve file size."); + } + // We use PWrite here rather than Grow because Grow doesn't actually allocate // an underlying disk block. This can lead to problems with mmap because mmap // has no effective way to signal that it was impossible to allocate the disk diff --git a/icing/file/file-backed-vector_benchmark.cc b/icing/file/file-backed-vector_benchmark.cc new file mode 100644 index 0000000..b2e660b --- /dev/null +++ b/icing/file/file-backed-vector_benchmark.cc @@ -0,0 +1,158 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <limits> +#include <memory> +#include <random> +#include <string> + +#include "testing/base/public/benchmark.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/file/destructible-directory.h" +#include "icing/file/file-backed-vector.h" +#include "icing/file/filesystem.h" +#include "icing/file/memory-mapped-file.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +class FileBackedVectorBenchmark { + public: + explicit FileBackedVectorBenchmark() + : base_dir_(GetTestTempDir() + "/file_backed_vector_benchmark"), + file_path_(base_dir_ + "/test_vector"), + ddir_(&filesystem_, base_dir_), + random_engine_(/*seed=*/12345) {} + + const Filesystem& filesystem() const { return filesystem_; } + const std::string& file_path() const { return file_path_; } + std::default_random_engine& random_engine() { return random_engine_; } + + private: + Filesystem filesystem_; + std::string base_dir_; + std::string file_path_; + DestructibleDirectory ddir_; + + std::default_random_engine random_engine_; +}; + +// Benchmark Set() (without extending vector, i.e. the index should be in range +// [0, num_elts - 1]. +void BM_Set(benchmark::State& state) { + int num_elts = state.range(0); + + FileBackedVectorBenchmark fbv_benchmark; + + fbv_benchmark.filesystem().DeleteFile(fbv_benchmark.file_path().c_str()); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> fbv, + FileBackedVector<int>::Create( + fbv_benchmark.filesystem(), fbv_benchmark.file_path(), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + // Extend to num_elts + fbv->Set(num_elts - 1, 0); + + std::uniform_int_distribution<> distrib(0, num_elts - 1); + for (auto _ : state) { + int idx = distrib(fbv_benchmark.random_engine()); + ICING_ASSERT_OK(fbv->Set(idx, idx)); + } +} +BENCHMARK(BM_Set) + ->Arg(1 << 10) + ->Arg(1 << 11) + ->Arg(1 << 12) + ->Arg(1 << 13) + ->Arg(1 << 14) + ->Arg(1 << 15) + ->Arg(1 << 16) + ->Arg(1 << 17) + ->Arg(1 << 18) + ->Arg(1 << 19) + ->Arg(1 << 20); + +// Benchmark single Append(). Equivalent to Set(fbv->num_elements(), val), which +// extends the vector every round. +void BM_Append(benchmark::State& state) { + FileBackedVectorBenchmark fbv_benchmark; + + fbv_benchmark.filesystem().DeleteFile(fbv_benchmark.file_path().c_str()); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> fbv, + FileBackedVector<int>::Create( + fbv_benchmark.filesystem(), fbv_benchmark.file_path(), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + std::uniform_int_distribution<> distrib(0, std::numeric_limits<int>::max()); + for (auto _ : state) { + ICING_ASSERT_OK(fbv->Append(distrib(fbv_benchmark.random_engine()))); + } +} +BENCHMARK(BM_Append); + +// Benchmark appending many elements. +void BM_AppendMany(benchmark::State& state) { + int num = state.range(0); + + FileBackedVectorBenchmark fbv_benchmark; + + for (auto _ : state) { + state.PauseTiming(); + fbv_benchmark.filesystem().DeleteFile(fbv_benchmark.file_path().c_str()); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> fbv, + FileBackedVector<int>::Create( + fbv_benchmark.filesystem(), fbv_benchmark.file_path(), + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + state.ResumeTiming(); + + for (int i = 0; i < num; ++i) { + ICING_ASSERT_OK(fbv->Append(i)); + } + + // Since destructor calls PersistToDisk, to avoid calling it twice, we reset + // the unique pointer to invoke destructor instead of calling PersistToDisk + // explicitly, so in this case PersistToDisk will be called only once. + fbv.reset(); + } +} +BENCHMARK(BM_AppendMany) + ->Arg(1 << 5) + ->Arg(1 << 6) + ->Arg(1 << 7) + ->Arg(1 << 8) + ->Arg(1 << 9) + ->Arg(1 << 10) + ->Arg(1 << 11) + ->Arg(1 << 12) + ->Arg(1 << 13) + ->Arg(1 << 14) + ->Arg(1 << 15) + ->Arg(1 << 16) + ->Arg(1 << 17) + ->Arg(1 << 18) + ->Arg(1 << 19) + ->Arg(1 << 20); + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/file/file-backed-vector_test.cc b/icing/file/file-backed-vector_test.cc index 60ed887..74e4132 100644 --- a/icing/file/file-backed-vector_test.cc +++ b/icing/file/file-backed-vector_test.cc @@ -36,6 +36,7 @@ #include "icing/util/crc32.h" #include "icing/util/logging.h" +using ::testing::DoDefault; using ::testing::ElementsAre; using ::testing::Eq; using ::testing::IsTrue; @@ -284,6 +285,65 @@ TEST_F(FileBackedVectorTest, Get) { StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); } +TEST_F(FileBackedVectorTest, SetWithoutGrowing) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(0))); + + std::string original = "abcde"; + Insert(vector.get(), /*idx=*/0, original); + ASSERT_THAT(vector->num_elements(), Eq(original.length())); + ASSERT_THAT(Get(vector.get(), /*idx=*/0, /*expected_len=*/5), Eq(original)); + + ICING_EXPECT_OK(vector->Set(/*idx=*/1, /*len=*/3, 'z')); + EXPECT_THAT(vector->num_elements(), Eq(5)); + EXPECT_THAT(Get(vector.get(), /*idx=*/0, /*expected_len=*/5), Eq("azzze")); +} + +TEST_F(FileBackedVectorTest, SetWithGrowing) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + EXPECT_THAT(vector->ComputeChecksum(), IsOkAndHolds(Crc32(0))); + + std::string original = "abcde"; + Insert(vector.get(), /*idx=*/0, original); + ASSERT_THAT(vector->num_elements(), Eq(original.length())); + ASSERT_THAT(Get(vector.get(), /*idx=*/0, /*expected_len=*/5), Eq(original)); + + ICING_EXPECT_OK(vector->Set(/*idx=*/3, /*len=*/4, 'z')); + EXPECT_THAT(vector->num_elements(), Eq(7)); + EXPECT_THAT(Get(vector.get(), /*idx=*/0, /*expected_len=*/7), Eq("abczzzz")); +} + +TEST_F(FileBackedVectorTest, SetInvalidArguments) { + // Create a vector and add some data. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<char>> vector, + FileBackedVector<char>::Create( + filesystem_, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + EXPECT_THAT(vector->Set(/*idx=*/0, /*len=*/-1, 'z'), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->Set(/*idx=*/0, /*len=*/0, 'z'), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->Set(/*idx=*/-1, /*len=*/2, 'z'), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + EXPECT_THAT(vector->Set(/*idx=*/100, + /*len=*/std::numeric_limits<int32_t>::max(), 'z'), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); +} + TEST_F(FileBackedVectorTest, MutableView) { // Create a vector and add some data. ICING_ASSERT_OK_AND_ASSIGN( @@ -1225,6 +1285,40 @@ TEST_F(FileBackedVectorTest, BadFileSizeDuringGrowReturnsError) { StatusIs(libtextclassifier3::StatusCode::INTERNAL)); } +TEST_F(FileBackedVectorTest, PWriteFailsInTheSecondRound) { + auto mock_filesystem = std::make_unique<MockFilesystem>(); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<FileBackedVector<int>> vector, + FileBackedVector<int>::Create( + *mock_filesystem, file_path_, + MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC)); + + // At first, the vector is empty and has no mapping established. The first Set + // call will cause a Grow. + // During Grow, we call PWrite for several rounds to grow the file. Mock + // PWrite to succeed in the first round, fail in the second round, and succeed + // in the rest rounds. + + // This unit test checks if we check file size and Remap properly. If the + // first PWrite succeeds but the second PWrite fails, then the file size has + // been grown, but there will be no Remap for the MemoryMappedFile. Then, + // the next several Append() won't require file growth since the file size has + // been grown, but it causes memory error because we haven't remapped. + EXPECT_CALL(*mock_filesystem, + PWrite(A<int>(), A<off_t>(), A<const void*>(), A<size_t>())) + .WillOnce(DoDefault()) + .WillOnce(Return(false)) + .WillRepeatedly(DoDefault()); + + EXPECT_THAT(vector->Append(7), + StatusIs(libtextclassifier3::StatusCode::INTERNAL)); + EXPECT_THAT(vector->Get(/*idx=*/0), + StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); + + EXPECT_THAT(vector->Append(7), IsOk()); + EXPECT_THAT(vector->Get(/*idx=*/0), IsOkAndHolds(Pointee(Eq(7)))); +} + } // namespace } // namespace lib diff --git a/icing/file/persistent-hash-map.cc b/icing/file/persistent-hash-map.cc index d20285a..43530dd 100644 --- a/icing/file/persistent-hash-map.cc +++ b/icing/file/persistent-hash-map.cc @@ -14,6 +14,7 @@ #include "icing/file/persistent-hash-map.h" +#include <cstdint> #include <cstring> #include <memory> #include <string> @@ -213,16 +214,16 @@ libtextclassifier3::Status PersistentHashMap::Put(std::string_view key, int32_t bucket_idx, HashKeyToBucketIndex(key, bucket_storage_->num_elements())); - ICING_ASSIGN_OR_RETURN(int32_t target_entry_idx, + ICING_ASSIGN_OR_RETURN(EntryIndexPair idx_pair, FindEntryIndexByKey(bucket_idx, key)); - if (target_entry_idx == Entry::kInvalidIndex) { + if (idx_pair.target_entry_index == Entry::kInvalidIndex) { // If not found, then insert new key value pair. return Insert(bucket_idx, key, value); } // Otherwise, overwrite the value. ICING_ASSIGN_OR_RETURN(const Entry* entry, - entry_storage_->Get(target_entry_idx)); + entry_storage_->Get(idx_pair.target_entry_index)); int32_t kv_len = key.length() + 1 + info()->value_type_size; int32_t value_offset = key.length() + 1; @@ -244,15 +245,15 @@ libtextclassifier3::Status PersistentHashMap::GetOrPut(std::string_view key, int32_t bucket_idx, HashKeyToBucketIndex(key, bucket_storage_->num_elements())); - ICING_ASSIGN_OR_RETURN(int32_t target_entry_idx, + ICING_ASSIGN_OR_RETURN(EntryIndexPair idx_pair, FindEntryIndexByKey(bucket_idx, key)); - if (target_entry_idx == Entry::kInvalidIndex) { + if (idx_pair.target_entry_index == Entry::kInvalidIndex) { // If not found, then insert new key value pair. return Insert(bucket_idx, key, next_value); } // Otherwise, copy the hash map value into next_value. - return CopyEntryValue(target_entry_idx, next_value); + return CopyEntryValue(idx_pair.target_entry_index, next_value); } libtextclassifier3::Status PersistentHashMap::Get(std::string_view key, @@ -262,14 +263,76 @@ libtextclassifier3::Status PersistentHashMap::Get(std::string_view key, int32_t bucket_idx, HashKeyToBucketIndex(key, bucket_storage_->num_elements())); - ICING_ASSIGN_OR_RETURN(int32_t target_entry_idx, + ICING_ASSIGN_OR_RETURN(EntryIndexPair idx_pair, FindEntryIndexByKey(bucket_idx, key)); - if (target_entry_idx == Entry::kInvalidIndex) { + if (idx_pair.target_entry_index == Entry::kInvalidIndex) { return absl_ports::NotFoundError( absl_ports::StrCat("Key not found in PersistentHashMap ", base_dir_)); } - return CopyEntryValue(target_entry_idx, value); + return CopyEntryValue(idx_pair.target_entry_index, value); +} + +libtextclassifier3::Status PersistentHashMap::Delete(std::string_view key) { + ICING_RETURN_IF_ERROR(ValidateKey(key)); + ICING_ASSIGN_OR_RETURN( + int32_t bucket_idx, + HashKeyToBucketIndex(key, bucket_storage_->num_elements())); + + ICING_ASSIGN_OR_RETURN(EntryIndexPair idx_pair, + FindEntryIndexByKey(bucket_idx, key)); + if (idx_pair.target_entry_index == Entry::kInvalidIndex) { + return absl_ports::NotFoundError( + absl_ports::StrCat("Key not found in PersistentHashMap ", base_dir_)); + } + + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<Entry>::MutableView mutable_target_entry, + entry_storage_->GetMutable(idx_pair.target_entry_index)); + if (idx_pair.prev_entry_index == Entry::kInvalidIndex) { + // If prev_entry_idx is Entry::kInvalidIndex, then target_entry must be the + // head element of the entry linked list, and we have to update + // bucket->head_entry_index_. + // + // Before: target_entry (head) -> next_entry -> ... + // After: next_entry (head) -> ... + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<Bucket>::MutableView mutable_bucket, + bucket_storage_->GetMutable(bucket_idx)); + if (mutable_bucket.Get().head_entry_index() != + idx_pair.target_entry_index) { + return absl_ports::InternalError( + "Bucket head entry index is inconsistent with the actual entry linked" + "list head. This shouldn't happen"); + } + mutable_bucket.Get().set_head_entry_index( + mutable_target_entry.Get().next_entry_index()); + } else { + // Otherwise, connect prev_entry and next_entry, to remove target_entry from + // the entry linked list. + // + // Before: ... -> prev_entry -> target_entry -> next_entry -> ... + // After: ... -> prev_entry -> next_entry -> ... + ICING_ASSIGN_OR_RETURN( + typename FileBackedVector<Entry>::MutableView mutable_prev_entry, + entry_storage_->GetMutable(idx_pair.prev_entry_index)); + mutable_prev_entry.Get().set_next_entry_index( + mutable_target_entry.Get().next_entry_index()); + } + + // Zero out the key value bytes. It is necessary for iterator to iterate + // through kv_storage and handle deleted keys properly. + int32_t kv_len = key.length() + 1 + info()->value_type_size; + ICING_RETURN_IF_ERROR(kv_storage_->Set( + mutable_target_entry.Get().key_value_index(), kv_len, '\0')); + + // Invalidate target_entry + mutable_target_entry.Get().set_key_value_index(kInvalidKVIndex); + mutable_target_entry.Get().set_next_entry_index(Entry::kInvalidIndex); + + ++(info()->num_deleted_entries); + + return libtextclassifier3::Status::OK; } libtextclassifier3::Status PersistentHashMap::PersistToDisk() { @@ -440,7 +503,9 @@ PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem, // Allow max_load_factor_percent_ change. if (max_load_factor_percent != info_ptr->max_load_factor_percent) { - ICING_VLOG(2) << "Changing max_load_factor_percent from " << info_ptr->max_load_factor_percent << " to " << max_load_factor_percent; + ICING_VLOG(2) << "Changing max_load_factor_percent from " + << info_ptr->max_load_factor_percent << " to " + << max_load_factor_percent; info_ptr->max_load_factor_percent = max_load_factor_percent; crcs_ptr->component_crcs.info_crc = info_ptr->ComputeChecksum().Get(); @@ -455,12 +520,15 @@ PersistentHashMap::InitializeExistingFiles(const Filesystem& filesystem, std::move(kv_storage))); } -libtextclassifier3::StatusOr<int32_t> PersistentHashMap::FindEntryIndexByKey( - int32_t bucket_idx, std::string_view key) const { +libtextclassifier3::StatusOr<PersistentHashMap::EntryIndexPair> +PersistentHashMap::FindEntryIndexByKey(int32_t bucket_idx, + std::string_view key) const { // Iterate all entries in the bucket, compare with key, and return the entry // index if exists. ICING_ASSIGN_OR_RETURN(const Bucket* bucket, bucket_storage_->Get(bucket_idx)); + + int32_t prev_entry_idx = Entry::kInvalidIndex; int32_t curr_entry_idx = bucket->head_entry_index(); while (curr_entry_idx != Entry::kInvalidIndex) { ICING_ASSIGN_OR_RETURN(const Entry* entry, @@ -473,13 +541,14 @@ libtextclassifier3::StatusOr<int32_t> PersistentHashMap::FindEntryIndexByKey( ICING_ASSIGN_OR_RETURN(const char* kv_arr, kv_storage_->Get(entry->key_value_index())); if (key.compare(kv_arr) == 0) { - return curr_entry_idx; + return EntryIndexPair(curr_entry_idx, prev_entry_idx); } + prev_entry_idx = curr_entry_idx; curr_entry_idx = entry->next_entry_index(); } - return curr_entry_idx; + return EntryIndexPair(curr_entry_idx, prev_entry_idx); } libtextclassifier3::Status PersistentHashMap::CopyEntryValue( @@ -530,5 +599,27 @@ libtextclassifier3::Status PersistentHashMap::Insert(int32_t bucket_idx, return libtextclassifier3::Status::OK; } +bool PersistentHashMap::Iterator::Advance() { + // Jump over the current key value pair before advancing to the next valid + // key value pair. In the first round (after construction), curr_key_len_ + // is 0, so don't jump over anything. + if (curr_key_len_ != 0) { + curr_kv_idx_ += curr_key_len_ + 1 + map_->info()->value_type_size; + curr_key_len_ = 0; + } + + // By skipping null chars, we will be automatically handling deleted entries + // (which are zeroed out during deletion). + for (const char* curr_kv_ptr = map_->kv_storage_->array() + curr_kv_idx_; + curr_kv_idx_ < map_->kv_storage_->num_elements(); + ++curr_kv_ptr, ++curr_kv_idx_) { + if (*curr_kv_ptr != '\0') { + curr_key_len_ = strlen(curr_kv_ptr); + return true; + } + } + return false; +} + } // namespace lib } // namespace icing diff --git a/icing/file/persistent-hash-map.h b/icing/file/persistent-hash-map.h index 24a47ea..a1ca25d 100644 --- a/icing/file/persistent-hash-map.h +++ b/icing/file/persistent-hash-map.h @@ -36,6 +36,48 @@ namespace lib { // should not contain termination character '\0'. class PersistentHashMap { public: + // For iterating through persistent hash map. The order is not guaranteed. + // + // Not thread-safe. + // + // Change in underlying persistent hash map invalidates iterator. + class Iterator { + public: + // Advance to the next entry. + // + // Returns: + // True on success, otherwise false. + bool Advance(); + + // Get the key. + // + // REQUIRES: The preceding call for Advance() is true. + std::string_view GetKey() const { + return std::string_view(map_->kv_storage_->array() + curr_kv_idx_, + curr_key_len_); + } + + // Get the memory mapped address of the value. + // + // REQUIRES: The preceding call for Advance() is true. + const void* GetValue() const { + return static_cast<const void*>(map_->kv_storage_->array() + + curr_kv_idx_ + curr_key_len_ + 1); + } + + private: + explicit Iterator(const PersistentHashMap* map) + : map_(map), curr_kv_idx_(0), curr_key_len_(0) {} + + // Does not own + const PersistentHashMap* map_; + + int32_t curr_kv_idx_; + int32_t curr_key_len_; + + friend class PersistentHashMap; + }; + // Crcs and Info will be written into the metadata file. // File layout: <Crcs><Info> // Crcs @@ -257,6 +299,19 @@ class PersistentHashMap { // Any FileBackedVector errors libtextclassifier3::Status Get(std::string_view key, void* value) const; + // Delete the key value pair from the storage. If key doesn't exist, then do + // nothing and return NOT_FOUND_ERROR. + // + // Returns: + // OK on success + // NOT_FOUND_ERROR if the key doesn't exist + // INVALID_ARGUMENT_ERROR if the key is invalid (i.e. contains '\0') + // INTERNAL_ERROR on I/O error or any data inconsistency + // Any FileBackedVector errors + libtextclassifier3::Status Delete(std::string_view key); + + Iterator GetIterator() const { return Iterator(this); } + // Flushes content to underlying files. // // Returns: @@ -296,7 +351,19 @@ class PersistentHashMap { bool empty() const { return size() == 0; } + int32_t num_buckets() const { return bucket_storage_->num_elements(); } + private: + struct EntryIndexPair { + int32_t target_entry_index; + int32_t prev_entry_index; + + explicit EntryIndexPair(int32_t target_entry_index_in, + int32_t prev_entry_index_in) + : target_entry_index(target_entry_index_in), + prev_entry_index(prev_entry_index_in) {} + }; + explicit PersistentHashMap( const Filesystem& filesystem, std::string_view base_dir, std::unique_ptr<MemoryMappedFile> metadata_mmapped_file, @@ -319,15 +386,18 @@ class PersistentHashMap { std::string_view base_dir, int32_t value_type_size, int32_t max_load_factor_percent); - // Find the index of the key entry from a bucket (specified by bucket index). - // The caller should specify the desired bucket index. + // Find the index of the target entry (that contains the key) from a bucket + // (specified by bucket index). Also return the previous entry index, since + // Delete() needs it to update the linked list and head entry index. The + // caller should specify the desired bucket index. // // Returns: - // int32_t: on success, the index of the entry, or Entry::kInvalidIndex if - // not found + // std::pair<int32_t, int32_t>: target entry index and previous entry index + // on success. If not found, then target entry + // index will be Entry::kInvalidIndex // INTERNAL_ERROR if any content inconsistency // Any FileBackedVector errors - libtextclassifier3::StatusOr<int32_t> FindEntryIndexByKey( + libtextclassifier3::StatusOr<EntryIndexPair> FindEntryIndexByKey( int32_t bucket_idx, std::string_view key) const; // Copy the hash map value of the entry into value buffer. diff --git a/icing/file/persistent-hash-map_test.cc b/icing/file/persistent-hash-map_test.cc index fb15175..138320c 100644 --- a/icing/file/persistent-hash-map_test.cc +++ b/icing/file/persistent-hash-map_test.cc @@ -15,6 +15,9 @@ #include "icing/file/persistent-hash-map.h" #include <cstring> +#include <string_view> +#include <unordered_map> +#include <unordered_set> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" @@ -34,12 +37,16 @@ namespace { static constexpr int32_t kCorruptedValueOffset = 3; +using ::testing::Contains; using ::testing::Eq; using ::testing::HasSubstr; using ::testing::IsEmpty; +using ::testing::Key; using ::testing::Not; +using ::testing::Pair; using ::testing::Pointee; using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; using Bucket = PersistentHashMap::Bucket; using Crcs = PersistentHashMap::Crcs; @@ -69,6 +76,18 @@ class PersistentHashMapTest : public ::testing::Test { return val; } + std::unordered_map<std::string, int> GetAllKeyValuePairs( + PersistentHashMap::Iterator&& iter) { + std::unordered_map<std::string, int> kvps; + + while (iter.Advance()) { + int val; + memcpy(&val, iter.GetValue(), sizeof(val)); + kvps.emplace(iter.GetKey(), val); + } + return kvps; + } + Filesystem filesystem_; std::string base_dir_; }; @@ -148,7 +167,10 @@ TEST_F(PersistentHashMapTest, // Put some key value pairs. ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); - // TODO(b/193919210): call Delete() to change PersistentHashMap header + ICING_ASSERT_OK(persistent_hash_map->Put("c", Serialize(3).data())); + // Call Delete() to change PersistentHashMap metadata info + // (num_deleted_entries) + ICING_ASSERT_OK(persistent_hash_map->Delete("c")); ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); @@ -178,7 +200,10 @@ TEST_F(PersistentHashMapTest, TestInitializationSucceedsWithPersistToDisk) { // Put some key value pairs. ICING_ASSERT_OK(persistent_hash_map1->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map1->Put("b", Serialize(2).data())); - // TODO(b/193919210): call Delete() to change PersistentHashMap header + ICING_ASSERT_OK(persistent_hash_map1->Put("c", Serialize(3).data())); + // Call Delete() to change PersistentHashMap metadata info + // (num_deleted_entries) + ICING_ASSERT_OK(persistent_hash_map1->Delete("c")); ASSERT_THAT(persistent_hash_map1, Pointee(SizeIs(2))); ASSERT_THAT(GetValueByKey(persistent_hash_map1.get(), "a"), IsOkAndHolds(1)); @@ -214,7 +239,10 @@ TEST_F(PersistentHashMapTest, TestInitializationSucceedsAfterDestruction) { /*max_load_factor_percent=*/1000)); ICING_ASSERT_OK(persistent_hash_map->Put("a", Serialize(1).data())); ICING_ASSERT_OK(persistent_hash_map->Put("b", Serialize(2).data())); - // TODO(b/193919210): call Delete() to change PersistentHashMap header + ICING_ASSERT_OK(persistent_hash_map->Put("c", Serialize(3).data())); + // Call Delete() to change PersistentHashMap metadata info + // (num_deleted_entries) + ICING_ASSERT_OK(persistent_hash_map->Delete("c")); ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), "a"), IsOkAndHolds(1)); @@ -637,6 +665,218 @@ TEST_F(PersistentHashMapTest, GetOrPutShouldGetIfKeyExists) { IsOkAndHolds(1)); } +TEST_F(PersistentHashMapTest, Delete) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + // Delete a non-existing key should get NOT_FOUND error + EXPECT_THAT(persistent_hash_map->Delete("default-google.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com", Serialize(100).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-youtube.com", Serialize(50).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + + // Delete an existing key should succeed + ICING_EXPECT_OK(persistent_hash_map->Delete("default-google.com")); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + // The deleted key should not be found. + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + // Other key should remain unchanged and available + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-youtube.com"), + IsOkAndHolds(50)); + + // Insert back the deleted key. Should get new value + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com", Serialize(200).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(2))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + IsOkAndHolds(200)); + + // Delete again + ICING_EXPECT_OK(persistent_hash_map->Delete("default-google.com")); + EXPECT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + // Other keys should remain unchanged and available + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-youtube.com"), + IsOkAndHolds(50)); +} + +TEST_F(PersistentHashMapTest, DeleteMultiple) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + std::unordered_map<std::string, int> existing_keys; + std::unordered_set<std::string> deleted_keys; + // Insert 100 key value pairs + for (int i = 0; i < 100; ++i) { + std::string key = "default-google.com-" + std::to_string(i); + ICING_ASSERT_OK(persistent_hash_map->Put(key, &i)); + existing_keys[key] = i; + } + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(existing_keys.size()))); + + // Delete several keys. + // Simulate with std::unordered_map and verify. + std::vector<int> delete_target_ids{3, 4, 6, 9, 13, 18, 24, 31, 39, 48, 58}; + for (const int delete_target_id : delete_target_ids) { + std::string key = "default-google.com-" + std::to_string(delete_target_id); + ASSERT_THAT(existing_keys, Contains(Key(key))); + ASSERT_THAT(GetValueByKey(persistent_hash_map.get(), key), + IsOkAndHolds(existing_keys[key])); + ICING_EXPECT_OK(persistent_hash_map->Delete(key)); + + existing_keys.erase(key); + deleted_keys.insert(key); + } + + // Deleted keys should not be found. + for (const std::string& deleted_key : deleted_keys) { + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), deleted_key), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + } + // Other keys should remain unchanged and available + for (const auto& [existing_key, existing_value] : existing_keys) { + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), existing_key), + IsOkAndHolds(existing_value)); + } + // Verify by iterator as well + EXPECT_THAT(GetAllKeyValuePairs(persistent_hash_map->GetIterator()), + Eq(existing_keys)); +} + +TEST_F(PersistentHashMapTest, DeleteBucketHeadElement) { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing. + // Preventing rehashing makes it much easier to test collisions. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-1", Serialize(1).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-2", Serialize(2).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(3))); + ASSERT_THAT(persistent_hash_map->num_buckets(), Eq(1)); + + // Delete the head element of the bucket. Note that in our implementation, the + // last added element will become the head element of the bucket. + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com-2")); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com-0"), + IsOkAndHolds(0)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com-1"), + IsOkAndHolds(1)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com-2"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); +} + +TEST_F(PersistentHashMapTest, DeleteBucketIntermediateElement) { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing. + // Preventing rehashing makes it much easier to test collisions. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-1", Serialize(1).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-2", Serialize(2).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(3))); + ASSERT_THAT(persistent_hash_map->num_buckets(), Eq(1)); + + // Delete any intermediate element of the bucket. + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com-1")); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com-0"), + IsOkAndHolds(0)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com-1"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com-2"), + IsOkAndHolds(2)); +} + +TEST_F(PersistentHashMapTest, DeleteBucketTailElement) { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing. + // Preventing rehashing makes it much easier to test collisions. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-1", Serialize(1).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-2", Serialize(2).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(3))); + ASSERT_THAT(persistent_hash_map->num_buckets(), Eq(1)); + + // Delete the last element of the bucket. Note that in our implementation, the + // first added element will become the tail element of the bucket. + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com-0")); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com-0"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com-1"), + IsOkAndHolds(1)); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com-2"), + IsOkAndHolds(2)); +} + +TEST_F(PersistentHashMapTest, DeleteBucketOnlySingleElement) { + // Create new persistent hash map + // Set max_load_factor_percent as 1000. Load factor percent is calculated as + // 100 * num_keys / num_buckets. Therefore, with 1 bucket (the initial # of + // buckets in an empty PersistentHashMap) and a max_load_factor_percent of + // 1000, we would allow the insertion of up to 10 keys before rehashing. + // Preventing rehashing makes it much easier to test collisions. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int), + /*max_load_factor_percent=*/1000)); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com", Serialize(100).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(1))); + + // Delete the only single element of the bucket. + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com")); + ASSERT_THAT(persistent_hash_map, Pointee(IsEmpty())); + EXPECT_THAT(GetValueByKey(persistent_hash_map.get(), "default-google.com"), + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); +} + TEST_F(PersistentHashMapTest, ShouldFailIfKeyContainsTerminationCharacter) { // Create new persistent hash map ICING_ASSERT_OK_AND_ASSIGN( @@ -654,6 +894,127 @@ TEST_F(PersistentHashMapTest, ShouldFailIfKeyContainsTerminationCharacter) { StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); EXPECT_THAT(persistent_hash_map->Get(invalid_key_view, &val), StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); + EXPECT_THAT(persistent_hash_map->Delete(invalid_key_view), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(PersistentHashMapTest, EmptyHashMapIterator) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + EXPECT_FALSE(persistent_hash_map->GetIterator().Advance()); +} + +TEST_F(PersistentHashMapTest, Iterator) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + std::unordered_map<std::string, int> kvps; + // Insert 100 key value pairs + for (int i = 0; i < 100; ++i) { + std::string key = "default-google.com-" + std::to_string(i); + ICING_ASSERT_OK(persistent_hash_map->Put(key, &i)); + kvps.emplace(key, i); + } + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(kvps.size()))); + + EXPECT_THAT(GetAllKeyValuePairs(persistent_hash_map->GetIterator()), + Eq(kvps)); +} + +TEST_F(PersistentHashMapTest, IteratorAfterDeletingFirstKeyValuePair) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-1", Serialize(1).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-2", Serialize(2).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(3))); + + // Delete the first key value pair. + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com-0")); + EXPECT_THAT(GetAllKeyValuePairs(persistent_hash_map->GetIterator()), + UnorderedElementsAre(Pair("default-google.com-1", 1), + Pair("default-google.com-2", 2))); +} + +TEST_F(PersistentHashMapTest, IteratorAfterDeletingIntermediateKeyValuePair) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-1", Serialize(1).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-2", Serialize(2).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(3))); + + // Delete any intermediate key value pair. + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com-1")); + EXPECT_THAT(GetAllKeyValuePairs(persistent_hash_map->GetIterator()), + UnorderedElementsAre(Pair("default-google.com-0", 0), + Pair("default-google.com-2", 2))); +} + +TEST_F(PersistentHashMapTest, IteratorAfterDeletingLastKeyValuePair) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-1", Serialize(1).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-2", Serialize(2).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(3))); + + // Delete the last key value pair. + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com-2")); + EXPECT_THAT(GetAllKeyValuePairs(persistent_hash_map->GetIterator()), + UnorderedElementsAre(Pair("default-google.com-0", 0), + Pair("default-google.com-1", 1))); +} + +TEST_F(PersistentHashMapTest, IteratorAfterDeletingAllKeyValuePairs) { + // Create new persistent hash map + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<PersistentHashMap> persistent_hash_map, + PersistentHashMap::Create(filesystem_, base_dir_, + /*value_type_size=*/sizeof(int))); + + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-0", Serialize(0).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-1", Serialize(1).data())); + ICING_ASSERT_OK( + persistent_hash_map->Put("default-google.com-2", Serialize(2).data())); + ASSERT_THAT(persistent_hash_map, Pointee(SizeIs(3))); + + // Delete all key value pairs. + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com-0")); + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com-1")); + ICING_ASSERT_OK(persistent_hash_map->Delete("default-google.com-2")); + ASSERT_THAT(persistent_hash_map, Pointee(IsEmpty())); + EXPECT_FALSE(persistent_hash_map->GetIterator().Advance()); } } // namespace diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index 4089ec9..4bb7d55 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -1113,7 +1113,8 @@ DeleteByQueryResultProto IcingSearchEngine::DeleteByQuery( std::unique_ptr<QueryProcessor> query_processor = std::move(query_processor_or).ValueOrDie(); - auto query_results_or = query_processor->ParseSearch(search_spec); + auto query_results_or = query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::NONE); if (!query_results_or.ok()) { TransformStatus(query_results_or.status(), result_status); delete_stats->set_parse_query_latency_ms( @@ -1259,7 +1260,8 @@ OptimizeResultProto IcingSearchEngine::Optimize() { optimize_stats->set_index_restoration_mode( OptimizeStatsProto::INDEX_TRANSLATION); libtextclassifier3::Status index_optimize_status = - index_->Optimize(document_id_old_to_new_or.ValueOrDie()); + index_->Optimize(document_id_old_to_new_or.ValueOrDie(), + document_store_->last_added_document_id()); if (!index_optimize_status.ok()) { ICING_LOG(WARNING) << "Failed to optimize index. Error: " << index_optimize_status.error_message(); @@ -1487,20 +1489,22 @@ SearchResultProto IcingSearchEngine::Search( const ResultSpecProto& result_spec) { SearchResultProto result_proto; StatusProto* result_status = result_proto.mutable_status(); + + QueryStatsProto* query_stats = result_proto.mutable_query_stats(); + query_stats->set_query_length(search_spec.query().length()); + ScopedTimer overall_timer(clock_->GetNewTimer(), [query_stats](int64_t t) { + query_stats->set_latency_ms(t); + }); // TODO(b/146008613) Explore ideas to make this function read-only. absl_ports::unique_lock l(&mutex_); + query_stats->set_lock_acquisition_latency_ms( + overall_timer.timer().GetElapsedMilliseconds()); if (!initialized_) { result_status->set_code(StatusProto::FAILED_PRECONDITION); result_status->set_message("IcingSearchEngine has not been initialized!"); return result_proto; } - QueryStatsProto* query_stats = result_proto.mutable_query_stats(); - query_stats->set_query_length(search_spec.query().length()); - ScopedTimer overall_timer(clock_->GetNewTimer(), [query_stats](int64_t t) { - query_stats->set_latency_ms(t); - }); - libtextclassifier3::Status status = ValidateResultSpec(result_spec); if (!status.ok()) { TransformStatus(status, result_status); @@ -1534,7 +1538,8 @@ SearchResultProto IcingSearchEngine::Search( std::unique_ptr<QueryProcessor> query_processor = std::move(query_processor_or).ValueOrDie(); - auto query_results_or = query_processor->ParseSearch(search_spec); + auto query_results_or = + query_processor->ParseSearch(search_spec, scoring_spec.rank_by()); if (!query_results_or.ok()) { TransformStatus(query_results_or.status(), result_status); query_stats->set_parse_query_latency_ms( @@ -1643,19 +1648,20 @@ SearchResultProto IcingSearchEngine::GetNextPage(uint64_t next_page_token) { SearchResultProto result_proto; StatusProto* result_status = result_proto.mutable_status(); + QueryStatsProto* query_stats = result_proto.mutable_query_stats(); + query_stats->set_is_first_page(false); + std::unique_ptr<Timer> overall_timer = clock_->GetNewTimer(); // ResultStateManager has its own writer lock, so here we only need a reader // lock for other components. absl_ports::shared_lock l(&mutex_); + query_stats->set_lock_acquisition_latency_ms( + overall_timer->GetElapsedMilliseconds()); if (!initialized_) { result_status->set_code(StatusProto::FAILED_PRECONDITION); result_status->set_message("IcingSearchEngine has not been initialized!"); return result_proto; } - QueryStatsProto* query_stats = result_proto.mutable_query_stats(); - query_stats->set_is_first_page(false); - - std::unique_ptr<Timer> overall_timer = clock_->GetNewTimer(); auto result_retriever_or = ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), language_segmenter_.get(), normalizer_.get()); diff --git a/icing/icing-search-engine_backwards_compatibility_test.cc b/icing/icing-search-engine_backwards_compatibility_test.cc new file mode 100644 index 0000000..2574313 --- /dev/null +++ b/icing/icing-search-engine_backwards_compatibility_test.cc @@ -0,0 +1,395 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <cstdint> +#include <limits> +#include <memory> +#include <string> +#include <utility> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/absl_ports/str_cat.h" +#include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/icing-search-engine.h" +#include "icing/portable/endian.h" +#include "icing/portable/equals-proto.h" +#include "icing/portable/platform.h" +#include "icing/schema-builder.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/jni-test-helpers.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { + +using ::icing::lib::portable_equals_proto::EqualsProto; +using ::testing::Eq; + +constexpr TermMatchType::Code MATCH_EXACT = TermMatchType::EXACT_ONLY; +constexpr PropertyConfigProto::Cardinality::Code CARDINALITY_OPTIONAL = + PropertyConfigProto::Cardinality::OPTIONAL; +constexpr StringIndexingConfig::TokenizerType::Code TOKENIZER_PLAIN = + StringIndexingConfig::TokenizerType::PLAIN; + +// For mocking purpose, we allow tests to provide a custom Filesystem. +class TestIcingSearchEngine : public IcingSearchEngine { + public: + TestIcingSearchEngine(const IcingSearchEngineOptions& options, + std::unique_ptr<const Filesystem> filesystem, + std::unique_ptr<const IcingFilesystem> icing_filesystem, + std::unique_ptr<Clock> clock, + std::unique_ptr<JniCache> jni_cache) + : IcingSearchEngine(options, std::move(filesystem), + std::move(icing_filesystem), std::move(clock), + std::move(jni_cache)) {} +}; + +std::string GetTestBaseDir() { return GetTestTempDir() + "/icing"; } + +class IcingSearchEngineBackwardsCompatibilityTest : public testing::Test { + protected: + void SetUp() override { + filesystem_.CreateDirectoryRecursively(GetTestBaseDir().c_str()); + } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(GetTestBaseDir().c_str()); + } + + const Filesystem* filesystem() const { return &filesystem_; } + + private: + Filesystem filesystem_; +}; + +ScoringSpecProto GetDefaultScoringSpec() { + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + return scoring_spec; +} + +std::string GetTestDataDir(std::string_view test_subdir) { + if (IsAndroidX86()) { + return GetTestFilePath( + absl_ports::StrCat("icing/testdata/", test_subdir, + "/icing_search_engine_android_x86")); + } else if (IsAndroidArm()) { + return GetTestFilePath( + absl_ports::StrCat("icing/testdata/", test_subdir, + "/icing_search_engine_android_arm")); + } else if (IsIosPlatform()) { + return GetTestFilePath(absl_ports::StrCat("icing/testdata/", + test_subdir, + "/icing_search_engine_ios")); + } else { + return GetTestFilePath(absl_ports::StrCat("icing/testdata/", + test_subdir, + "/icing_search_engine_linux")); + } +} + +TEST_F(IcingSearchEngineBackwardsCompatibilityTest, + MigrateToPortableFileBackedProtoLog) { + // Copy the testdata files into our IcingSearchEngine directory + std::string dir_without_portable_log = GetTestDataDir("not_portable_log"); + + // Create dst directory that we'll initialize the IcingSearchEngine over. + std::string base_dir = GetTestBaseDir() + "_migrate"; + ASSERT_THAT(filesystem()->DeleteDirectoryRecursively(base_dir.c_str()), true); + ASSERT_THAT(filesystem()->CreateDirectoryRecursively(base_dir.c_str()), true); + + ASSERT_TRUE(filesystem()->CopyDirectory(dir_without_portable_log.c_str(), + base_dir.c_str(), + /*recursive=*/true)); + + IcingSearchEngineOptions icing_options; + icing_options.set_base_dir(base_dir); + + IcingSearchEngine icing(icing_options, GetTestJniCache()); + InitializeResultProto init_result = icing.Initialize(); + EXPECT_THAT(init_result.status(), ProtoIsOk()); + EXPECT_THAT(init_result.initialize_stats().document_store_data_status(), + Eq(InitializeStatsProto::NO_DATA_LOSS)); + EXPECT_THAT(init_result.initialize_stats().document_store_recovery_cause(), + Eq(InitializeStatsProto::LEGACY_DOCUMENT_LOG_FORMAT)); + EXPECT_THAT(init_result.initialize_stats().schema_store_recovery_cause(), + Eq(InitializeStatsProto::NONE)); + // The main and lite indexes are in legacy formats and therefore will need to + // be rebuilt from scratch. + EXPECT_THAT(init_result.initialize_stats().index_restoration_cause(), + Eq(InitializeStatsProto::IO_ERROR)); + + // Set up schema, this is the one used to validate documents in the testdata + // files. Do not change unless you're also updating the testdata files. + SchemaProto schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("email") + .AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString(MATCH_EXACT, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString(MATCH_EXACT, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + // Make sure our schema is still the same as we expect. If not, there's + // definitely no way we're getting the documents back that we expect. + GetSchemaResultProto expected_get_schema_result_proto; + expected_get_schema_result_proto.mutable_status()->set_code(StatusProto::OK); + *expected_get_schema_result_proto.mutable_schema() = schema; + ASSERT_THAT(icing.GetSchema(), EqualsProto(expected_get_schema_result_proto)); + + // These are the documents that are stored in the testdata files. Do not + // change unless you're also updating the testdata files. + DocumentProto document1 = DocumentBuilder() + .SetKey("namespace1", "uri1") + .SetSchema("email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "foo") + .AddStringProperty("body", "bar") + .Build(); + + DocumentProto document2 = DocumentBuilder() + .SetKey("namespace1", "uri2") + .SetSchema("email") + .SetCreationTimestampMs(20) + .SetScore(321) + .AddStringProperty("body", "baz bat") + .Build(); + + DocumentProto document3 = DocumentBuilder() + .SetKey("namespace2", "uri1") + .SetSchema("email") + .SetCreationTimestampMs(30) + .SetScore(123) + .AddStringProperty("subject", "phoo") + .Build(); + + // Document 1 and 3 were put normally, and document 2 was deleted in our + // testdata files. + EXPECT_THAT(icing + .Get(document1.namespace_(), document1.uri(), + GetResultSpecProto::default_instance()) + .document(), + EqualsProto(document1)); + EXPECT_THAT(icing + .Get(document2.namespace_(), document2.uri(), + GetResultSpecProto::default_instance()) + .status(), + ProtoStatusIs(StatusProto::NOT_FOUND)); + EXPECT_THAT(icing + .Get(document3.namespace_(), document3.uri(), + GetResultSpecProto::default_instance()) + .document(), + EqualsProto(document3)); + + // Searching for "foo" should get us document1. + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::PREFIX); + search_spec.set_query("foo"); + + SearchResultProto expected_document1; + expected_document1.mutable_status()->set_code(StatusProto::OK); + *expected_document1.mutable_results()->Add()->mutable_document() = document1; + + SearchResultProto actual_results = + icing.Search(search_spec, GetDefaultScoringSpec(), + ResultSpecProto::default_instance()); + EXPECT_THAT(actual_results, + EqualsSearchResultIgnoreStatsAndScores(expected_document1)); + + // Searching for "baz" would've gotten us document2, except it got deleted. + // Make sure that it's cleared from our index too. + search_spec.set_query("baz"); + + SearchResultProto expected_no_documents; + expected_no_documents.mutable_status()->set_code(StatusProto::OK); + + actual_results = icing.Search(search_spec, GetDefaultScoringSpec(), + ResultSpecProto::default_instance()); + EXPECT_THAT(actual_results, + EqualsSearchResultIgnoreStatsAndScores(expected_no_documents)); + + // Searching for "phoo" should get us document3. + search_spec.set_query("phoo"); + + SearchResultProto expected_document3; + expected_document3.mutable_status()->set_code(StatusProto::OK); + *expected_document3.mutable_results()->Add()->mutable_document() = document3; + + actual_results = icing.Search(search_spec, GetDefaultScoringSpec(), + ResultSpecProto::default_instance()); + EXPECT_THAT(actual_results, + EqualsSearchResultIgnoreStatsAndScores(expected_document3)); +} + +TEST_F(IcingSearchEngineBackwardsCompatibilityTest, MigrateToLargerScale) { + // Copy the testdata files into our IcingSearchEngine directory + std::string test_data_dir = GetTestDataDir("icing_scale_migration"); + + // Create dst directory that we'll initialize the IcingSearchEngine over. + std::string base_dir = GetTestBaseDir() + "_migrate"; + ASSERT_THAT(filesystem()->DeleteDirectoryRecursively(base_dir.c_str()), true); + ASSERT_THAT(filesystem()->CreateDirectoryRecursively(base_dir.c_str()), true); + + ASSERT_TRUE(filesystem()->CopyDirectory(test_data_dir.c_str(), + base_dir.c_str(), + /*recursive=*/true)); + + IcingSearchEngineOptions icing_options; + icing_options.set_base_dir(base_dir); + + IcingSearchEngine icing(icing_options, GetTestJniCache()); + InitializeResultProto init_result = icing.Initialize(); + EXPECT_THAT(init_result.status(), ProtoIsOk()); + EXPECT_THAT(init_result.initialize_stats().document_store_data_status(), + Eq(InitializeStatsProto::NO_DATA_LOSS)); + // No recovery is required for the document store. + EXPECT_THAT(init_result.initialize_stats().document_store_recovery_cause(), + Eq(InitializeStatsProto::NONE)); + EXPECT_THAT(init_result.initialize_stats().schema_store_recovery_cause(), + Eq(InitializeStatsProto::NONE)); + // The main and lite indexes are in legacy formats and therefore will need to + // be rebuilt from scratch. + EXPECT_THAT(init_result.initialize_stats().index_restoration_cause(), + Eq(InitializeStatsProto::IO_ERROR)); + + // Verify that the schema stored in the index matches the one that we expect. + // Do not change unless you're also updating the testdata files. + SchemaProto expected_schema = + SchemaBuilder() + .AddType(SchemaTypeConfigBuilder() + .SetType("email") + .AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString(MATCH_EXACT, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString(MATCH_EXACT, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .Build(); + + // Make sure our schema is still the same as we expect. If not, there's + // definitely no way we're getting the documents back that we expect. + GetSchemaResultProto expected_get_schema_result_proto; + expected_get_schema_result_proto.mutable_status()->set_code(StatusProto::OK); + *expected_get_schema_result_proto.mutable_schema() = expected_schema; + ASSERT_THAT(icing.GetSchema(), EqualsProto(expected_get_schema_result_proto)); + + // These are the documents that are stored in the testdata files. Do not + // change unless you're also updating the testdata files. + DocumentProto expected_document1 = DocumentBuilder() + .SetKey("namespace1", "uri1") + .SetSchema("email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "foo") + .AddStringProperty("body", "bar") + .Build(); + + DocumentProto expected_deleted_document2 = + DocumentBuilder() + .SetKey("namespace1", "uri2") + .SetSchema("email") + .SetCreationTimestampMs(20) + .SetScore(321) + .AddStringProperty("body", "baz bat") + .Build(); + + DocumentProto expected_document3 = DocumentBuilder() + .SetKey("namespace2", "uri1") + .SetSchema("email") + .SetCreationTimestampMs(30) + .SetScore(123) + .AddStringProperty("subject", "phoo") + .Build(); + + // Document 1 and 3 were put normally, and document 2 was deleted in our + // testdata files. + EXPECT_THAT( + icing + .Get(expected_document1.namespace_(), expected_document1.uri(), + GetResultSpecProto::default_instance()) + .document(), + EqualsProto(expected_document1)); + EXPECT_THAT(icing + .Get(expected_deleted_document2.namespace_(), + expected_deleted_document2.uri(), + GetResultSpecProto::default_instance()) + .status(), + ProtoStatusIs(StatusProto::NOT_FOUND)); + EXPECT_THAT( + icing + .Get(expected_document3.namespace_(), expected_document3.uri(), + GetResultSpecProto::default_instance()) + .document(), + EqualsProto(expected_document3)); + + // Searching for "foo" should get us document1. + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::PREFIX); + search_spec.set_query("foo"); + + SearchResultProto expected_document1_search; + expected_document1_search.mutable_status()->set_code(StatusProto::OK); + *expected_document1_search.mutable_results()->Add()->mutable_document() = + expected_document1; + + SearchResultProto actual_results = + icing.Search(search_spec, GetDefaultScoringSpec(), + ResultSpecProto::default_instance()); + EXPECT_THAT(actual_results, EqualsSearchResultIgnoreStatsAndScores( + expected_document1_search)); + + // Searching for "baz" would've gotten us document2, except it got deleted. + // Make sure that it's cleared from our index too. + search_spec.set_query("baz"); + + SearchResultProto expected_no_documents; + expected_no_documents.mutable_status()->set_code(StatusProto::OK); + + actual_results = icing.Search(search_spec, GetDefaultScoringSpec(), + ResultSpecProto::default_instance()); + EXPECT_THAT(actual_results, + EqualsSearchResultIgnoreStatsAndScores(expected_no_documents)); + + // Searching for "phoo" should get us document3. + search_spec.set_query("phoo"); + + SearchResultProto expected_document3_search; + expected_document3_search.mutable_status()->set_code(StatusProto::OK); + *expected_document3_search.mutable_results()->Add()->mutable_document() = + expected_document3; + + actual_results = icing.Search(search_spec, GetDefaultScoringSpec(), + ResultSpecProto::default_instance()); + EXPECT_THAT(actual_results, EqualsSearchResultIgnoreStatsAndScores( + expected_document3_search)); +} + +} // namespace +} // namespace lib +} // namespace icing diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc index 2ac456e..699e573 100644 --- a/icing/icing-search-engine_test.cc +++ b/icing/icing-search-engine_test.cc @@ -102,10 +102,6 @@ constexpr StringIndexingConfig::TokenizerType::Code TOKENIZER_PLAIN = constexpr StringIndexingConfig::TokenizerType::Code TOKENIZER_NONE = StringIndexingConfig::TokenizerType::NONE; -#ifndef ICING_JNI_TEST -constexpr TermMatchType::Code MATCH_EXACT = TermMatchType::EXACT_ONLY; -#endif // !ICING_JNI_TEST - constexpr TermMatchType::Code MATCH_PREFIX = TermMatchType::PREFIX; constexpr TermMatchType::Code MATCH_NONE = TermMatchType::UNKNOWN; @@ -2312,6 +2308,8 @@ TEST_F(IcingSearchEngineTest, SearchReturnsOneResult) { EXPECT_THAT(search_result_proto.query_stats().ranking_latency_ms(), Eq(1000)); EXPECT_THAT(search_result_proto.query_stats().document_retrieval_latency_ms(), Eq(1000)); + EXPECT_THAT(search_result_proto.query_stats().lock_acquisition_latency_ms(), + Eq(1000)); // The token is a random number so we don't verify it. expected_search_result_proto.set_next_page_token( @@ -2470,6 +2468,8 @@ TEST_F(IcingSearchEngineTest, SearchShouldReturnEmpty) { EXPECT_THAT(search_result_proto.query_stats().ranking_latency_ms(), Eq(0)); EXPECT_THAT(search_result_proto.query_stats().document_retrieval_latency_ms(), Eq(0)); + EXPECT_THAT(search_result_proto.query_stats().lock_acquisition_latency_ms(), + Eq(1000)); EXPECT_THAT(search_result_proto, EqualsSearchResultIgnoreStatsAndScores( expected_search_result_proto)); @@ -3003,6 +3003,54 @@ TEST_F(IcingSearchEngineTest, GetAndPutShouldWorkAfterOptimization) { EXPECT_THAT(icing.Put(document5).status(), ProtoIsOk()); } +TEST_F(IcingSearchEngineTest, + GetAndPutShouldWorkAfterOptimizationWithEmptyDocuments) { + DocumentProto empty_document1 = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Message") + .AddStringProperty("body", "") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + DocumentProto empty_document2 = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Message") + .AddStringProperty("body", "") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + DocumentProto empty_document3 = + DocumentBuilder() + .SetKey("namespace", "uri3") + .SetSchema("Message") + .AddStringProperty("body", "") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + GetResultProto expected_get_result_proto; + expected_get_result_proto.mutable_status()->set_code(StatusProto::OK); + + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); + + ASSERT_THAT(icing.Put(empty_document1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(empty_document2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Delete("namespace", "uri2").status(), ProtoIsOk()); + ASSERT_THAT(icing.Optimize().status(), ProtoIsOk()); + + // Validates that Get() and Put() are good right after Optimize() + *expected_get_result_proto.mutable_document() = empty_document1; + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()), + EqualsProto(expected_get_result_proto)); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()) + .status() + .code(), + Eq(StatusProto::NOT_FOUND)); + EXPECT_THAT(icing.Put(empty_document3).status(), ProtoIsOk()); +} + TEST_F(IcingSearchEngineTest, DeleteShouldWorkAfterOptimization) { DocumentProto document1 = CreateMessageDocument("namespace", "uri1"); DocumentProto document2 = CreateMessageDocument("namespace", "uri2"); @@ -6316,30 +6364,64 @@ TEST_F(IcingSearchEngineTest, SnippetSectionRestrict) { .Build(); ASSERT_THAT(icing.Put(document_one).status(), ProtoIsOk()); - SearchSpecProto search_spec; - search_spec.set_term_match_type(TermMatchType::PREFIX); - search_spec.set_query("body:Zür"); + DocumentProto document_two = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Email") + .AddStringProperty("subject", "MDI zurich trip") + .AddStringProperty("body", "Let's travel to zurich") + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); + ASSERT_THAT(icing.Put(document_two).status(), ProtoIsOk()); - ResultSpecProto result_spec; - result_spec.mutable_snippet_spec()->set_max_window_utf32_length(64); - result_spec.mutable_snippet_spec()->set_num_matches_per_property(10); - result_spec.mutable_snippet_spec()->set_num_to_snippet(10); + auto search_spec = std::make_unique<SearchSpecProto>(); + search_spec->set_term_match_type(TermMatchType::PREFIX); + search_spec->set_query("body:Zür"); + + auto result_spec = std::make_unique<ResultSpecProto>(); + result_spec->set_num_per_page(1); + result_spec->mutable_snippet_spec()->set_max_window_utf32_length(64); + result_spec->mutable_snippet_spec()->set_num_matches_per_property(10); + result_spec->mutable_snippet_spec()->set_num_to_snippet(10); + + auto scoring_spec = std::make_unique<ScoringSpecProto>(); + *scoring_spec = GetDefaultScoringSpec(); SearchResultProto results = - icing.Search(search_spec, GetDefaultScoringSpec(), result_spec); + icing.Search(*search_spec, *scoring_spec, *result_spec); EXPECT_THAT(results.status(), ProtoIsOk()); ASSERT_THAT(results.results(), SizeIs(1)); - const DocumentProto& result_document = results.results(0).document(); - const SnippetProto& result_snippet = results.results(0).snippet(); - EXPECT_THAT(result_document, EqualsProto(document_one)); - EXPECT_THAT(result_snippet.entries(), SizeIs(1)); - EXPECT_THAT(result_snippet.entries(0).property_name(), Eq("body")); - std::string_view content = - GetString(&result_document, result_snippet.entries(0).property_name()); - EXPECT_THAT(GetWindows(content, result_snippet.entries(0)), + const DocumentProto& result_document_two = results.results(0).document(); + const SnippetProto& result_snippet_two = results.results(0).snippet(); + EXPECT_THAT(result_document_two, EqualsProto(document_two)); + EXPECT_THAT(result_snippet_two.entries(), SizeIs(1)); + EXPECT_THAT(result_snippet_two.entries(0).property_name(), Eq("body")); + std::string_view content = GetString( + &result_document_two, result_snippet_two.entries(0).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet_two.entries(0)), + ElementsAre("Let's travel to zurich")); + EXPECT_THAT(GetMatches(content, result_snippet_two.entries(0)), + ElementsAre("zurich")); + + search_spec.reset(); + scoring_spec.reset(); + result_spec.reset(); + + results = icing.GetNextPage(results.next_page_token()); + EXPECT_THAT(results.status(), ProtoIsOk()); + ASSERT_THAT(results.results(), SizeIs(1)); + + const DocumentProto& result_document_one = results.results(0).document(); + const SnippetProto& result_snippet_one = results.results(0).snippet(); + EXPECT_THAT(result_document_one, EqualsProto(document_one)); + EXPECT_THAT(result_snippet_one.entries(), SizeIs(1)); + EXPECT_THAT(result_snippet_one.entries(0).property_name(), Eq("body")); + content = GetString(&result_document_one, + result_snippet_one.entries(0).property_name()); + EXPECT_THAT(GetWindows(content, result_snippet_one.entries(0)), ElementsAre("MDI zurich Team Meeting")); - EXPECT_THAT(GetMatches(content, result_snippet.entries(0)), + EXPECT_THAT(GetMatches(content, result_snippet_one.entries(0)), ElementsAre("zurich")); } @@ -7763,25 +7845,30 @@ TEST_F(IcingSearchEngineTest, SearchWithProjectionMultipleFieldPaths) { // 2. Issue a query that will match those documents and request only // 'sender.name' and 'subject' properties. - SearchSpecProto search_spec; - search_spec.set_term_match_type(TermMatchType::PREFIX); - search_spec.set_query("hello"); - - ResultSpecProto result_spec; + // Create all of search_spec, result_spec and scoring_spec as objects with + // scope that will end before the call to GetNextPage to ensure that the + // implementation isn't relying on references to any of them. + auto search_spec = std::make_unique<SearchSpecProto>(); + search_spec->set_term_match_type(TermMatchType::PREFIX); + search_spec->set_query("hello"); + + auto result_spec = std::make_unique<ResultSpecProto>(); // Retrieve only one result at a time to make sure that projection works when // retrieving all pages. - result_spec.set_num_per_page(1); - TypePropertyMask* email_field_mask = result_spec.add_type_property_masks(); + result_spec->set_num_per_page(1); + TypePropertyMask* email_field_mask = result_spec->add_type_property_masks(); email_field_mask->set_schema_type("Email"); email_field_mask->add_paths("sender.name"); email_field_mask->add_paths("subject"); + auto scoring_spec = std::make_unique<ScoringSpecProto>(); + *scoring_spec = GetDefaultScoringSpec(); SearchResultProto results = - icing.Search(search_spec, GetDefaultScoringSpec(), result_spec); + icing.Search(*search_spec, *scoring_spec, *result_spec); EXPECT_THAT(results.status(), ProtoIsOk()); EXPECT_THAT(results.results(), SizeIs(1)); - // 3. Verify that the returned results only contain the 'sender.name' + // 3. Verify that the first returned result only contains the 'sender.name' // property. DocumentProto projected_document_two = DocumentBuilder() @@ -7799,6 +7886,14 @@ TEST_F(IcingSearchEngineTest, SearchWithProjectionMultipleFieldPaths) { EXPECT_THAT(results.results(0).document(), EqualsProto(projected_document_two)); + // 4. Now, delete all of the specs used in the search. GetNextPage should have + // no problem because it shouldn't be keeping any references to them. + search_spec.reset(); + result_spec.reset(); + scoring_spec.reset(); + + // 5. Verify that the second returned result only contains the 'sender.name' + // property. results = icing.GetNextPage(results.next_page_token()); EXPECT_THAT(results.status(), ProtoIsOk()); EXPECT_THAT(results.results(), SizeIs(1)); @@ -7882,6 +7977,7 @@ TEST_F(IcingSearchEngineTest, QueryStatsProtoTest) { exp_stats.set_scoring_latency_ms(5); exp_stats.set_ranking_latency_ms(5); exp_stats.set_document_retrieval_latency_ms(5); + exp_stats.set_lock_acquisition_latency_ms(5); EXPECT_THAT(search_result.query_stats(), EqualsProto(exp_stats)); // Second page, 2 result with 1 snippet @@ -7897,6 +7993,7 @@ TEST_F(IcingSearchEngineTest, QueryStatsProtoTest) { exp_stats.set_num_results_with_snippets(1); exp_stats.set_latency_ms(5); exp_stats.set_document_retrieval_latency_ms(5); + exp_stats.set_lock_acquisition_latency_ms(5); EXPECT_THAT(search_result.query_stats(), EqualsProto(exp_stats)); // Third page, 1 result with 0 snippets @@ -7912,6 +8009,7 @@ TEST_F(IcingSearchEngineTest, QueryStatsProtoTest) { exp_stats.set_num_results_with_snippets(0); exp_stats.set_latency_ms(5); exp_stats.set_document_retrieval_latency_ms(5); + exp_stats.set_lock_acquisition_latency_ms(5); EXPECT_THAT(search_result.query_stats(), EqualsProto(exp_stats)); } @@ -8860,170 +8958,219 @@ TEST_F(IcingSearchEngineTest, GetDebugInfoWithSchemaNoDocumentsSucceeds) { ASSERT_THAT(result.status(), ProtoIsOk()); } -#ifndef ICING_JNI_TEST -// We skip this test case when we're running in a jni_test since the data files -// will be stored in the android-instrumented storage location, rather than the -// normal cc_library runfiles directory. To get that storage location, it's -// recommended to use the TestStorage APIs which handles different API -// levels/absolute vs relative/etc differences. Since that's only accessible on -// the java-side, and I haven't figured out a way to pass that directory path to -// this native side yet, we're just going to disable this. The functionality is -// already well-tested across 4 different emulated OS's so we're not losing much -// test coverage here. -TEST_F(IcingSearchEngineTest, MigrateToPortableFileBackedProtoLog) { - // Copy the testdata files into our IcingSearchEngine directory - std::string dir_without_portable_log; - if (IsAndroidX86()) { - dir_without_portable_log = GetTestFilePath( - "icing/testdata/not_portable_log/" - "icing_search_engine_android_x86"); - } else if (IsAndroidArm()) { - dir_without_portable_log = GetTestFilePath( - "icing/testdata/not_portable_log/" - "icing_search_engine_android_arm"); - } else if (IsIosPlatform()) { - dir_without_portable_log = GetTestFilePath( - "icing/testdata/not_portable_log/" - "icing_search_engine_ios"); - } else { - dir_without_portable_log = GetTestFilePath( - "icing/testdata/not_portable_log/" - "icing_search_engine_linux"); - } - - // Create dst directory that we'll initialize the IcingSearchEngine over. - std::string base_dir = GetTestBaseDir() + "_migrate"; - ASSERT_THAT(filesystem()->DeleteDirectoryRecursively(base_dir.c_str()), true); - ASSERT_THAT(filesystem()->CreateDirectoryRecursively(base_dir.c_str()), true); - - ASSERT_TRUE(filesystem()->CopyDirectory(dir_without_portable_log.c_str(), - base_dir.c_str(), - /*recursive=*/true)); - - IcingSearchEngineOptions icing_options; - icing_options.set_base_dir(base_dir); - - IcingSearchEngine icing(icing_options, GetTestJniCache()); - InitializeResultProto init_result = icing.Initialize(); - EXPECT_THAT(init_result.status(), ProtoIsOk()); - EXPECT_THAT(init_result.initialize_stats().document_store_data_status(), - Eq(InitializeStatsProto::NO_DATA_LOSS)); - EXPECT_THAT(init_result.initialize_stats().document_store_recovery_cause(), - Eq(InitializeStatsProto::LEGACY_DOCUMENT_LOG_FORMAT)); - EXPECT_THAT(init_result.initialize_stats().schema_store_recovery_cause(), - Eq(InitializeStatsProto::NONE)); - EXPECT_THAT(init_result.initialize_stats().index_restoration_cause(), - Eq(InitializeStatsProto::NONE)); - - // Set up schema, this is the one used to validate documents in the testdata - // files. Do not change unless you're also updating the testdata files. +TEST_F(IcingSearchEngineTest, IcingShouldWorkFor64Sections) { + // Create a schema with 64 sections SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder() - .SetType("email") + // Person has 4 sections. + .SetType("Person") .AddProperty( PropertyConfigBuilder() - .SetName("subject") - .SetDataTypeString(MATCH_EXACT, TOKENIZER_PLAIN) + .SetName("firstName") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("lastName") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) .SetCardinality(CARDINALITY_OPTIONAL)) .AddProperty( PropertyConfigBuilder() + .SetName("emailAddress") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("phoneNumber") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType(SchemaTypeConfigBuilder() + // Email has 16 sections. + .SetType("Email") + .AddProperty( + PropertyConfigBuilder() .SetName("body") - .SetDataTypeString(MATCH_EXACT, TOKENIZER_PLAIN) + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("date") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("time") + .SetDataTypeString(MATCH_PREFIX, TOKENIZER_PLAIN) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("sender") + .SetDataTypeDocument( + "Person", /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("receiver") + .SetDataTypeDocument( + "Person", /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("cc") + .SetDataTypeDocument( + "Person", /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_REPEATED))) + .AddType(SchemaTypeConfigBuilder() + // EmailCollection has 64 sections. + .SetType("EmailCollection") + .AddProperty( + PropertyConfigBuilder() + .SetName("email1") + .SetDataTypeDocument( + "Email", /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("email2") + .SetDataTypeDocument( + "Email", /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("email3") + .SetDataTypeDocument( + "Email", /*index_nested_properties=*/true) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("email4") + .SetDataTypeDocument( + "Email", /*index_nested_properties=*/true) .SetCardinality(CARDINALITY_OPTIONAL))) .Build(); - // Make sure our schema is still the same as we expect. If not, there's - // definitely no way we're getting the documents back that we expect. - GetSchemaResultProto expected_get_schema_result_proto; - expected_get_schema_result_proto.mutable_status()->set_code(StatusProto::OK); - *expected_get_schema_result_proto.mutable_schema() = schema; - ASSERT_THAT(icing.GetSchema(), EqualsProto(expected_get_schema_result_proto)); - - // These are the documents that are stored in the testdata files. Do not - // change unless you're also updating the testdata files. - DocumentProto document1 = DocumentBuilder() - .SetKey("namespace1", "uri1") - .SetSchema("email") - .SetCreationTimestampMs(10) - .AddStringProperty("subject", "foo") - .AddStringProperty("body", "bar") - .Build(); - - DocumentProto document2 = DocumentBuilder() - .SetKey("namespace1", "uri2") - .SetSchema("email") - .SetCreationTimestampMs(20) - .SetScore(321) - .AddStringProperty("body", "baz bat") - .Build(); - - DocumentProto document3 = DocumentBuilder() - .SetKey("namespace2", "uri1") - .SetSchema("email") - .SetCreationTimestampMs(30) - .SetScore(123) - .AddStringProperty("subject", "phoo") - .Build(); + DocumentProto person1 = + DocumentBuilder() + .SetKey("namespace", "person1") + .SetSchema("Person") + .AddStringProperty("firstName", "first1") + .AddStringProperty("lastName", "last1") + .AddStringProperty("emailAddress", "email1@gmail.com") + .AddStringProperty("phoneNumber", "000-000-001") + .Build(); + DocumentProto person2 = + DocumentBuilder() + .SetKey("namespace", "person2") + .SetSchema("Person") + .AddStringProperty("firstName", "first2") + .AddStringProperty("lastName", "last2") + .AddStringProperty("emailAddress", "email2@gmail.com") + .AddStringProperty("phoneNumber", "000-000-002") + .Build(); + DocumentProto person3 = + DocumentBuilder() + .SetKey("namespace", "person3") + .SetSchema("Person") + .AddStringProperty("firstName", "first3") + .AddStringProperty("lastName", "last3") + .AddStringProperty("emailAddress", "email3@gmail.com") + .AddStringProperty("phoneNumber", "000-000-003") + .Build(); + DocumentProto email1 = DocumentBuilder() + .SetKey("namespace", "email1") + .SetSchema("Email") + .AddStringProperty("body", "test body") + .AddStringProperty("subject", "test subject") + .AddStringProperty("date", "2022-08-01") + .AddStringProperty("time", "1:00 PM") + .AddDocumentProperty("sender", person1) + .AddDocumentProperty("receiver", person2) + .AddDocumentProperty("cc", person3) + .Build(); + DocumentProto email2 = DocumentBuilder() + .SetKey("namespace", "email2") + .SetSchema("Email") + .AddStringProperty("body", "test body") + .AddStringProperty("subject", "test subject") + .AddStringProperty("date", "2022-08-02") + .AddStringProperty("time", "2:00 PM") + .AddDocumentProperty("sender", person2) + .AddDocumentProperty("receiver", person1) + .AddDocumentProperty("cc", person3) + .Build(); + DocumentProto email3 = DocumentBuilder() + .SetKey("namespace", "email3") + .SetSchema("Email") + .AddStringProperty("body", "test body") + .AddStringProperty("subject", "test subject") + .AddStringProperty("date", "2022-08-03") + .AddStringProperty("time", "3:00 PM") + .AddDocumentProperty("sender", person3) + .AddDocumentProperty("receiver", person1) + .AddDocumentProperty("cc", person2) + .Build(); + DocumentProto email4 = DocumentBuilder() + .SetKey("namespace", "email4") + .SetSchema("Email") + .AddStringProperty("body", "test body") + .AddStringProperty("subject", "test subject") + .AddStringProperty("date", "2022-08-04") + .AddStringProperty("time", "4:00 PM") + .AddDocumentProperty("sender", person3) + .AddDocumentProperty("receiver", person2) + .AddDocumentProperty("cc", person1) + .Build(); + DocumentProto email_collection = + DocumentBuilder() + .SetKey("namespace", "email_collection") + .SetSchema("EmailCollection") + .AddDocumentProperty("email1", email1) + .AddDocumentProperty("email2", email2) + .AddDocumentProperty("email3", email3) + .AddDocumentProperty("email4", email4) + .SetCreationTimestampMs(kDefaultCreationTimestampMs) + .Build(); - // Document 1 and 3 were put normally, and document 2 was deleted in our - // testdata files. - EXPECT_THAT(icing - .Get(document1.namespace_(), document1.uri(), - GetResultSpecProto::default_instance()) - .document(), - EqualsProto(document1)); - EXPECT_THAT(icing - .Get(document2.namespace_(), document2.uri(), - GetResultSpecProto::default_instance()) - .status(), - ProtoStatusIs(StatusProto::NOT_FOUND)); - EXPECT_THAT(icing - .Get(document3.namespace_(), document3.uri(), - GetResultSpecProto::default_instance()) - .document(), - EqualsProto(document3)); + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(schema).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email_collection).status(), ProtoIsOk()); + + const std::vector<std::string> query_terms = { + "first1", "last2", "email3@gmail.com", "000-000-001", + "body", "subject", "2022-08-02", "3\\:00"}; + SearchResultProto expected_document; + expected_document.mutable_status()->set_code(StatusProto::OK); + *expected_document.mutable_results()->Add()->mutable_document() = + email_collection; + for (const std::string& query_term : query_terms) { + SearchSpecProto search_spec; + search_spec.set_term_match_type(TermMatchType::PREFIX); + search_spec.set_query(query_term); + SearchResultProto actual_results = + icing.Search(search_spec, GetDefaultScoringSpec(), + ResultSpecProto::default_instance()); + EXPECT_THAT(actual_results, + EqualsSearchResultIgnoreStatsAndScores(expected_document)); + } - // Searching for "foo" should get us document1. SearchSpecProto search_spec; search_spec.set_term_match_type(TermMatchType::PREFIX); search_spec.set_query("foo"); - - SearchResultProto expected_document1; - expected_document1.mutable_status()->set_code(StatusProto::OK); - *expected_document1.mutable_results()->Add()->mutable_document() = document1; - + SearchResultProto expected_no_documents; + expected_no_documents.mutable_status()->set_code(StatusProto::OK); SearchResultProto actual_results = icing.Search(search_spec, GetDefaultScoringSpec(), ResultSpecProto::default_instance()); EXPECT_THAT(actual_results, - EqualsSearchResultIgnoreStatsAndScores(expected_document1)); - - // Searching for "baz" would've gotten us document2, except it got deleted. - // Make sure that it's cleared from our index too. - search_spec.set_query("baz"); - - SearchResultProto expected_no_documents; - expected_no_documents.mutable_status()->set_code(StatusProto::OK); - - actual_results = icing.Search(search_spec, GetDefaultScoringSpec(), - ResultSpecProto::default_instance()); - EXPECT_THAT(actual_results, EqualsSearchResultIgnoreStatsAndScores(expected_no_documents)); - - // Searching for "phoo" should get us document3. - search_spec.set_query("phoo"); - - SearchResultProto expected_document3; - expected_document3.mutable_status()->set_code(StatusProto::OK); - *expected_document3.mutable_results()->Add()->mutable_document() = document3; - - actual_results = icing.Search(search_spec, GetDefaultScoringSpec(), - ResultSpecProto::default_instance()); - EXPECT_THAT(actual_results, - EqualsSearchResultIgnoreStatsAndScores(expected_document3)); } -#endif // !ICING_JNI_TEST } // namespace } // namespace lib diff --git a/icing/index/hit/doc-hit-info.cc b/icing/index/hit/doc-hit-info.cc deleted file mode 100644 index 8e418c8..0000000 --- a/icing/index/hit/doc-hit-info.cc +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (C) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "icing/index/hit/doc-hit-info.h" - -#include "icing/legacy/core/icing-string-util.h" - -namespace icing { -namespace lib { - -bool DocHitInfo::operator<(const DocHitInfo& other) const { - if (document_id() != other.document_id()) { - // Sort by document_id descending. This mirrors how the individual hits that - // are collapsed into this DocHitInfo would sort with other hits - - // document_ids are inverted when encoded in hits. Hits are encoded this way - // because they are appended to posting lists and the most recent value - // appended to a posting list must have the smallest encoded value of any - // hit on the posting list. - return document_id() > other.document_id(); - } - if (hit_section_ids_mask() != other.hit_section_ids_mask()) { - return hit_section_ids_mask() < other.hit_section_ids_mask(); - } - // Doesn't matter which way we compare this array, as long as - // DocHitInfo is unequal when it is unequal. - return memcmp(hit_term_frequency_, other.hit_term_frequency_, - sizeof(hit_term_frequency_)) < 0; -} - -void DocHitInfo::UpdateSection(SectionId section_id, - Hit::TermFrequency hit_term_frequency) { - SectionIdMask section_id_mask = (1u << section_id); - if ((hit_section_ids_mask() & section_id_mask)) { - // If the sectionId is already embedded in the hit_section_ids_mask, - // then the term frequencies should always match. So there is no - // need to update anything. - return; - } - hit_term_frequency_[section_id] = hit_term_frequency; - hit_section_ids_mask_ |= section_id_mask; -} - -void DocHitInfo::MergeSectionsFrom(const DocHitInfo& other) { - SectionIdMask other_mask = other.hit_section_ids_mask(); - while (other_mask) { - SectionId section_id = __builtin_ctz(other_mask); - UpdateSection(section_id, other.hit_term_frequency(section_id)); - other_mask &= ~(1u << section_id); - } -} - -} // namespace lib -} // namespace icing diff --git a/icing/index/hit/doc-hit-info.h b/icing/index/hit/doc-hit-info.h index 0be87d6..2770de2 100644 --- a/icing/index/hit/doc-hit-info.h +++ b/icing/index/hit/doc-hit-info.h @@ -26,19 +26,15 @@ namespace icing { namespace lib { // DocHitInfo provides a collapsed view of all hits for a specific doc. -// Hits contain a document_id, section_id and a term frequency. The -// information in multiple hits is collapse into a DocHitInfo by providing a -// SectionIdMask of all sections that contained a hit for this term as well as -// the highest term frequency of any hit for each section. +// Hits contain a document_id and section_id. The information in multiple hits +// is collapse into a DocHitInfo by providing a SectionIdMask of all sections +// that contained a hit for this term. class DocHitInfo { public: explicit DocHitInfo(DocumentId document_id_in = kInvalidDocumentId, SectionIdMask hit_section_ids_mask = kSectionIdMaskNone) : document_id_(document_id_in), - hit_section_ids_mask_(hit_section_ids_mask) { - memset(hit_term_frequency_, Hit::kNoTermFrequency, - sizeof(hit_term_frequency_)); - } + hit_section_ids_mask_(hit_section_ids_mask) {} DocumentId document_id() const { return document_id_; } @@ -50,41 +46,44 @@ class DocHitInfo { hit_section_ids_mask_ = section_id_mask; } - Hit::TermFrequency hit_term_frequency(SectionId section_id) const { - return hit_term_frequency_[section_id]; + bool operator<(const DocHitInfo& other) const { + if (document_id() != other.document_id()) { + // Sort by document_id descending. This mirrors how the individual hits + // that are collapsed into this DocHitInfo would sort with other hits - + // document_ids are inverted when encoded in hits. Hits are encoded this + // way because they are appended to posting lists and the most recent + // value appended to a posting list must have the smallest encoded value + // of any hit on the posting list. + return document_id() > other.document_id(); + } + return hit_section_ids_mask() < other.hit_section_ids_mask(); } - - bool operator<(const DocHitInfo& other) const; bool operator==(const DocHitInfo& other) const { - return (*this < other) == (other < *this); + return document_id_ == other.document_id_ && + hit_section_ids_mask_ == other.hit_section_ids_mask_; } - // Updates the hit_section_ids_mask and hit_term_frequency for the - // section, if necessary. - void UpdateSection(SectionId section_id, - Hit::TermFrequency hit_term_frequency); + // Updates the hit_section_ids_mask for the section, if necessary. + void UpdateSection(SectionId section_id) { + hit_section_ids_mask_ |= (UINT64_C(1) << section_id); + } - // Merges the sections of other into this. The hit_section_ids_masks are or'd; - // if this.hit_term_frequency_[sectionId] has already been defined, - // other.hit_term_frequency_[sectionId] value is ignored. + // Merges the sections of other into this. The hit_section_ids_masks are or'd. // // This does not affect the DocumentId of this or other. If callers care about // only merging sections for DocHitInfos with the same DocumentId, callers // should check this themselves. - void MergeSectionsFrom(const DocHitInfo& other); + void MergeSectionsFrom(const SectionIdMask& other_hit_section_ids_mask) { + hit_section_ids_mask_ |= other_hit_section_ids_mask; + } private: DocumentId document_id_; SectionIdMask hit_section_ids_mask_; - Hit::TermFrequency hit_term_frequency_[kMaxSectionId + 1]; } __attribute__((packed)); -static_assert(sizeof(DocHitInfo) == 22, ""); +static_assert(sizeof(DocHitInfo) == 12, ""); // TODO(b/138991332) decide how to remove/replace all is_packed_pod assertions. static_assert(icing_is_packed_pod<DocHitInfo>::value, "go/icing-ubsan"); -static_assert( - sizeof(Hit::TermFrequency) == 1, - "Change how hit_term_frequency_ is initialized if changing the type " - "of Hit::TermFrequency"); } // namespace lib } // namespace icing diff --git a/icing/index/hit/doc-hit-info_test.cc b/icing/index/hit/doc-hit-info_test.cc index 36c1a06..13eca9a 100644 --- a/icing/index/hit/doc-hit-info_test.cc +++ b/icing/index/hit/doc-hit-info_test.cc @@ -14,142 +14,29 @@ #include "icing/index/hit/doc-hit-info.h" -#include "icing/index/hit/hit.h" -#include "icing/schema/section.h" -#include "icing/store/document-id.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "icing/schema/section.h" +#include "icing/store/document-id.h" namespace icing { namespace lib { using ::testing::ElementsAre; -using ::testing::Eq; -using ::testing::IsTrue; using ::testing::Ne; -constexpr DocumentId kSomeDocumentId = 12; -constexpr DocumentId kSomeOtherDocumentId = 54; - -TEST(DocHitInfoTest, InitialMaxHitTermFrequencies) { - DocHitInfo info(kSomeDocumentId); - for (SectionId i = 0; i <= kMaxSectionId; ++i) { - EXPECT_THAT(info.hit_term_frequency(i), Eq(Hit::kNoTermFrequency)); - } -} - -TEST(DocHitInfoTest, UpdateHitTermFrequenciesForTheFirstTime) { - DocHitInfo info(kSomeDocumentId); - ASSERT_THAT(info.hit_term_frequency(3), Eq(Hit::kNoTermFrequency)); - - // Updating a section for the first time, should change its hit - // term_frequency - info.UpdateSection(3, 16); - EXPECT_THAT(info.hit_term_frequency(3), Eq(16)); -} - -TEST(DocHitInfoTest, UpdateSectionLowerHitTermFrequencyHasNoEffect) { - DocHitInfo info(kSomeDocumentId); - info.UpdateSection(3, 16); - ASSERT_THAT(info.hit_term_frequency(3), Eq(16)); - - // Updating a section with a term frequency lower than the previously set - // one should have no effect. - info.UpdateSection(3, 15); - EXPECT_THAT(info.hit_term_frequency(3), Eq(16)); -} - -TEST(DocHitInfoTest, UpdateSectionHigherHitTermFrequencyHasNoEffect) { - DocHitInfo info(kSomeDocumentId); - info.UpdateSection(3, 16); - ASSERT_THAT(info.hit_term_frequency(3), Eq(16)); - - // Updating a section with a term frequency higher than the previously set - // one should have no effect. - info.UpdateSection(3, 17); - EXPECT_THAT(info.hit_term_frequency(3), Eq(16)); -} - -TEST(DocHitInfoTest, UpdateSectionIdMask) { - DocHitInfo info(kSomeDocumentId); - EXPECT_THAT(info.hit_section_ids_mask(), Eq(kSectionIdMaskNone)); - - info.UpdateSection(3, 16); - EXPECT_THAT(info.hit_section_ids_mask() & 1U << 3, IsTrue()); - - // Calling update again shouldn't do anything - info.UpdateSection(3, 15); - EXPECT_THAT(info.hit_section_ids_mask() & 1U << 3, IsTrue()); - - // Updating another section shouldn't do anything - info.UpdateSection(2, 77); - EXPECT_THAT(info.hit_section_ids_mask() & 1U << 3, IsTrue()); -} - -TEST(DocHitInfoTest, MergeSectionsFromDifferentDocumentId) { - // Merging infos with different document_ids works. - DocHitInfo info1(kSomeDocumentId); - DocHitInfo info2(kSomeOtherDocumentId); - info2.UpdateSection(7, 12); - info1.MergeSectionsFrom(info2); - EXPECT_THAT(info1.hit_term_frequency(7), Eq(12)); - EXPECT_THAT(info1.document_id(), Eq(kSomeDocumentId)); -} - -TEST(DocHitInfoTest, MergeSectionsFromKeepsOldSection) { - // Merging shouldn't override sections that are present info1, but not present - // in info2. - DocHitInfo info1(kSomeDocumentId); - info1.UpdateSection(3, 16); - DocHitInfo info2(kSomeDocumentId); - info1.MergeSectionsFrom(info2); - EXPECT_THAT(info1.hit_term_frequency(3), Eq(16)); -} - -TEST(DocHitInfoTest, MergeSectionsFromAddsNewSection) { - // Merging should add sections that were not present in info1, but are present - // in info2. - DocHitInfo info1(kSomeDocumentId); - DocHitInfo info2(kSomeDocumentId); - info2.UpdateSection(7, 12); - info1.MergeSectionsFrom(info2); - EXPECT_THAT(info1.hit_term_frequency(7), Eq(12)); -} - -TEST(DocHitInfoTest, MergeSectionsFromHigherHitTermFrequencyHasNoEffect) { - // Merging should not override the value of a section in info1 if the same - // section is present in info2. - DocHitInfo info1(kSomeDocumentId); - info1.UpdateSection(2, 77); - DocHitInfo info2(kSomeDocumentId); - info2.UpdateSection(2, 89); - info1.MergeSectionsFrom(info2); - EXPECT_THAT(info1.hit_term_frequency(2), Eq(77)); -} - -TEST(DocHitInfoTest, MergeSectionsFromLowerHitScoreHasNoEffect) { - // Merging should not override the hit score of a section in info1 if the same - // section is present in info2. - DocHitInfo info1(kSomeDocumentId); - info1.UpdateSection(5, 108); - DocHitInfo info2(kSomeDocumentId); - info2.UpdateSection(5, 13); - info1.MergeSectionsFrom(info2); - EXPECT_THAT(info1.hit_term_frequency(5), Eq(108)); -} - TEST(DocHitInfoTest, Comparison) { constexpr DocumentId kDocumentId = 1; DocHitInfo info(kDocumentId); - info.UpdateSection(1, 12); + info.UpdateSection(1); constexpr DocumentId kHighDocumentId = 15; DocHitInfo high_document_id_info(kHighDocumentId); - high_document_id_info.UpdateSection(1, 12); + high_document_id_info.UpdateSection(1); DocHitInfo high_section_id_info(kDocumentId); - high_section_id_info.UpdateSection(1, 12); - high_section_id_info.UpdateSection(6, Hit::kDefaultTermFrequency); + high_section_id_info.UpdateSection(1); + high_section_id_info.UpdateSection(6); std::vector<DocHitInfo> infos{info, high_document_id_info, high_section_id_info}; @@ -160,7 +47,7 @@ TEST(DocHitInfoTest, Comparison) { // There are no requirements for how DocHitInfos with the same DocumentIds and // hit masks will compare, but they must not be equal. DocHitInfo different_term_frequency_info(kDocumentId); - different_term_frequency_info.UpdateSection(1, 76); + different_term_frequency_info.UpdateSection(2); EXPECT_THAT(info < different_term_frequency_info, Ne(different_term_frequency_info < info)); } diff --git a/icing/index/hit/hit.h b/icing/index/hit/hit.h index f8cbd78..35c9238 100644 --- a/icing/index/hit/hit.h +++ b/icing/index/hit/hit.h @@ -15,6 +15,7 @@ #ifndef ICING_INDEX_HIT_HIT_H_ #define ICING_INDEX_HIT_HIT_H_ +#include <array> #include <cstdint> #include <limits> @@ -54,6 +55,7 @@ class Hit { // The Term Frequency of a Hit. using TermFrequency = uint8_t; + using TermFrequencyArray = std::array<Hit::TermFrequency, kTotalNumSections>; // Max TermFrequency is 255. static constexpr TermFrequency kMaxTermFrequency = std::numeric_limits<TermFrequency>::max(); diff --git a/icing/index/index-processor_test.cc b/icing/index/index-processor_test.cc index 7746688..92ced61 100644 --- a/icing/index/index-processor_test.cc +++ b/icing/index/index-processor_test.cc @@ -32,6 +32,7 @@ #include "icing/file/filesystem.h" #include "icing/index/hit/doc-hit-info.h" #include "icing/index/index.h" +#include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/index/term-property-id.h" #include "icing/legacy/index/icing-filesystem.h" @@ -246,6 +247,20 @@ std::vector<DocHitInfo> GetHits(std::unique_ptr<DocHitInfoIterator> iterator) { return infos; } +std::vector<DocHitInfoTermFrequencyPair> GetHitsWithTermFrequency( + std::unique_ptr<DocHitInfoIterator> iterator) { + std::vector<DocHitInfoTermFrequencyPair> infos; + while (iterator->Advance().ok()) { + std::vector<TermMatchInfo> matched_terms_stats; + iterator->PopulateMatchedTermsStats(&matched_terms_stats); + for (const TermMatchInfo& term_match_info : matched_terms_stats) { + infos.push_back(DocHitInfoTermFrequencyPair( + iterator->doc_hit_info(), term_match_info.term_frequencies)); + } + } + return infos; +} + TEST_F(IndexProcessorTest, CreationWithNullPointerShouldFail) { EXPECT_THAT(IndexProcessor::Create(/*normalizer=*/nullptr, index_.get(), &fake_clock_), @@ -308,7 +323,8 @@ TEST_F(IndexProcessorTest, OneDoc) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, index_->GetIterator("hello", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); - std::vector<DocHitInfo> hits = GetHits(std::move(itr)); + std::vector<DocHitInfoTermFrequencyPair> hits = + GetHitsWithTermFrequency(std::move(itr)); std::unordered_map<SectionId, Hit::TermFrequency> expectedMap{ {kExactSectionId, 1}}; EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfoWithTermFrequency( @@ -360,7 +376,8 @@ TEST_F(IndexProcessorTest, MultipleDocs) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, index_->GetIterator("world", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); - std::vector<DocHitInfo> hits = GetHits(std::move(itr)); + std::vector<DocHitInfoTermFrequencyPair> hits = + GetHitsWithTermFrequency(std::move(itr)); std::unordered_map<SectionId, Hit::TermFrequency> expectedMap1{ {kPrefixedSectionId, 2}}; std::unordered_map<SectionId, Hit::TermFrequency> expectedMap2{ @@ -373,7 +390,7 @@ TEST_F(IndexProcessorTest, MultipleDocs) { ICING_ASSERT_OK_AND_ASSIGN( itr, index_->GetIterator("world", 1U << kPrefixedSectionId, TermMatchType::EXACT_ONLY)); - hits = GetHits(std::move(itr)); + hits = GetHitsWithTermFrequency(std::move(itr)); std::unordered_map<SectionId, Hit::TermFrequency> expectedMap{ {kPrefixedSectionId, 2}}; EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfoWithTermFrequency( @@ -382,7 +399,7 @@ TEST_F(IndexProcessorTest, MultipleDocs) { ICING_ASSERT_OK_AND_ASSIGN(itr, index_->GetIterator("coffee", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); - hits = GetHits(std::move(itr)); + hits = GetHitsWithTermFrequency(std::move(itr)); expectedMap = {{kExactSectionId, Hit::kMaxTermFrequency}}; EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfoWithTermFrequency( kDocumentId1, expectedMap))); @@ -838,7 +855,8 @@ TEST_F(IndexProcessorTest, ExactVerbatimProperty) { std::unique_ptr<DocHitInfoIterator> itr, index_->GetIterator("Hello, world!", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); - std::vector<DocHitInfo> hits = GetHits(std::move(itr)); + std::vector<DocHitInfoTermFrequencyPair> hits = + GetHitsWithTermFrequency(std::move(itr)); std::unordered_map<SectionId, Hit::TermFrequency> expectedMap{ {kExactVerbatimSectionId, 1}}; @@ -869,7 +887,8 @@ TEST_F(IndexProcessorTest, PrefixVerbatimProperty) { ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, index_->GetIterator("Hello, w", kSectionIdMaskAll, TermMatchType::PREFIX)); - std::vector<DocHitInfo> hits = GetHits(std::move(itr)); + std::vector<DocHitInfoTermFrequencyPair> hits = + GetHitsWithTermFrequency(std::move(itr)); std::unordered_map<SectionId, Hit::TermFrequency> expectedMap{ {kPrefixedVerbatimSectionId, 1}}; diff --git a/icing/index/index.cc b/icing/index/index.cc index 1d863cc..5306520 100644 --- a/icing/index/index.cc +++ b/icing/index/index.cc @@ -183,21 +183,24 @@ libtextclassifier3::Status Index::TruncateTo(DocumentId document_id) { libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> Index::GetIterator(const std::string& term, SectionIdMask section_id_mask, - TermMatchType::Code term_match_type) { + TermMatchType::Code term_match_type, + bool need_hit_term_frequency) { std::unique_ptr<DocHitInfoIterator> lite_itr; std::unique_ptr<DocHitInfoIterator> main_itr; switch (term_match_type) { case TermMatchType::EXACT_ONLY: lite_itr = std::make_unique<DocHitInfoIteratorTermLiteExact>( - term_id_codec_.get(), lite_index_.get(), term, section_id_mask); + term_id_codec_.get(), lite_index_.get(), term, section_id_mask, + need_hit_term_frequency); main_itr = std::make_unique<DocHitInfoIteratorTermMainExact>( - main_index_.get(), term, section_id_mask); + main_index_.get(), term, section_id_mask, need_hit_term_frequency); break; case TermMatchType::PREFIX: lite_itr = std::make_unique<DocHitInfoIteratorTermLitePrefix>( - term_id_codec_.get(), lite_index_.get(), term, section_id_mask); + term_id_codec_.get(), lite_index_.get(), term, section_id_mask, + need_hit_term_frequency); main_itr = std::make_unique<DocHitInfoIteratorTermMainPrefix>( - main_index_.get(), term, section_id_mask); + main_index_.get(), term, section_id_mask, need_hit_term_frequency); break; default: return absl_ports::InvalidArgumentError( @@ -265,11 +268,13 @@ IndexStorageInfoProto Index::GetStorageInfo() const { } libtextclassifier3::Status Index::Optimize( - const std::vector<DocumentId>& document_id_old_to_new) { + const std::vector<DocumentId>& document_id_old_to_new, + DocumentId new_last_added_document_id) { if (main_index_->last_added_document_id() != kInvalidDocumentId) { ICING_RETURN_IF_ERROR(main_index_->Optimize(document_id_old_to_new)); } - return lite_index_->Optimize(document_id_old_to_new, term_id_codec_.get()); + return lite_index_->Optimize(document_id_old_to_new, term_id_codec_.get(), + new_last_added_document_id); } libtextclassifier3::Status Index::Editor::BufferTerm(const char* term) { diff --git a/icing/index/index.h b/icing/index/index.h index 748acb0..9d4e5ac 100644 --- a/icing/index/index.h +++ b/icing/index/index.h @@ -185,7 +185,7 @@ class Index { // INVALID_ARGUMENT if given an invalid term_match_type libtextclassifier3::StatusOr<std::unique_ptr<DocHitInfoIterator>> GetIterator( const std::string& term, SectionIdMask section_id_mask, - TermMatchType::Code term_match_type); + TermMatchType::Code term_match_type, bool need_hit_term_frequency = true); // Finds terms with the given prefix in the given namespaces. If // 'namespace_ids' is empty, returns results from all the namespaces. Results @@ -264,13 +264,16 @@ class Index { } // Reduces internal file sizes by reclaiming space of deleted documents. + // new_last_added_document_id will be used to update the last added document + // id in the lite index. // // Returns: // OK on success // INTERNAL_ERROR on IO error, this indicates that the index may be in an // invalid state and should be cleared. libtextclassifier3::Status Optimize( - const std::vector<DocumentId>& document_id_old_to_new); + const std::vector<DocumentId>& document_id_old_to_new, + DocumentId new_last_added_document_id); private: Index(const Options& options, std::unique_ptr<TermIdCodec> term_id_codec, diff --git a/icing/index/index_test.cc b/icing/index/index_test.cc index 7323603..995f501 100644 --- a/icing/index/index_test.cc +++ b/icing/index/index_test.cc @@ -21,6 +21,7 @@ #include <random> #include <string> #include <string_view> +#include <unordered_map> #include <utility> #include <vector> @@ -120,7 +121,7 @@ MATCHER_P2(EqualsDocHitInfo, document_id, sections, "") { const DocHitInfo& actual = arg; SectionIdMask section_mask = kSectionIdMaskNone; for (SectionId section : sections) { - section_mask |= 1U << section; + section_mask |= UINT64_C(1) << section; } *result_listener << "actual is {document_id=" << actual.document_id() << ", section_mask=" << actual.hit_section_ids_mask() @@ -267,7 +268,8 @@ TEST_F(IndexTest, SingleHitSingleTermIndexAfterOptimize) { EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); index_->set_last_added_document_id(kDocumentId2); - ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2})); + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(ElementsAre(EqualsDocHitInfo( kDocumentId2, std::vector<SectionId>{kSectionId2})))); @@ -275,7 +277,8 @@ TEST_F(IndexTest, SingleHitSingleTermIndexAfterOptimize) { // Mapping to a different docid will translate the hit ICING_ASSERT_OK(index_->Optimize( - /*document_id_old_to_new=*/{0, kInvalidDocumentId, kDocumentId1})); + /*document_id_old_to_new=*/{0, kInvalidDocumentId, kDocumentId1}, + /*new_last_added_document_id=*/1)); EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(ElementsAre(EqualsDocHitInfo( kDocumentId1, std::vector<SectionId>{kSectionId2})))); @@ -283,10 +286,11 @@ TEST_F(IndexTest, SingleHitSingleTermIndexAfterOptimize) { // Mapping to kInvalidDocumentId will remove the hit. ICING_ASSERT_OK( - index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId})); + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(IsEmpty())); - EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); + EXPECT_EQ(index_->last_added_document_id(), kDocumentId0); } TEST_F(IndexTest, SingleHitSingleTermIndexAfterMergeAndOptimize) { @@ -298,7 +302,8 @@ TEST_F(IndexTest, SingleHitSingleTermIndexAfterMergeAndOptimize) { ICING_ASSERT_OK(index_->Merge()); - ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2})); + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(ElementsAre(EqualsDocHitInfo( kDocumentId2, std::vector<SectionId>{kSectionId2})))); @@ -306,7 +311,8 @@ TEST_F(IndexTest, SingleHitSingleTermIndexAfterMergeAndOptimize) { // Mapping to a different docid will translate the hit ICING_ASSERT_OK(index_->Optimize( - /*document_id_old_to_new=*/{0, kInvalidDocumentId, kDocumentId1})); + /*document_id_old_to_new=*/{0, kInvalidDocumentId, kDocumentId1}, + /*new_last_added_document_id=*/1)); EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(ElementsAre(EqualsDocHitInfo( kDocumentId1, std::vector<SectionId>{kSectionId2})))); @@ -314,10 +320,11 @@ TEST_F(IndexTest, SingleHitSingleTermIndexAfterMergeAndOptimize) { // Mapping to kInvalidDocumentId will remove the hit. ICING_ASSERT_OK( - index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId})); + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId}, + /*new_last_added_document_id=*/0)); EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(IsEmpty())); - EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); + EXPECT_EQ(index_->last_added_document_id(), 0); } TEST_F(IndexTest, SingleHitMultiTermIndex) { @@ -369,7 +376,8 @@ TEST_F(IndexTest, MultiHitMultiTermIndexAfterOptimize) { EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); index_->set_last_added_document_id(kDocumentId2); - ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2})); + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); EXPECT_THAT( GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(ElementsAre( @@ -383,7 +391,8 @@ TEST_F(IndexTest, MultiHitMultiTermIndexAfterOptimize) { // Delete document id 1, and document id 2 is translated to 1. ICING_ASSERT_OK( - index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId, 1})); + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId, 1}, + /*new_last_added_document_id=*/1)); EXPECT_THAT( GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(ElementsAre( @@ -396,7 +405,8 @@ TEST_F(IndexTest, MultiHitMultiTermIndexAfterOptimize) { // Delete all the rest documents. ICING_ASSERT_OK(index_->Optimize( - /*document_id_old_to_new=*/{kInvalidDocumentId, kInvalidDocumentId})); + /*document_id_old_to_new=*/{kInvalidDocumentId, kInvalidDocumentId}, + /*new_last_added_document_id=*/kInvalidDocumentId)); EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(IsEmpty())); EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), @@ -423,7 +433,8 @@ TEST_F(IndexTest, MultiHitMultiTermIndexAfterMergeAndOptimize) { ICING_ASSERT_OK(index_->Merge()); - ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2})); + ICING_ASSERT_OK(index_->Optimize(/*document_id_old_to_new=*/{0, 1, 2}, + /*new_last_added_document_id=*/2)); EXPECT_THAT( GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(ElementsAre( @@ -437,7 +448,8 @@ TEST_F(IndexTest, MultiHitMultiTermIndexAfterMergeAndOptimize) { // Delete document id 1, and document id 2 is translated to 1. ICING_ASSERT_OK( - index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId, 1})); + index_->Optimize(/*document_id_old_to_new=*/{0, kInvalidDocumentId, 1}, + /*new_last_added_document_id=*/1)); EXPECT_THAT( GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(ElementsAre( @@ -450,7 +462,8 @@ TEST_F(IndexTest, MultiHitMultiTermIndexAfterMergeAndOptimize) { // Delete all the rest documents. ICING_ASSERT_OK(index_->Optimize( - /*document_id_old_to_new=*/{kInvalidDocumentId, kInvalidDocumentId})); + /*document_id_old_to_new=*/{kInvalidDocumentId, kInvalidDocumentId}, + /*new_last_added_document_id=*/kInvalidDocumentId)); EXPECT_THAT(GetHits("foo", TermMatchType::EXACT_ONLY), IsOkAndHolds(IsEmpty())); EXPECT_THAT(GetHits("bar", TermMatchType::EXACT_ONLY), @@ -756,18 +769,24 @@ TEST_F(IndexTest, PrefixToString) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<DocHitInfoIterator> itr, index_->GetIterator("foo", id_mask, TermMatchType::PREFIX)); - EXPECT_THAT(itr->ToString(), - Eq("(0000000000001100:foo* OR 0000000000001100:foo*)")); + EXPECT_THAT(itr->ToString(), Eq("(0000000000000000000000000000000000000000000" + "000000000000000001100:foo* OR " + "00000000000000000000000000000000000000000000" + "00000000000000001100:foo*)")); ICING_ASSERT_OK_AND_ASSIGN(itr, index_->GetIterator("foo", kSectionIdMaskAll, TermMatchType::PREFIX)); - EXPECT_THAT(itr->ToString(), - Eq("(1111111111111111:foo* OR 1111111111111111:foo*)")); + EXPECT_THAT(itr->ToString(), Eq("(1111111111111111111111111111111111111111111" + "111111111111111111111:foo* OR " + "11111111111111111111111111111111111111111111" + "11111111111111111111:foo*)")); ICING_ASSERT_OK_AND_ASSIGN(itr, index_->GetIterator("foo", kSectionIdMaskNone, TermMatchType::PREFIX)); - EXPECT_THAT(itr->ToString(), - Eq("(0000000000000000:foo* OR 0000000000000000:foo*)")); + EXPECT_THAT(itr->ToString(), Eq("(0000000000000000000000000000000000000000000" + "000000000000000000000:foo* OR " + "00000000000000000000000000000000000000000000" + "00000000000000000000:foo*)")); } TEST_F(IndexTest, ExactToString) { @@ -775,20 +794,26 @@ TEST_F(IndexTest, ExactToString) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<DocHitInfoIterator> itr, index_->GetIterator("foo", id_mask, TermMatchType::EXACT_ONLY)); - EXPECT_THAT(itr->ToString(), - Eq("(0000000000001100:foo OR 0000000000001100:foo)")); + EXPECT_THAT(itr->ToString(), Eq("(0000000000000000000000000000000000000000000" + "000000000000000001100:foo OR " + "00000000000000000000000000000000000000000000" + "00000000000000001100:foo)")); ICING_ASSERT_OK_AND_ASSIGN( itr, index_->GetIterator("foo", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); - EXPECT_THAT(itr->ToString(), - Eq("(1111111111111111:foo OR 1111111111111111:foo)")); + EXPECT_THAT(itr->ToString(), Eq("(1111111111111111111111111111111111111111111" + "111111111111111111111:foo OR " + "11111111111111111111111111111111111111111111" + "11111111111111111111:foo)")); ICING_ASSERT_OK_AND_ASSIGN(itr, index_->GetIterator("foo", kSectionIdMaskNone, TermMatchType::EXACT_ONLY)); - EXPECT_THAT(itr->ToString(), - Eq("(0000000000000000:foo OR 0000000000000000:foo)")); + EXPECT_THAT(itr->ToString(), Eq("(0000000000000000000000000000000000000000000" + "000000000000000000000:foo OR " + "00000000000000000000000000000000000000000000" + "00000000000000000000:foo)")); } TEST_F(IndexTest, NonAsciiTerms) { @@ -986,7 +1011,10 @@ TEST_F(IndexTest, FullIndexMerge) { TEST_F(IndexTest, OptimizeShouldWorkForEmptyIndex) { // Optimize an empty index should succeed, but have no effects. - ICING_ASSERT_OK(index_->Optimize(std::vector<DocumentId>())); + ICING_ASSERT_OK( + index_->Optimize(std::vector<DocumentId>(), + /*new_last_added_document_id=*/kInvalidDocumentId)); + EXPECT_EQ(index_->last_added_document_id(), kInvalidDocumentId); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<DocHitInfoIterator> itr, @@ -998,6 +1026,120 @@ TEST_F(IndexTest, OptimizeShouldWorkForEmptyIndex) { EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); } +TEST_F(IndexTest, IndexShouldWorkAtSectionLimit) { + std::string prefix = "prefix"; + std::default_random_engine random; + std::vector<std::string> query_terms; + // Add 2048 hits to main index, and 2048 hits to lite index. + for (int i = 0; i < 4096; ++i) { + if (i == 1024) { + ICING_ASSERT_OK(index_->Merge()); + } + // Generate a unique term for document i. + query_terms.push_back(prefix + RandomString("abcdefg", 5, &random) + + std::to_string(i)); + TermMatchType::Code term_match_type = TermMatchType::PREFIX; + SectionId section_id = i % 64; + if (section_id == 2) { + // Make section 2 an exact section. + term_match_type = TermMatchType::EXACT_ONLY; + } + Index::Editor edit = index_->Edit(/*document_id=*/i, section_id, + term_match_type, /*namespace_id=*/0); + ICING_ASSERT_OK(edit.BufferTerm(query_terms.at(i).c_str())); + ICING_ASSERT_OK(edit.IndexAllBufferedTerms()); + } + + std::vector<DocHitInfo> exp_prefix_hits; + for (int i = 0; i < 4096; ++i) { + if (i % 64 == 2) { + // Section 2 is an exact section, so we should not see any hits in + // prefix search. + continue; + } + exp_prefix_hits.push_back(DocHitInfo(i)); + exp_prefix_hits.back().UpdateSection(/*section_id=*/i % 64); + } + std::reverse(exp_prefix_hits.begin(), exp_prefix_hits.end()); + + // Check prefix search. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<DocHitInfo> hits, + GetHits(prefix, TermMatchType::PREFIX)); + EXPECT_THAT(hits, ContainerEq(exp_prefix_hits)); + + // Check exact search. + for (int i = 0; i < 4096; ++i) { + ICING_ASSERT_OK_AND_ASSIGN( + hits, GetHits(query_terms[i], TermMatchType::EXACT_ONLY)); + EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfo( + i, std::vector<SectionId>{(SectionId)(i % 64)}))); + } +} + +// Skip this test on Android because of timeout. +#if !defined(__ANDROID__) +TEST_F(IndexTest, IndexShouldWorkAtDocumentLimit) { + std::string prefix = "pre"; + std::default_random_engine random; + const int max_lite_index_size = 1024 * 1024 / 8; + int lite_index_size = 0; + for (int i = 0; i <= kMaxDocumentId; ++i) { + if (i % max_lite_index_size == 0 && i != 0) { + ICING_ASSERT_OK(index_->Merge()); + lite_index_size = 0; + } + std::string term; + TermMatchType::Code term_match_type = TermMatchType::PREFIX; + SectionId section_id = i % 64; + if (section_id == 2) { + // Make section 2 an exact section. + term_match_type = TermMatchType::EXACT_ONLY; + term = std::to_string(i); + } else { + term = prefix + RandomString("abcd", 5, &random); + } + Index::Editor edit = index_->Edit(/*document_id=*/i, section_id, + term_match_type, /*namespace_id=*/0); + ICING_ASSERT_OK(edit.BufferTerm(term.c_str())); + ICING_ASSERT_OK(edit.IndexAllBufferedTerms()); + ++lite_index_size; + index_->set_last_added_document_id(i); + } + // Ensure that the lite index still contains some data to better test both + // indexes. + ASSERT_THAT(lite_index_size, Eq(max_lite_index_size - 1)); + EXPECT_EQ(index_->last_added_document_id(), kMaxDocumentId); + + std::vector<DocHitInfo> exp_prefix_hits; + for (int i = 0; i <= kMaxDocumentId; ++i) { + if (i % 64 == 2) { + // Section 2 is an exact section, so we should not see any hits in + // prefix search. + continue; + } + exp_prefix_hits.push_back(DocHitInfo(i)); + exp_prefix_hits.back().UpdateSection(/*section_id=*/i % 64); + } + std::reverse(exp_prefix_hits.begin(), exp_prefix_hits.end()); + + // Check prefix search. + ICING_ASSERT_OK_AND_ASSIGN(std::vector<DocHitInfo> hits, + GetHits(prefix, TermMatchType::PREFIX)); + EXPECT_THAT(hits, ContainerEq(exp_prefix_hits)); + + // Check exact search. + for (int i = 0; i <= kMaxDocumentId; ++i) { + if (i % 64 == 2) { + // Only section 2 is an exact section + ICING_ASSERT_OK_AND_ASSIGN( + hits, GetHits(std::to_string(i), TermMatchType::EXACT_ONLY)); + EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfo( + i, std::vector<SectionId>{(SectionId)(2)}))); + } + } +} +#endif // if !defined(__ANDROID__) + TEST_F(IndexTest, IndexOptimize) { std::string prefix = "prefix"; std::default_random_engine random; @@ -1011,7 +1153,7 @@ TEST_F(IndexTest, IndexOptimize) { query_terms.push_back(prefix + RandomString("abcdefg", 5, &random) + std::to_string(i)); TermMatchType::Code term_match_type = TermMatchType::PREFIX; - SectionId section_id = i % 5; + SectionId section_id = i % 64; if (section_id == 2) { // Make section 2 an exact section. term_match_type = TermMatchType::EXACT_ONLY; @@ -1041,19 +1183,19 @@ TEST_F(IndexTest, IndexOptimize) { if (document_id_old_to_new[i] == kInvalidDocumentId) { continue; } - if (i % 5 == 2) { + if (i % 64 == 2) { // Section 2 is an exact section, so we should not see any hits in // prefix search. continue; } exp_prefix_hits.push_back(DocHitInfo(document_id_old_to_new[i])); - exp_prefix_hits.back().UpdateSection(/*section_id=*/i % 5, - /*hit_term_frequency=*/1); + exp_prefix_hits.back().UpdateSection(/*section_id=*/i % 64); } std::reverse(exp_prefix_hits.begin(), exp_prefix_hits.end()); // Check that optimize is correct - ICING_ASSERT_OK(index_->Optimize(document_id_old_to_new)); + ICING_ASSERT_OK( + index_->Optimize(document_id_old_to_new, new_last_added_document_id)); EXPECT_EQ(index_->last_added_document_id(), new_last_added_document_id); // Check prefix search. ICING_ASSERT_OK_AND_ASSIGN(std::vector<DocHitInfo> hits, @@ -1068,7 +1210,7 @@ TEST_F(IndexTest, IndexOptimize) { } else { EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfo( document_id_old_to_new[i], - std::vector<SectionId>{(SectionId)(i % 5)}))); + std::vector<SectionId>{(SectionId)(i % 64)}))); } } @@ -1087,7 +1229,7 @@ TEST_F(IndexTest, IndexOptimize) { } else { EXPECT_THAT(hits, ElementsAre(EqualsDocHitInfo( document_id_old_to_new[i], - std::vector<SectionId>{(SectionId)(i % 5)}))); + std::vector<SectionId>{(SectionId)(i % 64)}))); } } } @@ -1338,6 +1480,29 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnCorrectHitCount) { EqualsTermMetadata("foo", 1)))); } +TEST_F(IndexTest, FindTermByPrefixMultipleHitBatch) { + AlwaysTrueNamespaceCheckerImpl impl; + // Create multiple hit batches. + for (int i = 0; i < 4000; i++) { + Index::Editor edit = index_->Edit(i, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit.BufferTerm("fool"), IsOk()); + EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); + } + + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", + /*num_to_return=*/10, + TermMatchType::PREFIX, &impl), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("fool", 4000)))); + + ICING_ASSERT_OK(index_->Merge()); + + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", + /*num_to_return=*/10, + TermMatchType::PREFIX, &impl), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("fool", 4000)))); +} + TEST_F(IndexTest, FindTermByPrefixShouldReturnInOrder) { // Push 6 term-six, 5 term-five, 4 term-four, 3 term-three, 2 term-two and one // term-one into lite index. diff --git a/icing/index/iterator/doc-hit-info-iterator-and.cc b/icing/index/iterator/doc-hit-info-iterator-and.cc index 543e9ef..6bde8e6 100644 --- a/icing/index/iterator/doc-hit-info-iterator-and.cc +++ b/icing/index/iterator/doc-hit-info-iterator-and.cc @@ -104,7 +104,7 @@ libtextclassifier3::Status DocHitInfoIteratorAnd::Advance() { // Guaranteed that short_doc_id and long_doc_id match now doc_hit_info_ = short_->doc_hit_info(); - doc_hit_info_.MergeSectionsFrom(long_->doc_hit_info()); + doc_hit_info_.MergeSectionsFrom(long_->doc_hit_info().hit_section_ids_mask()); hit_intersect_section_ids_mask_ = short_->hit_intersect_section_ids_mask() & long_->hit_intersect_section_ids_mask(); return libtextclassifier3::Status::OK; @@ -186,7 +186,8 @@ libtextclassifier3::Status DocHitInfoIteratorAndNary::Advance() { iterators_.at(0)->hit_intersect_section_ids_mask(); for (size_t i = 1; i < iterators_.size(); i++) { - doc_hit_info_.MergeSectionsFrom(iterators_.at(i)->doc_hit_info()); + doc_hit_info_.MergeSectionsFrom( + iterators_.at(i)->doc_hit_info().hit_section_ids_mask()); hit_intersect_section_ids_mask_ &= iterators_.at(i)->hit_intersect_section_ids_mask(); } diff --git a/icing/index/iterator/doc-hit-info-iterator-and_test.cc b/icing/index/iterator/doc-hit-info-iterator-and_test.cc index 783e937..e4730fe 100644 --- a/icing/index/iterator/doc-hit-info-iterator-and_test.cc +++ b/icing/index/iterator/doc-hit-info-iterator-and_test.cc @@ -203,24 +203,24 @@ TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ - 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ + 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2{ - 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2{ + 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1 = DocHitInfo(4); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3); doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); - DocHitInfo doc_hit_info2 = DocHitInfo(4); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(4); doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); - std::vector<DocHitInfo> first_vector = {doc_hit_info1}; - std::vector<DocHitInfo> second_vector = {doc_hit_info2}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2}; auto first_iter = std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); @@ -256,15 +256,15 @@ TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ - 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ + 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1 = DocHitInfo(4); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); - std::vector<DocHitInfo> first_vector = {doc_hit_info1}; - std::vector<DocHitInfo> second_vector = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info1}; auto first_iter = std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); @@ -295,15 +295,15 @@ TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats) { } TEST(DocHitInfoIteratorAndTest, PopulateMatchedTermsStats_NoMatchingDocument) { - DocHitInfo doc_hit_info1 = DocHitInfo(4); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); - DocHitInfo doc_hit_info2 = DocHitInfo(5); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(5); doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); - std::vector<DocHitInfo> first_vector = {doc_hit_info1}; - std::vector<DocHitInfo> second_vector = {doc_hit_info2}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2}; auto first_iter = std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); @@ -471,46 +471,47 @@ TEST(DocHitInfoIteratorAndNaryTest, PopulateMatchedTermsStats) { // DocHitInfoIterators. // For term "hi", document 10 and 8 SectionIdMask section_id_mask1_hi = 0b01000101; // hits in sections 0, 2, 6 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_hi{ - 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1_hi = DocHitInfo(10); + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hi{ + 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfoTermFrequencyPair doc_hit_info1_hi = DocHitInfo(10); doc_hit_info1_hi.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1_hi.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); - DocHitInfo doc_hit_info2_hi = DocHitInfo(8); + DocHitInfoTermFrequencyPair doc_hit_info2_hi = DocHitInfo(8); doc_hit_info2_hi.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); // For term "hello", document 10 and 9 SectionIdMask section_id_mask1_hello = 0b00001001; // hits in sections 0, 3 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_hello{ - 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1_hello = DocHitInfo(10); + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hello{ + 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfoTermFrequencyPair doc_hit_info1_hello = DocHitInfo(10); doc_hit_info1_hello.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); doc_hit_info1_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); - DocHitInfo doc_hit_info2_hello = DocHitInfo(9); + DocHitInfoTermFrequencyPair doc_hit_info2_hello = DocHitInfo(9); doc_hit_info2_hello.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/3); doc_hit_info2_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/2); // For term "ciao", document 10 and 9 SectionIdMask section_id_mask1_ciao = 0b00000011; // hits in sections 0, 1 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_ciao{ - 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1_ciao = DocHitInfo(10); + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_ciao{ + 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfoTermFrequencyPair doc_hit_info1_ciao = DocHitInfo(10); doc_hit_info1_ciao.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); doc_hit_info1_ciao.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/3); - DocHitInfo doc_hit_info2_ciao = DocHitInfo(9); + DocHitInfoTermFrequencyPair doc_hit_info2_ciao = DocHitInfo(9); doc_hit_info2_ciao.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); doc_hit_info2_ciao.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/2); - std::vector<DocHitInfo> first_vector = {doc_hit_info1_hi, doc_hit_info2_hi}; - std::vector<DocHitInfo> second_vector = {doc_hit_info1_hello, - doc_hit_info2_hello}; - std::vector<DocHitInfo> third_vector = {doc_hit_info1_ciao, - doc_hit_info2_ciao}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1_hi, + doc_hit_info2_hi}; + std::vector<DocHitInfoTermFrequencyPair> second_vector = { + doc_hit_info1_hello, doc_hit_info2_hello}; + std::vector<DocHitInfoTermFrequencyPair> third_vector = {doc_hit_info1_ciao, + doc_hit_info2_ciao}; auto first_iter = std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); diff --git a/icing/index/iterator/doc-hit-info-iterator-or.cc b/icing/index/iterator/doc-hit-info-iterator-or.cc index b4234e0..655cafc 100644 --- a/icing/index/iterator/doc-hit-info-iterator-or.cc +++ b/icing/index/iterator/doc-hit-info-iterator-or.cc @@ -115,7 +115,8 @@ libtextclassifier3::Status DocHitInfoIteratorOr::Advance() { // If equal, combine. if (left_document_id_ == right_document_id_) { - doc_hit_info_.MergeSectionsFrom(right_->doc_hit_info()); + doc_hit_info_.MergeSectionsFrom( + right_->doc_hit_info().hit_section_ids_mask()); hit_intersect_section_ids_mask_ &= right_->hit_intersect_section_ids_mask(); } @@ -195,7 +196,8 @@ libtextclassifier3::Status DocHitInfoIteratorOrNary::Advance() { hit_intersect_section_ids_mask_ = iterator->hit_intersect_section_ids_mask(); } else { - doc_hit_info_.MergeSectionsFrom(iterator->doc_hit_info()); + doc_hit_info_.MergeSectionsFrom( + iterator->doc_hit_info().hit_section_ids_mask()); hit_intersect_section_ids_mask_ &= iterator->hit_intersect_section_ids_mask(); } diff --git a/icing/index/iterator/doc-hit-info-iterator-or_test.cc b/icing/index/iterator/doc-hit-info-iterator-or_test.cc index 3f00a39..6e6872c 100644 --- a/icing/index/iterator/doc-hit-info-iterator-or_test.cc +++ b/icing/index/iterator/doc-hit-info-iterator-or_test.cc @@ -183,24 +183,24 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ - 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ + 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2{ - 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2{ + 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1 = DocHitInfo(4); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3); doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); - DocHitInfo doc_hit_info2 = DocHitInfo(4); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(4); doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); - std::vector<DocHitInfo> first_vector = {doc_hit_info1}; - std::vector<DocHitInfo> second_vector = {doc_hit_info2}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2}; auto first_iter = std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); @@ -235,15 +235,15 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. SectionIdMask section_id_mask1 = 0b00000101; // hits in sections 0, 2 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ - 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ + 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1 = DocHitInfo(4); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); - std::vector<DocHitInfo> first_vector = {doc_hit_info1}; - std::vector<DocHitInfo> second_vector = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info1}; auto first_iter = std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); @@ -274,24 +274,24 @@ TEST(DocHitInfoIteratorOrTest, PopulateMatchedTermsStats) { // Arbitrary section ids for the documents in the DocHitInfoIterators. // Created to test correct section_id_mask behavior. SectionIdMask section_id_mask1 = 0b01010101; // hits in sections 0, 2, 4, 6 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1{ - 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1{ + 1, 0, 2, 0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; SectionIdMask section_id_mask2 = 0b00000110; // hits in sections 1, 2 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2{ - 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2{ + 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1 = DocHitInfo(4); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(4); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/3); doc_hit_info1.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); - DocHitInfo doc_hit_info2 = DocHitInfo(5); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(5); doc_hit_info2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); - std::vector<DocHitInfo> first_vector = {doc_hit_info1}; - std::vector<DocHitInfo> second_vector = {doc_hit_info2}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> second_vector = {doc_hit_info2}; auto first_iter = std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); @@ -477,55 +477,56 @@ TEST(DocHitInfoIteratorOrNaryTest, PopulateMatchedTermsStats) { // DocHitInfoIterators. // For term "hi", document 10 and 8 SectionIdMask section_id_mask1_hi = 0b01000101; // hits in sections 0, 2, 6 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_hi{ - 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1_hi = DocHitInfo(10); + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hi{ + 1, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfoTermFrequencyPair doc_hit_info1_hi = DocHitInfo(10); doc_hit_info1_hi.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); doc_hit_info1_hi.UpdateSection(/*section_id=*/6, /*hit_term_frequency=*/4); SectionIdMask section_id_mask2_hi = 0b00000110; // hits in sections 1, 2 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2_hi{ - 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info2_hi = DocHitInfo(8); + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2_hi{ + 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfoTermFrequencyPair doc_hit_info2_hi = DocHitInfo(8); doc_hit_info2_hi.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/2); doc_hit_info2_hi.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/6); // For term "hello", document 10 and 9 SectionIdMask section_id_mask1_hello = 0b00001001; // hits in sections 0, 3 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_hello{ - 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1_hello = DocHitInfo(10); + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_hello{ + 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfoTermFrequencyPair doc_hit_info1_hello = DocHitInfo(10); doc_hit_info1_hello.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); doc_hit_info1_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); SectionIdMask section_id_mask2_hello = 0b00001100; // hits in sections 2, 3 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2_hello{ - 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info2_hello = DocHitInfo(9); + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2_hello{ + 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfoTermFrequencyPair doc_hit_info2_hello = DocHitInfo(9); doc_hit_info2_hello.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/3); doc_hit_info2_hello.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/2); // For term "ciao", document 9 and 8 SectionIdMask section_id_mask1_ciao = 0b00000011; // hits in sections 0, 1 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies1_ciao{ - 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info1_ciao = DocHitInfo(9); + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies1_ciao{ + 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfoTermFrequencyPair doc_hit_info1_ciao = DocHitInfo(9); doc_hit_info1_ciao.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/2); doc_hit_info1_ciao.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/3); SectionIdMask section_id_mask2_ciao = 0b00011000; // hits in sections 3, 4 - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies2_ciao{ - 0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; - DocHitInfo doc_hit_info2_ciao = DocHitInfo(8); + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies2_ciao{ + 0, 0, 0, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + DocHitInfoTermFrequencyPair doc_hit_info2_ciao = DocHitInfo(8); doc_hit_info2_ciao.UpdateSection(/*section_id=*/3, /*hit_term_frequency=*/3); doc_hit_info2_ciao.UpdateSection(/*section_id=*/4, /*hit_term_frequency=*/2); - std::vector<DocHitInfo> first_vector = {doc_hit_info1_hi, doc_hit_info2_hi}; - std::vector<DocHitInfo> second_vector = {doc_hit_info1_hello, - doc_hit_info2_hello}; - std::vector<DocHitInfo> third_vector = {doc_hit_info1_ciao, - doc_hit_info2_ciao}; + std::vector<DocHitInfoTermFrequencyPair> first_vector = {doc_hit_info1_hi, + doc_hit_info2_hi}; + std::vector<DocHitInfoTermFrequencyPair> second_vector = { + doc_hit_info1_hello, doc_hit_info2_hello}; + std::vector<DocHitInfoTermFrequencyPair> third_vector = {doc_hit_info1_ciao, + doc_hit_info2_ciao}; auto first_iter = std::make_unique<DocHitInfoIteratorDummy>(first_vector, "hi"); diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc index 9d33e2c..0871436 100644 --- a/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc +++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict.cc @@ -38,11 +38,11 @@ namespace lib { DocHitInfoIteratorSectionRestrict::DocHitInfoIteratorSectionRestrict( std::unique_ptr<DocHitInfoIterator> delegate, const DocumentStore* document_store, const SchemaStore* schema_store, - std::string_view target_section) + std::string target_section) : delegate_(std::move(delegate)), document_store_(*document_store), schema_store_(*schema_store), - target_section_(target_section) {} + target_section_(std::move(target_section)) {} libtextclassifier3::Status DocHitInfoIteratorSectionRestrict::Advance() { while (delegate_->Advance().ok()) { @@ -65,7 +65,7 @@ libtextclassifier3::Status DocHitInfoIteratorSectionRestrict::Advance() { // one of the confirmed section ids match the name of the target section while (section_id_mask != 0) { // There was a hit in this section id - SectionId section_id = __builtin_ctz(section_id_mask); + SectionId section_id = __builtin_ctzll(section_id_mask); auto section_metadata_or = schema_store_.GetSectionMetadata(schema_type_id, section_id); @@ -77,13 +77,13 @@ libtextclassifier3::Status DocHitInfoIteratorSectionRestrict::Advance() { if (section_metadata->path == target_section_) { // The hit was in the target section name, return OK/found doc_hit_info_ = delegate_->doc_hit_info(); - hit_intersect_section_ids_mask_ = 1u << section_id; + hit_intersect_section_ids_mask_ = UINT64_C(1) << section_id; return libtextclassifier3::Status::OK; } } // Mark this section as checked - section_id_mask &= ~(1U << section_id); + section_id_mask &= ~(UINT64_C(1) << section_id); } // Didn't find a matching section name for this hit. Continue. diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict.h b/icing/index/iterator/doc-hit-info-iterator-section-restrict.h index 52b243a..2639e67 100644 --- a/icing/index/iterator/doc-hit-info-iterator-section-restrict.h +++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict.h @@ -42,7 +42,7 @@ class DocHitInfoIteratorSectionRestrict : public DocHitInfoIterator { explicit DocHitInfoIteratorSectionRestrict( std::unique_ptr<DocHitInfoIterator> delegate, const DocumentStore* document_store, const SchemaStore* schema_store, - std::string_view target_section); + std::string target_section); libtextclassifier3::Status Advance() override; @@ -75,7 +75,7 @@ class DocHitInfoIteratorSectionRestrict : public DocHitInfoIterator { const SchemaStore& schema_store_; // Ensure that this does not outlive the underlying string value. - std::string_view target_section_; + std::string target_section_; }; } // namespace lib diff --git a/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc b/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc index 7c6d924..485f85b 100644 --- a/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc +++ b/icing/index/iterator/doc-hit-info-iterator-section-restrict_test.cc @@ -121,12 +121,12 @@ TEST_F(DocHitInfoIteratorSectionRestrictTest, // Created to test correct section_id_mask behavior. SectionIdMask original_section_id_mask = 0b00000101; // hits in sections 0, 2 - DocHitInfo doc_hit_info1 = DocHitInfo(document_id); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(document_id); doc_hit_info1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); doc_hit_info1.UpdateSection(/*section_id=*/2, /*hit_term_frequency=*/2); // Create a hit that was found in the indexed section - std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos = {doc_hit_info1}; auto original_iterator = std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "hi"); @@ -152,8 +152,8 @@ TEST_F(DocHitInfoIteratorSectionRestrictTest, section_restrict_iterator.PopulateMatchedTermsStats(&matched_terms_stats); EXPECT_EQ(matched_terms_stats.at(0).term, "hi"); - std::array<Hit::TermFrequency, kMaxSectionId> expected_term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> expected_term_frequencies{ + 1}; EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, ElementsAreArray(expected_term_frequencies)); EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, diff --git a/icing/index/iterator/doc-hit-info-iterator-test-util.h b/icing/index/iterator/doc-hit-info-iterator-test-util.h index 45acc8f..ed6db23 100644 --- a/icing/index/iterator/doc-hit-info-iterator-test-util.h +++ b/icing/index/iterator/doc-hit-info-iterator-test-util.h @@ -31,6 +31,40 @@ namespace icing { namespace lib { +class DocHitInfoTermFrequencyPair { + public: + DocHitInfoTermFrequencyPair( + const DocHitInfo& doc_hit_info, + const Hit::TermFrequencyArray& hit_term_frequency = {}) + : doc_hit_info_(doc_hit_info), hit_term_frequency_(hit_term_frequency) {} + + void UpdateSection(SectionId section_id, + Hit::TermFrequency hit_term_frequency) { + doc_hit_info_.UpdateSection(section_id); + hit_term_frequency_[section_id] = hit_term_frequency; + } + + void MergeSectionsFrom(const DocHitInfoTermFrequencyPair& other) { + SectionIdMask other_mask = other.doc_hit_info_.hit_section_ids_mask(); + doc_hit_info_.MergeSectionsFrom(other_mask); + while (other_mask) { + SectionId section_id = __builtin_ctzll(other_mask); + hit_term_frequency_[section_id] = other.hit_term_frequency_[section_id]; + other_mask &= ~(UINT64_C(1) << section_id); + } + } + + DocHitInfo doc_hit_info() const { return doc_hit_info_; } + + Hit::TermFrequency hit_term_frequency(SectionId section_id) const { + return hit_term_frequency_[section_id]; + } + + private: + DocHitInfo doc_hit_info_; + Hit::TermFrequencyArray hit_term_frequency_; +}; + // Dummy class to help with testing. It starts with an kInvalidDocumentId doc // hit info until an Advance is called (like normal DocHitInfoIterators). It // will then proceed to return the doc_hit_infos in order as Advance's are @@ -39,14 +73,23 @@ namespace lib { class DocHitInfoIteratorDummy : public DocHitInfoIterator { public: DocHitInfoIteratorDummy() = default; - explicit DocHitInfoIteratorDummy(std::vector<DocHitInfo> doc_hit_infos, - std::string term = "") + explicit DocHitInfoIteratorDummy( + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos, + std::string term = "") : doc_hit_infos_(std::move(doc_hit_infos)), term_(std::move(term)) {} + explicit DocHitInfoIteratorDummy(const std::vector<DocHitInfo>& doc_hit_infos, + std::string term = "") + : term_(std::move(term)) { + for (auto& doc_hit_info : doc_hit_infos) { + doc_hit_infos_.push_back(DocHitInfoTermFrequencyPair(doc_hit_info)); + } + } + libtextclassifier3::Status Advance() override { + ++index_; if (index_ < doc_hit_infos_.size()) { - doc_hit_info_ = doc_hit_infos_.at(index_); - index_++; + doc_hit_info_ = doc_hit_infos_.at(index_).doc_hit_info(); return libtextclassifier3::Status::OK; } @@ -58,20 +101,20 @@ class DocHitInfoIteratorDummy : public DocHitInfoIterator { void PopulateMatchedTermsStats( std::vector<TermMatchInfo>* matched_terms_stats, SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override { - if (doc_hit_info_.document_id() == kInvalidDocumentId) { + if (index_ == -1 || index_ >= doc_hit_infos_.size()) { // Current hit isn't valid, return. return; } SectionIdMask section_mask = doc_hit_info_.hit_section_ids_mask() & filtering_section_mask; SectionIdMask section_mask_copy = section_mask; - std::array<Hit::TermFrequency, kMaxSectionId> section_term_frequencies = { - Hit::kNoTermFrequency}; + std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies = + {Hit::kNoTermFrequency}; while (section_mask_copy) { - SectionId section_id = __builtin_ctz(section_mask_copy); + SectionId section_id = __builtin_ctzll(section_mask_copy); section_term_frequencies.at(section_id) = - doc_hit_info_.hit_term_frequency(section_id); - section_mask_copy &= ~(1u << section_id); + doc_hit_infos_.at(index_).hit_term_frequency(section_id); + section_mask_copy &= ~(UINT64_C(1) << section_id); } TermMatchInfo term_stats(term_, section_mask, std::move(section_term_frequencies)); @@ -109,20 +152,22 @@ class DocHitInfoIteratorDummy : public DocHitInfoIterator { std::string ToString() const override { std::string ret = "<"; - for (auto& doc_hit_info : doc_hit_infos_) { - absl_ports::StrAppend(&ret, IcingStringUtil::StringPrintf( - "[%d,%d]", doc_hit_info.document_id(), - doc_hit_info.hit_section_ids_mask())); + for (auto& doc_hit_info_pair : doc_hit_infos_) { + absl_ports::StrAppend( + &ret, IcingStringUtil::StringPrintf( + "[%d,%" PRIu64 "]", + doc_hit_info_pair.doc_hit_info().document_id(), + doc_hit_info_pair.doc_hit_info().hit_section_ids_mask())); } absl_ports::StrAppend(&ret, ">"); return ret; } private: - int32_t index_ = 0; + int32_t index_ = -1; int32_t num_blocks_inspected_ = 0; int32_t num_leaf_advance_calls_ = 0; - std::vector<DocHitInfo> doc_hit_infos_; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos_; std::string term_; }; diff --git a/icing/index/iterator/doc-hit-info-iterator.h b/icing/index/iterator/doc-hit-info-iterator.h index bf90202..b73b264 100644 --- a/icing/index/iterator/doc-hit-info-iterator.h +++ b/icing/index/iterator/doc-hit-info-iterator.h @@ -40,11 +40,11 @@ struct TermMatchInfo { SectionIdMask section_ids_mask; // Array with fixed size kMaxSectionId. For every section id, i.e. // vector index, it stores the term frequency of the term. - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies; explicit TermMatchInfo( std::string_view term, SectionIdMask section_ids_mask, - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies) + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies) : term(term), section_ids_mask(section_ids_mask), term_frequencies(std::move(term_frequencies)) {} diff --git a/icing/index/lite/doc-hit-info-iterator-term-lite.cc b/icing/index/lite/doc-hit-info-iterator-term-lite.cc index f215d63..597f5b5 100644 --- a/icing/index/lite/doc-hit-info-iterator-term-lite.cc +++ b/icing/index/lite/doc-hit-info-iterator-term-lite.cc @@ -14,7 +14,9 @@ #include "icing/index/lite/doc-hit-info-iterator-term-lite.h" +#include <array> #include <cstdint> +#include <numeric> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/absl_ports/canonical_errors.h" @@ -30,9 +32,9 @@ namespace lib { namespace { std::string SectionIdMaskToString(SectionIdMask section_id_mask) { - std::string mask(kMaxSectionId + 1, '0'); + std::string mask(kTotalNumSections, '0'); for (SectionId i = kMaxSectionId; i >= 0; --i) { - if (section_id_mask & (1U << i)) { + if (section_id_mask & (UINT64_C(1) << i)) { mask[kMaxSectionId - i] = '1'; } } @@ -76,9 +78,11 @@ libtextclassifier3::Status DocHitInfoIteratorTermLiteExact::RetrieveMoreHits() { ICING_ASSIGN_OR_RETURN(uint32_t tvi, lite_index_->GetTermId(term_)); ICING_ASSIGN_OR_RETURN(uint32_t term_id, term_id_codec_->EncodeTvi(tvi, TviType::LITE)); - lite_index_->AppendHits(term_id, section_restrict_mask_, - /*only_from_prefix_sections=*/false, - /*namespace_checker=*/nullptr, &cached_hits_); + lite_index_->AppendHits( + term_id, section_restrict_mask_, + /*only_from_prefix_sections=*/false, + /*namespace_checker=*/nullptr, &cached_hits_, + need_hit_term_frequency_ ? &cached_hit_term_frequency_ : nullptr); cached_hits_idx_ = 0; return libtextclassifier3::Status::OK; } @@ -99,9 +103,11 @@ DocHitInfoIteratorTermLitePrefix::RetrieveMoreHits() { ICING_ASSIGN_OR_RETURN( uint32_t term_id, term_id_codec_->EncodeTvi(it.GetValueIndex(), TviType::LITE)); - lite_index_->AppendHits(term_id, section_restrict_mask_, - /*only_from_prefix_sections=*/!exact_match, - /*namespace_checker=*/nullptr, &cached_hits_); + lite_index_->AppendHits( + term_id, section_restrict_mask_, + /*only_from_prefix_sections=*/!exact_match, + /*namespace_checker=*/nullptr, &cached_hits_, + need_hit_term_frequency_ ? &cached_hit_term_frequency_ : nullptr); ++terms_matched; } if (terms_matched > 1) { @@ -111,23 +117,83 @@ DocHitInfoIteratorTermLitePrefix::RetrieveMoreHits() { return libtextclassifier3::Status::OK; } -void DocHitInfoIteratorTermLitePrefix::SortAndDedupeDocumentIds() { +void DocHitInfoIteratorTermLitePrefix::SortDocumentIds() { // Re-sort cached document_ids and merge sections. - sort(cached_hits_.begin(), cached_hits_.end()); + if (!need_hit_term_frequency_) { + // If we don't need to also sort cached_hit_term_frequency_ along with + // cached_hits_, then just simply sort cached_hits_. + sort(cached_hits_.begin(), cached_hits_.end()); + } else { + // Sort cached_hit_term_frequency_ along with cached_hits_. + std::vector<int> indices(cached_hits_.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [this](int i, int j) { + return cached_hits_[i] < cached_hits_[j]; + }); + // Now indices is a map from sorted index to current index. In other words, + // the sorted cached_hits_[i] should be the current cached_hits_[indices[i]] + // for every valid i. + std::vector<bool> done(indices.size()); + // Apply permutation + for (int i = 0; i < indices.size(); ++i) { + if (done[i]) { + continue; + } + done[i] = true; + int curr = i; + int next = indices[i]; + // Since every finite permutation is formed by disjoint cycles, we can + // start with the current element, at index i, and swap the element at + // this position with whatever element that *should* be here. Then, + // continue to swap the original element, at its updated positions, with + // the element that should be occupying that position until the original + // element has reached *its* correct position. This completes applying the + // single cycle in the permutation. + while (next != i) { + std::swap(cached_hits_[curr], cached_hits_[next]); + std::swap(cached_hit_term_frequency_[curr], + cached_hit_term_frequency_[next]); + done[next] = true; + curr = next; + next = indices[next]; + } + } + } +} +void DocHitInfoIteratorTermLitePrefix::SortAndDedupeDocumentIds() { + SortDocumentIds(); int idx = 0; for (int i = 1; i < cached_hits_.size(); ++i) { - const DocHitInfo& hit_info = cached_hits_.at(i); - DocHitInfo& collapsed_hit_info = cached_hits_.at(idx); + const DocHitInfo& hit_info = cached_hits_[i]; + DocHitInfo& collapsed_hit_info = cached_hits_[idx]; if (collapsed_hit_info.document_id() == hit_info.document_id()) { - collapsed_hit_info.MergeSectionsFrom(hit_info); + SectionIdMask curr_mask = hit_info.hit_section_ids_mask(); + collapsed_hit_info.MergeSectionsFrom(curr_mask); + if (need_hit_term_frequency_) { + Hit::TermFrequencyArray& collapsed_term_frequency = + cached_hit_term_frequency_[idx]; + while (curr_mask) { + SectionId section_id = __builtin_ctzll(curr_mask); + collapsed_term_frequency[section_id] = + cached_hit_term_frequency_[i][section_id]; + curr_mask &= ~(UINT64_C(1) << section_id); + } + } } else { // New document_id. - cached_hits_.at(++idx) = hit_info; + ++idx; + cached_hits_[idx] = hit_info; + if (need_hit_term_frequency_) { + cached_hit_term_frequency_[idx] = cached_hit_term_frequency_[i]; + } } } // idx points to last doc hit info. cached_hits_.resize(idx + 1); + if (need_hit_term_frequency_) { + cached_hit_term_frequency_.resize(idx + 1); + } } std::string DocHitInfoIteratorTermLitePrefix::ToString() const { diff --git a/icing/index/lite/doc-hit-info-iterator-term-lite.h b/icing/index/lite/doc-hit-info-iterator-term-lite.h index 179fc93..bd8a6ee 100644 --- a/icing/index/lite/doc-hit-info-iterator-term-lite.h +++ b/icing/index/lite/doc-hit-info-iterator-term-lite.h @@ -33,39 +33,40 @@ class DocHitInfoIteratorTermLite : public DocHitInfoIterator { explicit DocHitInfoIteratorTermLite(const TermIdCodec* term_id_codec, LiteIndex* lite_index, const std::string& term, - SectionIdMask section_restrict_mask) + SectionIdMask section_restrict_mask, + bool need_hit_term_frequency) : term_(term), lite_index_(lite_index), cached_hits_idx_(-1), term_id_codec_(term_id_codec), num_advance_calls_(0), - section_restrict_mask_(section_restrict_mask) {} + section_restrict_mask_(section_restrict_mask), + need_hit_term_frequency_(need_hit_term_frequency) {} libtextclassifier3::Status Advance() override; - int32_t GetNumBlocksInspected() const override { - // TODO(b/137862424): Implement this once the main index is added. - return 0; - } + int32_t GetNumBlocksInspected() const override { return 0; } int32_t GetNumLeafAdvanceCalls() const override { return num_advance_calls_; } void PopulateMatchedTermsStats( std::vector<TermMatchInfo>* matched_terms_stats, SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override { - if (doc_hit_info_.document_id() == kInvalidDocumentId) { + if (cached_hits_idx_ == -1 || cached_hits_idx_ >= cached_hits_.size()) { // Current hit isn't valid, return. return; } SectionIdMask section_mask = doc_hit_info_.hit_section_ids_mask() & filtering_section_mask; SectionIdMask section_mask_copy = section_mask; - std::array<Hit::TermFrequency, kMaxSectionId> section_term_frequencies = { - Hit::kNoTermFrequency}; + std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies = + {Hit::kNoTermFrequency}; while (section_mask_copy) { - SectionId section_id = __builtin_ctz(section_mask_copy); - section_term_frequencies.at(section_id) = - doc_hit_info_.hit_term_frequency(section_id); - section_mask_copy &= ~(1u << section_id); + SectionId section_id = __builtin_ctzll(section_mask_copy); + if (need_hit_term_frequency_) { + section_term_frequencies.at(section_id) = + cached_hit_term_frequency_.at(cached_hits_idx_)[section_id]; + } + section_mask_copy &= ~(UINT64_C(1) << section_id); } TermMatchInfo term_stats(term_, section_mask, std::move(section_term_frequencies)); @@ -95,12 +96,14 @@ class DocHitInfoIteratorTermLite : public DocHitInfoIterator { // that are present in the index. Current value pointed to by the Iterator is // tracked by cached_hits_idx_. std::vector<DocHitInfo> cached_hits_; + std::vector<Hit::TermFrequencyArray> cached_hit_term_frequency_; int cached_hits_idx_; const TermIdCodec* term_id_codec_; int num_advance_calls_; // Mask indicating which sections hits should be considered for. // Ex. 0000 0000 0000 0010 means that only hits from section 1 are desired. const SectionIdMask section_restrict_mask_; + const bool need_hit_term_frequency_; }; class DocHitInfoIteratorTermLiteExact : public DocHitInfoIteratorTermLite { @@ -108,9 +111,10 @@ class DocHitInfoIteratorTermLiteExact : public DocHitInfoIteratorTermLite { explicit DocHitInfoIteratorTermLiteExact(const TermIdCodec* term_id_codec, LiteIndex* lite_index, const std::string& term, - SectionIdMask section_id_mask) + SectionIdMask section_id_mask, + bool need_hit_term_frequency) : DocHitInfoIteratorTermLite(term_id_codec, lite_index, term, - section_id_mask) {} + section_id_mask, need_hit_term_frequency) {} std::string ToString() const override; @@ -123,9 +127,10 @@ class DocHitInfoIteratorTermLitePrefix : public DocHitInfoIteratorTermLite { explicit DocHitInfoIteratorTermLitePrefix(const TermIdCodec* term_id_codec, LiteIndex* lite_index, const std::string& term, - SectionIdMask section_id_mask) + SectionIdMask section_id_mask, + bool need_hit_term_frequency) : DocHitInfoIteratorTermLite(term_id_codec, lite_index, term, - section_id_mask) {} + section_id_mask, need_hit_term_frequency) {} std::string ToString() const override; @@ -136,6 +141,7 @@ class DocHitInfoIteratorTermLitePrefix : public DocHitInfoIteratorTermLite { // After retrieving DocHitInfos from the index, a DocHitInfo for docid 1 and // "foo" and a DocHitInfo for docid 1 and "fool". These DocHitInfos should be // merged. + void SortDocumentIds(); void SortAndDedupeDocumentIds(); }; diff --git a/icing/index/lite/lite-index-header.h b/icing/index/lite/lite-index-header.h index dd6a0a8..58379d6 100644 --- a/icing/index/lite/lite-index-header.h +++ b/icing/index/lite/lite-index-header.h @@ -50,7 +50,7 @@ class LiteIndex_Header { class LiteIndex_HeaderImpl : public LiteIndex_Header { public: struct HeaderData { - static const uint32_t kMagic = 0x6dfba6a0; + static const uint32_t kMagic = 0xb4fb8792; uint32_t lite_index_crc; uint32_t magic; diff --git a/icing/index/lite/lite-index.cc b/icing/index/lite/lite-index.cc index 9622ff4..b10add9 100644 --- a/icing/index/lite/lite-index.cc +++ b/icing/index/lite/lite-index.cc @@ -230,7 +230,8 @@ Crc32 LiteIndex::ComputeChecksum() { Crc32 all_crc(header_->CalculateHeaderCrc()); all_crc.Append(std::string_view(reinterpret_cast<const char*>(dependent_crcs), sizeof(dependent_crcs))); - ICING_VLOG(2) << "Lite index crc computed in " << timer.Elapsed() * 1000 << "ms"; + ICING_VLOG(2) << "Lite index crc computed in " << timer.Elapsed() * 1000 + << "ms"; return all_crc; } @@ -332,10 +333,11 @@ libtextclassifier3::StatusOr<uint32_t> LiteIndex::GetTermId( return tvi; } -int LiteIndex::AppendHits(uint32_t term_id, SectionIdMask section_id_mask, - bool only_from_prefix_sections, - const NamespaceChecker* namespace_checker, - std::vector<DocHitInfo>* hits_out) { +int LiteIndex::AppendHits( + uint32_t term_id, SectionIdMask section_id_mask, + bool only_from_prefix_sections, const NamespaceChecker* namespace_checker, + std::vector<DocHitInfo>* hits_out, + std::vector<Hit::TermFrequencyArray>* term_frequency_out) { int count = 0; DocumentId last_document_id = kInvalidDocumentId; // Record whether the last document belongs to the given namespaces. @@ -347,7 +349,7 @@ int LiteIndex::AppendHits(uint32_t term_id, SectionIdMask section_id_mask, const Hit& hit = term_id_hit_pair.hit(); // Check sections. - if (((1u << hit.section_id()) & section_id_mask) == 0) { + if (((UINT64_C(1) << hit.section_id()) & section_id_mask) == 0) { continue; } // Check prefix section only. @@ -368,10 +370,16 @@ int LiteIndex::AppendHits(uint32_t term_id, SectionIdMask section_id_mask, ++count; if (hits_out != nullptr) { hits_out->push_back(DocHitInfo(document_id)); + if (term_frequency_out != nullptr) { + term_frequency_out->push_back(Hit::TermFrequencyArray()); + } } } if (hits_out != nullptr && last_document_in_namespace) { - hits_out->back().UpdateSection(hit.section_id(), hit.term_frequency()); + hits_out->back().UpdateSection(hit.section_id()); + if (term_frequency_out != nullptr) { + term_frequency_out->back()[hit.section_id()] = hit.term_frequency(); + } } } return count; @@ -458,7 +466,8 @@ void LiteIndex::SortHits() { array_start + header_->cur_size()); } ICING_VLOG(2) << "Lite index sort and merge " << sort_len << " into " - << header_->searchable_end() << " in " << timer.Elapsed() * 1000 << "ms"; + << header_->searchable_end() << " in " << timer.Elapsed() * 1000 + << "ms"; // Now the entire array is sorted. header_->set_searchable_end(header_->cur_size()); @@ -484,7 +493,8 @@ uint32_t LiteIndex::Seek(uint32_t term_id) { libtextclassifier3::Status LiteIndex::Optimize( const std::vector<DocumentId>& document_id_old_to_new, - const TermIdCodec* term_id_codec) { + const TermIdCodec* term_id_codec, DocumentId new_last_added_document_id) { + header_->set_last_added_docid(new_last_added_document_id); if (header_->cur_size() == 0) { return libtextclassifier3::Status::OK; } @@ -492,8 +502,6 @@ libtextclassifier3::Status LiteIndex::Optimize( // which helps later to determine which terms will be unused after compaction. SortHits(); uint32_t new_size = 0; - // The largest document id after translating hits. - DocumentId largest_document_id = kInvalidDocumentId; uint32_t curr_term_id = 0; uint32_t curr_tvi = 0; std::unordered_set<uint32_t> tvi_to_delete; @@ -518,10 +526,6 @@ libtextclassifier3::Status LiteIndex::Optimize( if (new_document_id == kInvalidDocumentId) { continue; } - if (largest_document_id == kInvalidDocumentId || - new_document_id > largest_document_id) { - largest_document_id = new_document_id; - } if (term_id_hit_pair.hit().is_in_prefix_section()) { lexicon_.SetProperty(curr_tvi, GetHasHitsInPrefixSectionPropertyId()); } @@ -539,7 +543,6 @@ libtextclassifier3::Status LiteIndex::Optimize( } header_->set_cur_size(new_size); header_->set_searchable_end(new_size); - header_->set_last_added_docid(largest_document_id); // Delete unused terms. std::unordered_set<std::string> terms_to_delete; diff --git a/icing/index/lite/lite-index.h b/icing/index/lite/lite-index.h index 64b5881..592e956 100644 --- a/icing/index/lite/lite-index.h +++ b/icing/index/lite/lite-index.h @@ -141,16 +141,19 @@ class LiteIndex { // Add all hits with term_id from the sections specified in section_id_mask, // skipping hits in non-prefix sections if only_from_prefix_sections is true, - // to hits_out. If hits_out is nullptr, no hits will be added. + // to hits_out. If hits_out is nullptr, no hits will be added. The + // corresponding hit term frequencies will also be added if term_frequency_out + // is nullptr. // // Only those hits which belongs to the given namespaces will be counted and // appended. A nullptr namespace checker will disable this check. // // Returns the number of hits that would be added to hits_out. - int AppendHits(uint32_t term_id, SectionIdMask section_id_mask, - bool only_from_prefix_sections, - const NamespaceChecker* namespace_checker, - std::vector<DocHitInfo>* hits_out); + int AppendHits( + uint32_t term_id, SectionIdMask section_id_mask, + bool only_from_prefix_sections, const NamespaceChecker* namespace_checker, + std::vector<DocHitInfo>* hits_out, + std::vector<Hit::TermFrequencyArray>* term_frequency_out = nullptr); // Returns the hit count of the term. // Only those hits which belongs to the given namespaces will be counted. @@ -263,13 +266,16 @@ class LiteIndex { // Reduces internal file sizes by reclaiming space of deleted documents. // + // This method also sets the last_added_docid of the index to + // new_last_added_document_id. + // // Returns: // OK on success // INTERNAL_ERROR on IO error, this indicates that the index may be in an // invalid state and should be cleared. libtextclassifier3::Status Optimize( const std::vector<DocumentId>& document_id_old_to_new, - const TermIdCodec* term_id_codec); + const TermIdCodec* term_id_codec, DocumentId new_last_added_document_id); private: static IcingDynamicTrie::RuntimeOptions MakeTrieRuntimeOptions(); diff --git a/icing/index/lite/lite-index_test.cc b/icing/index/lite/lite-index_test.cc index 825f830..12d5f42 100644 --- a/icing/index/lite/lite-index_test.cc +++ b/icing/index/lite/lite-index_test.cc @@ -18,6 +18,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "icing/index/lite/doc-hit-info-iterator-term-lite.h" #include "icing/index/term-id-codec.h" #include "icing/legacy/index/icing-mock-filesystem.h" #include "icing/schema/section.h" @@ -30,6 +31,7 @@ namespace lib { namespace { +using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::IsEmpty; using ::testing::SizeIs; @@ -105,6 +107,56 @@ TEST_F(LiteIndexTest, LiteIndexAppendHits) { EXPECT_THAT(hits2, IsEmpty()); } +TEST_F(LiteIndexTest, LiteIndexIterator) { + const std::string term = "foo"; + ICING_ASSERT_OK_AND_ASSIGN( + uint32_t tvi, + lite_index_->InsertTerm(term, TermMatchType::PREFIX, kNamespace0)); + ICING_ASSERT_OK_AND_ASSIGN(uint32_t foo_term_id, + term_id_codec_->EncodeTvi(tvi, TviType::LITE)); + Hit doc_hit0(/*section_id=*/0, /*document_id=*/0, 3, + /*is_in_prefix_section=*/false); + Hit doc_hit1(/*section_id=*/1, /*document_id=*/0, 5, + /*is_in_prefix_section=*/false); + Hit::TermFrequencyArray doc0_term_frequencies{3, 5}; + Hit doc_hit2(/*section_id=*/1, /*document_id=*/1, 7, + /*is_in_prefix_section=*/false); + Hit doc_hit3(/*section_id=*/2, /*document_id=*/1, 11, + /*is_in_prefix_section=*/false); + Hit::TermFrequencyArray doc1_term_frequencies{0, 7, 11}; + ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit0)); + ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit1)); + ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit2)); + ICING_ASSERT_OK(lite_index_->AddHit(foo_term_id, doc_hit3)); + + std::unique_ptr<DocHitInfoIteratorTermLiteExact> iter = + std::make_unique<DocHitInfoIteratorTermLiteExact>( + term_id_codec_.get(), lite_index_.get(), term, kSectionIdMaskAll, + /*need_hit_term_frequency=*/true); + + ASSERT_THAT(iter->Advance(), IsOk()); + EXPECT_THAT(iter->doc_hit_info().document_id(), Eq(1)); + EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(), Eq(0b110)); + std::vector<TermMatchInfo> matched_terms_stats; + iter->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); + EXPECT_EQ(matched_terms_stats.at(0).term, term); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, 0b110); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(doc1_term_frequencies)); + + ASSERT_THAT(iter->Advance(), IsOk()); + EXPECT_THAT(iter->doc_hit_info().document_id(), Eq(0)); + EXPECT_THAT(iter->doc_hit_info().hit_section_ids_mask(), Eq(0b11)); + matched_terms_stats.clear(); + iter->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); + EXPECT_EQ(matched_terms_stats.at(0).term, term); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, 0b11); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(doc0_term_frequencies)); +} + } // namespace } // namespace lib } // namespace icing diff --git a/icing/index/main/doc-hit-info-iterator-term-main.cc b/icing/index/main/doc-hit-info-iterator-term-main.cc index 98bc18e..4bd87aa 100644 --- a/icing/index/main/doc-hit-info-iterator-term-main.cc +++ b/icing/index/main/doc-hit-info-iterator-term-main.cc @@ -34,9 +34,9 @@ namespace lib { namespace { std::string SectionIdMaskToString(SectionIdMask section_id_mask) { - std::string mask(kMaxSectionId + 1, '0'); + std::string mask(kTotalNumSections, '0'); for (SectionId i = kMaxSectionId; i >= 0; --i) { - if (section_id_mask & (1U << i)) { + if (section_id_mask & (UINT64_C(1) << i)) { mask[kMaxSectionId - i] = '1'; } } @@ -102,9 +102,10 @@ libtextclassifier3::Status DocHitInfoIteratorTermMainExact::RetrieveMoreHits() { posting_list_accessor_->GetNextHitsBatch()); ++num_blocks_inspected_; cached_doc_hit_infos_.reserve(hits.size() + 1); + cached_hit_term_frequency_.reserve(hits.size() + 1); for (const Hit& hit : hits) { // Check sections. - if (((1u << hit.section_id()) & section_restrict_mask_) == 0) { + if (((UINT64_C(1) << hit.section_id()) & section_restrict_mask_) == 0) { continue; } // We want exact hits, skip prefix-only hits. @@ -114,9 +115,10 @@ libtextclassifier3::Status DocHitInfoIteratorTermMainExact::RetrieveMoreHits() { if (cached_doc_hit_infos_.empty() || hit.document_id() != cached_doc_hit_infos_.back().document_id()) { cached_doc_hit_infos_.push_back(DocHitInfo(hit.document_id())); + cached_hit_term_frequency_.push_back(Hit::TermFrequencyArray()); } - cached_doc_hit_infos_.back().UpdateSection(hit.section_id(), - hit.term_frequency()); + cached_doc_hit_infos_.back().UpdateSection(hit.section_id()); + cached_hit_term_frequency_.back()[hit.section_id()] = hit.term_frequency(); } return libtextclassifier3::Status::OK; } @@ -142,18 +144,20 @@ DocHitInfoIteratorTermMainPrefix::RetrieveMoreHits() { ++num_blocks_inspected_; if (posting_list_accessor_ == nullptr) { - ICING_ASSIGN_OR_RETURN( - MainIndex::GetPrefixAccessorResult result, - main_index_->GetAccessorForPrefixTerm(term_)); + ICING_ASSIGN_OR_RETURN(MainIndex::GetPrefixAccessorResult result, + main_index_->GetAccessorForPrefixTerm(term_)); posting_list_accessor_ = std::move(result.accessor); exact_ = result.exact; } ICING_ASSIGN_OR_RETURN(std::vector<Hit> hits, posting_list_accessor_->GetNextHitsBatch()); cached_doc_hit_infos_.reserve(hits.size()); + if (need_hit_term_frequency_) { + cached_hit_term_frequency_.reserve(hits.size()); + } for (const Hit& hit : hits) { // Check sections. - if (((1u << hit.section_id()) & section_restrict_mask_) == 0) { + if (((UINT64_C(1) << hit.section_id()) & section_restrict_mask_) == 0) { continue; } // If we only want hits from prefix sections. @@ -163,9 +167,15 @@ DocHitInfoIteratorTermMainPrefix::RetrieveMoreHits() { if (cached_doc_hit_infos_.empty() || hit.document_id() != cached_doc_hit_infos_.back().document_id()) { cached_doc_hit_infos_.push_back(DocHitInfo(hit.document_id())); + if (need_hit_term_frequency_) { + cached_hit_term_frequency_.push_back(Hit::TermFrequencyArray()); + } + } + cached_doc_hit_infos_.back().UpdateSection(hit.section_id()); + if (need_hit_term_frequency_) { + cached_hit_term_frequency_.back()[hit.section_id()] = + hit.term_frequency(); } - cached_doc_hit_infos_.back().UpdateSection(hit.section_id(), - hit.term_frequency()); } return libtextclassifier3::Status::OK; } diff --git a/icing/index/main/doc-hit-info-iterator-term-main.h b/icing/index/main/doc-hit-info-iterator-term-main.h index f3cf701..c1b289f 100644 --- a/icing/index/main/doc-hit-info-iterator-term-main.h +++ b/icing/index/main/doc-hit-info-iterator-term-main.h @@ -33,14 +33,16 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { public: explicit DocHitInfoIteratorTermMain(MainIndex* main_index, const std::string& term, - SectionIdMask section_restrict_mask) + SectionIdMask section_restrict_mask, + bool need_hit_term_frequency) : term_(term), main_index_(main_index), cached_doc_hit_infos_idx_(-1), num_advance_calls_(0), num_blocks_inspected_(0), next_posting_list_id_(PostingListIdentifier::kInvalid), - section_restrict_mask_(section_restrict_mask) {} + section_restrict_mask_(section_restrict_mask), + need_hit_term_frequency_(need_hit_term_frequency) {} libtextclassifier3::Status Advance() override; @@ -52,20 +54,23 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { void PopulateMatchedTermsStats( std::vector<TermMatchInfo>* matched_terms_stats, SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override { - if (doc_hit_info_.document_id() == kInvalidDocumentId) { + if (cached_doc_hit_infos_idx_ == -1 || + cached_doc_hit_infos_idx_ >= cached_doc_hit_infos_.size()) { // Current hit isn't valid, return. return; } SectionIdMask section_mask = doc_hit_info_.hit_section_ids_mask() & filtering_section_mask; SectionIdMask section_mask_copy = section_mask; - std::array<Hit::TermFrequency, kMaxSectionId> section_term_frequencies = { - Hit::kNoTermFrequency}; + std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies = + {Hit::kNoTermFrequency}; while (section_mask_copy) { - SectionId section_id = __builtin_ctz(section_mask_copy); - section_term_frequencies.at(section_id) = - doc_hit_info_.hit_term_frequency(section_id); - section_mask_copy &= ~(1u << section_id); + SectionId section_id = __builtin_ctzll(section_mask_copy); + if (need_hit_term_frequency_) { + section_term_frequencies.at(section_id) = cached_hit_term_frequency_.at( + cached_doc_hit_infos_idx_)[section_id]; + } + section_mask_copy &= ~(UINT64_C(1) << section_id); } TermMatchInfo term_stats(term_, section_mask, std::move(section_term_frequencies)); @@ -93,6 +98,7 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { // that are present in the index. Current value pointed to by the Iterator is // tracked by cached_doc_hit_infos_idx_. std::vector<DocHitInfo> cached_doc_hit_infos_; + std::vector<Hit::TermFrequencyArray> cached_hit_term_frequency_; int cached_doc_hit_infos_idx_; int num_advance_calls_; int num_blocks_inspected_; @@ -100,14 +106,17 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { // Mask indicating which sections hits should be considered for. // Ex. 0000 0000 0000 0010 means that only hits from section 1 are desired. const SectionIdMask section_restrict_mask_; + const bool need_hit_term_frequency_; }; class DocHitInfoIteratorTermMainExact : public DocHitInfoIteratorTermMain { public: explicit DocHitInfoIteratorTermMainExact(MainIndex* main_index, const std::string& term, - SectionIdMask section_restrict_mask) - : DocHitInfoIteratorTermMain(main_index, term, section_restrict_mask) {} + SectionIdMask section_restrict_mask, + bool need_hit_term_frequency) + : DocHitInfoIteratorTermMain(main_index, term, section_restrict_mask, + need_hit_term_frequency) {} std::string ToString() const override; @@ -119,8 +128,10 @@ class DocHitInfoIteratorTermMainPrefix : public DocHitInfoIteratorTermMain { public: explicit DocHitInfoIteratorTermMainPrefix(MainIndex* main_index, const std::string& term, - SectionIdMask section_restrict_mask) - : DocHitInfoIteratorTermMain(main_index, term, section_restrict_mask) {} + SectionIdMask section_restrict_mask, + bool need_hit_term_frequency) + : DocHitInfoIteratorTermMain(main_index, term, section_restrict_mask, + need_hit_term_frequency) {} std::string ToString() const override; diff --git a/icing/index/main/flash-index-storage-header.h b/icing/index/main/flash-index-storage-header.h index f81e99e..71ec816 100644 --- a/icing/index/main/flash-index-storage-header.h +++ b/icing/index/main/flash-index-storage-header.h @@ -33,7 +33,7 @@ class HeaderBlock { // The class used to access the actual header. struct Header { // A magic used to mark the beginning of a valid header. - static constexpr int kMagic = 0x6dfba6ae; + static constexpr int kMagic = 0xb0780cf4; int magic; int block_size; int last_indexed_docid; diff --git a/icing/index/main/main-index.cc b/icing/index/main/main-index.cc index 9f591c0..0fdc1ac 100644 --- a/icing/index/main/main-index.cc +++ b/icing/index/main/main-index.cc @@ -234,24 +234,27 @@ MainIndex::FindTermsByPrefix(const std::string& prefix, flash_index_storage_.get(), posting_list_id)); ICING_ASSIGN_OR_RETURN(std::vector<Hit> hits, pl_accessor.GetNextHitsBatch()); - for (const Hit& hit : hits) { - DocumentId document_id = hit.document_id(); - if (document_id != last_document_id) { - last_document_id = document_id; - if (term_match_type == TermMatchType::EXACT_ONLY && - hit.is_prefix_hit()) { - continue; + while (!hits.empty()) { + for (const Hit& hit : hits) { + DocumentId document_id = hit.document_id(); + if (document_id != last_document_id) { + last_document_id = document_id; + if (term_match_type == TermMatchType::EXACT_ONLY && + hit.is_prefix_hit()) { + continue; + } + if (!namespace_checker->BelongsToTargetNamespaces(document_id)) { + // The document is removed or expired or not belongs to target + // namespaces. + continue; + } + // TODO(b/152934343) Add search type in SuggestionSpec to ask user to + // input search type, prefix or exact. And make different score + // strategy base on that. + ++count; } - if (!namespace_checker->BelongsToTargetNamespaces(document_id)) { - // The document is removed or expired or not belongs to target - // namespaces. - continue; - } - // TODO(b/152934343) Add search type in SuggestionSpec to ask user to - // input search type, prefix or exact. And make different score strategy - // base on that. - ++count; } + ICING_ASSIGN_OR_RETURN(hits, pl_accessor.GetNextHitsBatch()); } if (count > 0) { term_metadata_list.push_back(TermMetadata(term_iterator.GetKey(), count)); diff --git a/icing/index/main/main-index.h b/icing/index/main/main-index.h index 15030b0..4ed2e94 100644 --- a/icing/index/main/main-index.h +++ b/icing/index/main/main-index.h @@ -190,6 +190,9 @@ class MainIndex { // Reduces internal file sizes by reclaiming space of deleted documents. // + // This method will update the last_added_docid of the index to the largest + // document id that still appears in the index after compaction. + // // Returns: // OK on success // INTERNAL_ERROR on IO error, this indicates that the index may be in an diff --git a/icing/index/main/main-index_test.cc b/icing/index/main/main-index_test.cc index fa83d68..bfda014 100644 --- a/icing/index/main/main-index_test.cc +++ b/icing/index/main/main-index_test.cc @@ -56,7 +56,7 @@ std::vector<DocHitInfo> GetExactHits( MainIndex* main_index, const std::string& term, SectionIdMask section_mask = kSectionIdMaskAll) { auto iterator = std::make_unique<DocHitInfoIteratorTermMainExact>( - main_index, term, section_mask); + main_index, term, section_mask, /*need_hit_term_frequency=*/true); return GetHits(std::move(iterator)); } @@ -64,7 +64,7 @@ std::vector<DocHitInfo> GetPrefixHits( MainIndex* main_index, const std::string& term, SectionIdMask section_mask = kSectionIdMaskAll) { auto iterator = std::make_unique<DocHitInfoIteratorTermMainPrefix>( - main_index, term, section_mask); + main_index, term, section_mask, /*need_hit_term_frequency=*/true); return GetHits(std::move(iterator)); } diff --git a/icing/index/main/posting-list-accessor.cc b/icing/index/main/posting-list-accessor.cc index a4f8ca7..67e1ad5 100644 --- a/icing/index/main/posting-list-accessor.cc +++ b/icing/index/main/posting-list-accessor.cc @@ -61,7 +61,7 @@ PostingListAccessor::GetNextHitsBatch() { return std::vector<Hit>(); } return absl_ports::FailedPreconditionError( - "Cannot retrieve hits from a PostingListAccessor that was not creaated " + "Cannot retrieve hits from a PostingListAccessor that was not created " "from a preexisting posting list."); } ICING_ASSIGN_OR_RETURN(std::vector<Hit> batch, diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc index c9e7127..6757e29 100644 --- a/icing/jni/icing-search-engine-jni.cc +++ b/icing/jni/icing-search-engine-jni.cc @@ -236,13 +236,26 @@ Java_com_google_android_icing_IcingSearchEngine_nativeGetAllNamespaces( JNIEXPORT jbyteArray JNICALL Java_com_google_android_icing_IcingSearchEngine_nativeGetNextPage( - JNIEnv* env, jclass clazz, jobject object, jlong next_page_token) { + JNIEnv* env, jclass clazz, jobject object, jlong next_page_token, + jlong java_to_native_start_timestamp_ms) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); + const std::unique_ptr<const icing::lib::Clock> clock = + std::make_unique<icing::lib::Clock>(); + // TODO(b/236412954): java_to_native_start_timestamp_ms can only be used after + // cl/469819190 is synced to Jetpack and exported back to google3. + // int32 java_to_native_jni_latency_ms = + // clock->GetSystemTimeMilliseconds() - java_to_native_start_timestamp_ms; + icing::lib::SearchResultProto next_page_result_proto = icing->GetNextPage(next_page_token); + icing::lib::QueryStatsProto* query_stats = + next_page_result_proto.mutable_query_stats(); + // query_stats->set_java_to_native_jni_latency_ms(java_to_native_jni_latency_ms); + query_stats->set_native_to_java_start_timestamp_ms(clock->GetSystemTimeMilliseconds()); + return SerializeProtoToJniByteArray(env, next_page_result_proto); } @@ -260,7 +273,8 @@ Java_com_google_android_icing_IcingSearchEngine_nativeInvalidateNextPageToken( JNIEXPORT jbyteArray JNICALL Java_com_google_android_icing_IcingSearchEngine_nativeSearch( JNIEnv* env, jclass clazz, jobject object, jbyteArray search_spec_bytes, - jbyteArray scoring_spec_bytes, jbyteArray result_spec_bytes) { + jbyteArray scoring_spec_bytes, jbyteArray result_spec_bytes, + jlong java_to_native_start_timestamp_ms) { icing::lib::IcingSearchEngine* icing = GetIcingSearchEnginePointer(env, object); @@ -283,9 +297,21 @@ Java_com_google_android_icing_IcingSearchEngine_nativeSearch( return nullptr; } + const std::unique_ptr<const icing::lib::Clock> clock = + std::make_unique<icing::lib::Clock>(); + // TODO(b/236412954): java_to_native_start_timestamp_ms can only be used after + // cl/469819190 is synced to Jetpack and exported back to google3. + // int32 java_to_native_jni_latency_ms = + // clock->GetSystemTimeMilliseconds() - java_to_native_start_timestamp_ms; + icing::lib::SearchResultProto search_result_proto = icing->Search(search_spec_proto, scoring_spec_proto, result_spec_proto); + icing::lib::QueryStatsProto* query_stats = + search_result_proto.mutable_query_stats(); + // query_stats->set_java_to_native_jni_latency_ms(java_to_native_jni_latency_ms); + query_stats->set_native_to_java_start_timestamp_ms(clock->GetSystemTimeMilliseconds()); + return SerializeProtoToJniByteArray(env, search_result_proto); } diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc index 36c76db..8a942cf 100644 --- a/icing/query/query-processor.cc +++ b/icing/query/query-processor.cc @@ -69,7 +69,7 @@ struct ParserStateFrame { // If the last independent token was a property/section filter, then we need // to save the section name so we can create a section filter iterator. - std::string_view section_restrict = ""; + std::string section_restrict; }; // Combines any OR and AND iterators together into one iterator. @@ -145,8 +145,11 @@ DocHitInfoIteratorFilter::Options QueryProcessor::getFilterOptions( } libtextclassifier3::StatusOr<QueryProcessor::QueryResults> -QueryProcessor::ParseSearch(const SearchSpecProto& search_spec) { - ICING_ASSIGN_OR_RETURN(QueryResults results, ParseRawQuery(search_spec)); +QueryProcessor::ParseSearch( + const SearchSpecProto& search_spec, + ScoringSpecProto::RankingStrategy::Code ranking_strategy) { + ICING_ASSIGN_OR_RETURN(QueryResults results, + ParseRawQuery(search_spec, ranking_strategy)); DocHitInfoIteratorFilter::Options options = getFilterOptions(search_spec); results.root_iterator = std::make_unique<DocHitInfoIteratorFilter>( @@ -157,7 +160,9 @@ QueryProcessor::ParseSearch(const SearchSpecProto& search_spec) { // TODO(cassiewang): Collect query stats to populate the SearchResultsProto libtextclassifier3::StatusOr<QueryProcessor::QueryResults> -QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) { +QueryProcessor::ParseRawQuery( + const SearchSpecProto& search_spec, + ScoringSpecProto::RankingStrategy::Code ranking_strategy) { DocHitInfoIteratorFilter::Options options = getFilterOptions(search_spec); // Tokenize the incoming raw query @@ -220,7 +225,7 @@ QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) { "Encountered empty stack of ParserStateFrames"); } - frames.top().section_restrict = token.text; + frames.top().section_restrict = std::string(token.text); break; } case Token::Type::REGULAR: { @@ -257,8 +262,11 @@ QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) { ICING_ASSIGN_OR_RETURN( result_iterator, - index_.GetIterator(normalized_text, kSectionIdMaskAll, - search_spec.term_match_type())); + index_.GetIterator( + normalized_text, kSectionIdMaskAll, + search_spec.term_match_type(), + /*need_hit_term_frequency=*/ranking_strategy == + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Add term iterator and terms to match if this is not a negation term. // WARNING: setting query terms at this point is not compatible with @@ -268,14 +276,19 @@ QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) { if (!frames.top().saw_exclude) { ICING_ASSIGN_OR_RETURN( std::unique_ptr<DocHitInfoIterator> term_iterator, - index_.GetIterator(normalized_text, kSectionIdMaskAll, - search_spec.term_match_type())); - - results.query_term_iterators[normalized_text] = - std::make_unique<DocHitInfoIteratorFilter>( - std::move(term_iterator), &document_store_, &schema_store_, - options); - + index_.GetIterator( + normalized_text, kSectionIdMaskAll, + search_spec.term_match_type(), + /*need_hit_term_frequency=*/ranking_strategy == + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); + + if (ranking_strategy == + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE) { + results.query_term_iterators[normalized_text] = + std::make_unique<DocHitInfoIteratorFilter>( + std::move(term_iterator), &document_store_, &schema_store_, + options); + } results.query_terms[frames.top().section_restrict].insert( std::move(normalized_text)); } @@ -330,7 +343,7 @@ QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) { // the section restrict result_iterator = std::make_unique<DocHitInfoIteratorSectionRestrict>( std::move(result_iterator), &document_store_, &schema_store_, - frames.top().section_restrict); + std::move(frames.top().section_restrict)); frames.top().section_restrict = ""; } diff --git a/icing/query/query-processor.h b/icing/query/query-processor.h index bdf9ef2..f99576a 100644 --- a/icing/query/query-processor.h +++ b/icing/query/query-processor.h @@ -56,12 +56,18 @@ class QueryProcessor { // Hit iterators for the text terms in the query. These query_term_iterators // are completely separate from the iterators that make the iterator tree // beginning with root_iterator. + // This will only be populated when ranking_strategy == RELEVANCE_SCORE. std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; }; // Parse the search configurations (including the query, any additional // filters, etc.) in the SearchSpecProto into one DocHitInfoIterator. // + // When ranking_strategy == RELEVANCE_SCORE, the root_iterator and the + // query_term_iterators returned will keep term frequency information + // internally, so that term frequency stats will be collected when calling + // PopulateMatchedTermsStats to the iterators. + // // Returns: // On success, // - One iterator that represents the entire query @@ -69,7 +75,8 @@ class QueryProcessor { // INVALID_ARGUMENT if query syntax is incorrect and cannot be tokenized // INTERNAL_ERROR on all other errors libtextclassifier3::StatusOr<QueryResults> ParseSearch( - const SearchSpecProto& search_spec); + const SearchSpecProto& search_spec, + ScoringSpecProto::RankingStrategy::Code ranking_strategy); private: explicit QueryProcessor(Index* index, @@ -88,7 +95,8 @@ class QueryProcessor { // INVALID_ARGUMENT if query syntax is incorrect and cannot be tokenized // INTERNAL_ERROR on all other errors libtextclassifier3::StatusOr<QueryResults> ParseRawQuery( - const SearchSpecProto& search_spec); + const SearchSpecProto& search_spec, + ScoringSpecProto::RankingStrategy::Code ranking_strategy); // Return the options for the DocHitInfoIteratorFilter based on the // search_spec. diff --git a/icing/query/query-processor_benchmark.cc b/icing/query/query-processor_benchmark.cc index b505ac5..3b3ea0d 100644 --- a/icing/query/query-processor_benchmark.cc +++ b/icing/query/query-processor_benchmark.cc @@ -155,7 +155,10 @@ void BM_QueryOneTerm(benchmark::State& state) { for (auto _ : state) { QueryProcessor::QueryResults results = - query_processor->ParseSearch(search_spec).ValueOrDie(); + query_processor + ->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE) + .ValueOrDie(); while (results.root_iterator->Advance().ok()) { results.root_iterator->doc_hit_info(); } @@ -290,7 +293,10 @@ void BM_QueryFiveTerms(benchmark::State& state) { for (auto _ : state) { QueryProcessor::QueryResults results = - query_processor->ParseSearch(search_spec).ValueOrDie(); + query_processor + ->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE) + .ValueOrDie(); while (results.root_iterator->Advance().ok()) { results.root_iterator->doc_hit_info(); } @@ -410,7 +416,10 @@ void BM_QueryDiacriticTerm(benchmark::State& state) { for (auto _ : state) { QueryProcessor::QueryResults results = - query_processor->ParseSearch(search_spec).ValueOrDie(); + query_processor + ->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE) + .ValueOrDie(); while (results.root_iterator->Advance().ok()) { results.root_iterator->doc_hit_info(); } @@ -530,7 +539,10 @@ void BM_QueryHiragana(benchmark::State& state) { for (auto _ : state) { QueryProcessor::QueryResults results = - query_processor->ParseSearch(search_spec).ValueOrDie(); + query_processor + ->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE) + .ValueOrDie(); while (results.root_iterator->Advance().ok()) { results.root_iterator->doc_hit_info(); } diff --git a/icing/query/query-processor_test.cc b/icing/query/query-processor_test.cc index d1cce87..d8b987a 100644 --- a/icing/query/query-processor_test.cc +++ b/icing/query/query-processor_test.cc @@ -213,8 +213,10 @@ TEST_F(QueryProcessorTest, EmptyGroupMatchAllDocuments) { SearchSpecProto search_spec; search_spec.set_query("()"); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); // Descending order of valid DocumentIds EXPECT_THAT(GetDocumentIds(results.root_iterator.get()), @@ -264,8 +266,10 @@ TEST_F(QueryProcessorTest, EmptyQueryMatchAllDocuments) { SearchSpecProto search_spec; search_spec.set_query(""); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); // Descending order of valid DocumentIds EXPECT_THAT(GetDocumentIds(results.root_iterator.get()), @@ -304,8 +308,7 @@ TEST_F(QueryProcessorTest, QueryTermNormalized) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -325,8 +328,10 @@ TEST_F(QueryProcessorTest, QueryTermNormalized) { search_spec.set_query("hElLo WORLD"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; ASSERT_THAT(results.root_iterator->Advance(), IsOk()); @@ -380,8 +385,7 @@ TEST_F(QueryProcessorTest, OneTermPrefixMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::PREFIX; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -398,8 +402,80 @@ TEST_F(QueryProcessorTest, OneTermPrefixMatch) { search_spec.set_query("he"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); + + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(1)); // 1 term + EXPECT_EQ(matched_terms_stats.at(0).term, "he"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(term_frequencies)); + + EXPECT_THAT(results.query_terms, SizeIs(1)); + EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("he")); + EXPECT_THAT(results.query_term_iterators, SizeIs(1)); +} + +TEST_F(QueryProcessorTest, OneTermPrefixMatchWithMaxSectionID) { + // Create the schema and document store + SchemaProto schema = SchemaBuilder() + .AddType(SchemaTypeConfigBuilder().SetType("email")) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + + // These documents don't actually match to the tokens in the index. We're + // inserting the documents to get the appropriate number of documents and + // namespaces populated. + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id, + document_store_->Put(DocumentBuilder() + .SetKey("namespace1", "1") + .SetSchema("email") + .Build())); + + // Populate the index + SectionId section_id = kMaxSectionId; + SectionIdMask section_id_mask = UINT64_C(1) << section_id; + TermMatchType::Code term_match_type = TermMatchType::PREFIX; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{}; + term_frequencies[kMaxSectionId] = 1; + + EXPECT_THAT( + AddTokenToIndex(document_id, section_id, term_match_type, "hello"), + IsOk()); + + // Perform query + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<QueryProcessor> query_processor, + QueryProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get())); + + SearchSpecProto search_spec; + search_spec.set_query("he"); + search_spec.set_term_match_type(term_match_type); + + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; ASSERT_THAT(results.root_iterator->Advance(), IsOk()); @@ -448,8 +524,7 @@ TEST_F(QueryProcessorTest, OneTermExactMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -466,8 +541,10 @@ TEST_F(QueryProcessorTest, OneTermExactMatch) { search_spec.set_query("hello"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; ASSERT_THAT(results.root_iterator->Advance(), IsOk()); @@ -516,8 +593,7 @@ TEST_F(QueryProcessorTest, AndSameTermExactMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -534,8 +610,10 @@ TEST_F(QueryProcessorTest, AndSameTermExactMatch) { search_spec.set_query("hello hello"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; ASSERT_THAT(results.root_iterator->Advance(), IsOk()); @@ -586,8 +664,7 @@ TEST_F(QueryProcessorTest, AndTwoTermExactMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -607,8 +684,10 @@ TEST_F(QueryProcessorTest, AndTwoTermExactMatch) { search_spec.set_query("hello world"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; ASSERT_THAT(results.root_iterator->Advance(), IsOk()); @@ -661,8 +740,7 @@ TEST_F(QueryProcessorTest, AndSameTermPrefixMatch) { SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; TermMatchType::Code term_match_type = TermMatchType::PREFIX; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT( AddTokenToIndex(document_id, section_id, term_match_type, "hello"), @@ -679,8 +757,10 @@ TEST_F(QueryProcessorTest, AndSameTermPrefixMatch) { search_spec.set_query("he he"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); std::vector<TermMatchInfo> matched_terms_stats; ASSERT_THAT(results.root_iterator->Advance(), IsOk()); @@ -730,8 +810,7 @@ TEST_F(QueryProcessorTest, AndTwoTermPrefixMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::PREFIX; EXPECT_THAT( @@ -752,8 +831,10 @@ TEST_F(QueryProcessorTest, AndTwoTermPrefixMatch) { search_spec.set_query("he wo"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds std::vector<TermMatchInfo> matched_terms_stats; @@ -806,8 +887,7 @@ TEST_F(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::PREFIX; EXPECT_THAT(AddTokenToIndex(document_id, section_id, @@ -828,8 +908,10 @@ TEST_F(QueryProcessorTest, AndTwoTermPrefixAndExactMatch) { search_spec.set_query("hello wo"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds std::vector<TermMatchInfo> matched_terms_stats; @@ -887,8 +969,7 @@ TEST_F(QueryProcessorTest, OrTwoTermExactMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; EXPECT_THAT( @@ -909,8 +990,10 @@ TEST_F(QueryProcessorTest, OrTwoTermExactMatch) { search_spec.set_query("hello OR world"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds std::vector<TermMatchInfo> matched_terms_stats; @@ -976,8 +1059,7 @@ TEST_F(QueryProcessorTest, OrTwoTermPrefixMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::PREFIX; EXPECT_THAT( @@ -998,8 +1080,10 @@ TEST_F(QueryProcessorTest, OrTwoTermPrefixMatch) { search_spec.set_query("he OR wo"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds std::vector<TermMatchInfo> matched_terms_stats; @@ -1064,8 +1148,7 @@ TEST_F(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; EXPECT_THAT(AddTokenToIndex(document_id1, section_id, TermMatchType::EXACT_ONLY, "hello"), @@ -1085,8 +1168,10 @@ TEST_F(QueryProcessorTest, OrTwoTermPrefixAndExactMatch) { search_spec.set_query("hello OR wo"); search_spec.set_term_match_type(TermMatchType::PREFIX); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds std::vector<TermMatchInfo> matched_terms_stats; @@ -1150,8 +1235,7 @@ TEST_F(QueryProcessorTest, CombinedAndOrTerms) { // Populate the index SectionId section_id = 0; SectionIdMask section_id_mask = 1U << section_id; - std::array<Hit::TermFrequency, kMaxSectionId> term_frequencies{ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + std::array<Hit::TermFrequency, kTotalNumSections> term_frequencies{1}; TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; // Document 1 has content "animal puppy dog" @@ -1163,6 +1247,7 @@ TEST_F(QueryProcessorTest, CombinedAndOrTerms) { IsOk()); EXPECT_THAT(AddTokenToIndex(document_id1, section_id, term_match_type, "dog"), IsOk()); + index_->Merge(); // Document 2 has content "animal kitten cat" EXPECT_THAT( @@ -1188,8 +1273,10 @@ TEST_F(QueryProcessorTest, CombinedAndOrTerms) { search_spec.set_query("puppy OR kitten dog"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Only Document 1 matches since it has puppy AND dog std::vector<TermMatchInfo> matched_terms_stats; @@ -1222,8 +1309,10 @@ TEST_F(QueryProcessorTest, CombinedAndOrTerms) { search_spec.set_query("animal puppy OR kitten"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Both Document 1 and 2 match since Document 1 has animal AND puppy, and // Document 2 has animal AND kitten @@ -1275,8 +1364,10 @@ TEST_F(QueryProcessorTest, CombinedAndOrTerms) { search_spec.set_query("kitten foo OR bar OR cat"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Only Document 2 matches since it has both kitten and cat std::vector<TermMatchInfo> matched_terms_stats; @@ -1365,12 +1456,14 @@ TEST_F(QueryProcessorTest, OneGroup) { search_spec.set_query("puppy OR (kitten foo)"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds DocHitInfo expectedDocHitInfo(document_id1); - expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -1441,14 +1534,16 @@ TEST_F(QueryProcessorTest, TwoGroups) { search_spec.set_query("(puppy dog) OR (kitten cat)"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds DocHitInfo expectedDocHitInfo1(document_id1); - expectedDocHitInfo1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0); DocHitInfo expectedDocHitInfo2(document_id2); - expectedDocHitInfo2.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo2.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo2, expectedDocHitInfo1)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -1519,12 +1614,14 @@ TEST_F(QueryProcessorTest, ManyLevelNestedGrouping) { search_spec.set_query("puppy OR ((((kitten foo))))"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds DocHitInfo expectedDocHitInfo(document_id1); - expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -1594,14 +1691,16 @@ TEST_F(QueryProcessorTest, OneLevelNestedGrouping) { search_spec.set_query("puppy OR (kitten(cat))"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds DocHitInfo expectedDocHitInfo1(document_id1); - expectedDocHitInfo1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0); DocHitInfo expectedDocHitInfo2(document_id2); - expectedDocHitInfo2.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo2.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo2, expectedDocHitInfo1)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -1663,8 +1762,10 @@ TEST_F(QueryProcessorTest, ExcludeTerm) { search_spec.set_query("-hello"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); // We don't know have the section mask to indicate what section "world" // came. It doesn't matter which section it was in since the query doesn't @@ -1727,8 +1828,10 @@ TEST_F(QueryProcessorTest, ExcludeNonexistentTerm) { search_spec.set_query("-foo"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); // Descending order of valid DocumentIds EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), @@ -1799,8 +1902,10 @@ TEST_F(QueryProcessorTest, ExcludeAnd) { search_spec.set_query("-dog -cat"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // The query is interpreted as "exclude all documents that have animal, // and exclude all documents that have cat". Since both documents contain @@ -1815,8 +1920,10 @@ TEST_F(QueryProcessorTest, ExcludeAnd) { search_spec.set_query("-animal cat"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // The query is interpreted as "exclude all documents that have animal, // and include all documents that have cat". Since both documents contain @@ -1889,8 +1996,10 @@ TEST_F(QueryProcessorTest, ExcludeOr) { search_spec.set_query("-animal OR -cat"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // We don't have a section mask indicating which sections in this document // matched the query since it's not based on section-term matching. It's @@ -1906,24 +2015,141 @@ TEST_F(QueryProcessorTest, ExcludeOr) { search_spec.set_query("animal OR -cat"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds DocHitInfo expectedDocHitInfo1(document_id1); - expectedDocHitInfo1.UpdateSection(/*section_id=*/0, - /*hit_term_frequency=*/1); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0); DocHitInfo expectedDocHitInfo2(document_id2); - expectedDocHitInfo2.UpdateSection(/*section_id=*/0, - /*hit_term_frequency=*/1); + expectedDocHitInfo2.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo2, expectedDocHitInfo1)); EXPECT_THAT(results.query_terms, SizeIs(1)); EXPECT_THAT(results.query_terms[""], UnorderedElementsAre("animal")); - EXPECT_THAT(results.query_term_iterators, SizeIs(1)); } } +TEST_F(QueryProcessorTest, WithoutTermFrequency) { + // Create the schema and document store + SchemaProto schema = SchemaBuilder() + .AddType(SchemaTypeConfigBuilder().SetType("email")) + .Build(); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); + ASSERT_THAT(schema_store_->SetSchema(schema), IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + + // These documents don't actually match to the tokens in the index. We're + // just inserting the documents so that the DocHitInfoIterators will see + // that the document exists and not filter out the DocumentId as deleted. + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id1, + document_store_->Put(DocumentBuilder() + .SetKey("namespace", "1") + .SetSchema("email") + .Build())); + ICING_ASSERT_OK_AND_ASSIGN(DocumentId document_id2, + document_store_->Put(DocumentBuilder() + .SetKey("namespace", "2") + .SetSchema("email") + .Build())); + + // Populate the index + SectionId section_id = 0; + SectionIdMask section_id_mask = 1U << section_id; + TermMatchType::Code term_match_type = TermMatchType::EXACT_ONLY; + + // Document 1 has content "animal puppy dog", which is added to the main + // index. + EXPECT_THAT( + AddTokenToIndex(document_id1, section_id, term_match_type, "animal"), + IsOk()); + EXPECT_THAT( + AddTokenToIndex(document_id1, section_id, term_match_type, "puppy"), + IsOk()); + EXPECT_THAT(AddTokenToIndex(document_id1, section_id, term_match_type, "dog"), + IsOk()); + ASSERT_THAT(index_->Merge(), IsOk()); + + // Document 2 has content "animal kitten cat", which is added to the lite + // index. + EXPECT_THAT( + AddTokenToIndex(document_id2, section_id, term_match_type, "animal"), + IsOk()); + EXPECT_THAT( + AddTokenToIndex(document_id2, section_id, term_match_type, "kitten"), + IsOk()); + EXPECT_THAT(AddTokenToIndex(document_id2, section_id, term_match_type, "cat"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<QueryProcessor> query_processor, + QueryProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get(), document_store_.get(), + schema_store_.get())); + + // OR gets precedence over AND, this is parsed as (animal AND (puppy OR + // kitten)) + SearchSpecProto search_spec; + search_spec.set_query("animal puppy OR kitten"); + search_spec.set_term_match_type(term_match_type); + + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); + // Since need_hit_term_frequency is false, the expected term frequencies + // should all be 0. + Hit::TermFrequencyArray exp_term_frequencies{0}; + + // Descending order of valid DocumentIds + // The first Document to match (Document 2) matches on 'animal' AND 'kitten' + std::vector<TermMatchInfo> matched_terms_stats; + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id2); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "animal"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(exp_term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "kitten"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(exp_term_frequencies)); + + // The second Document to match (Document 1) matches on 'animal' AND 'puppy' + matched_terms_stats.clear(); + ASSERT_THAT(results.root_iterator->Advance(), IsOk()); + EXPECT_EQ(results.root_iterator->doc_hit_info().document_id(), document_id1); + EXPECT_EQ(results.root_iterator->doc_hit_info().hit_section_ids_mask(), + section_id_mask); + results.root_iterator->PopulateMatchedTermsStats(&matched_terms_stats); + ASSERT_THAT(matched_terms_stats, SizeIs(2)); // 2 terms + EXPECT_EQ(matched_terms_stats.at(0).term, "animal"); + EXPECT_EQ(matched_terms_stats.at(0).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(0).term_frequencies, + ElementsAreArray(exp_term_frequencies)); + EXPECT_EQ(matched_terms_stats.at(1).term, "puppy"); + EXPECT_EQ(matched_terms_stats.at(1).section_ids_mask, section_id_mask); + EXPECT_THAT(matched_terms_stats.at(1).term_frequencies, + ElementsAreArray(exp_term_frequencies)); + + // This should be empty because ranking_strategy != RELEVANCE_SCORE + EXPECT_THAT(results.query_term_iterators, IsEmpty()); +} + TEST_F(QueryProcessorTest, DeletedFilter) { // Create the schema and document store SchemaProto schema = SchemaBuilder() @@ -1985,12 +2211,14 @@ TEST_F(QueryProcessorTest, DeletedFilter) { search_spec.set_query("animal"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds DocHitInfo expectedDocHitInfo(document_id2); - expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -2059,12 +2287,14 @@ TEST_F(QueryProcessorTest, NamespaceFilter) { search_spec.set_term_match_type(term_match_type); search_spec.add_namespace_filters("namespace1"); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds DocHitInfo expectedDocHitInfo(document_id1); - expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -2131,12 +2361,14 @@ TEST_F(QueryProcessorTest, SchemaTypeFilter) { search_spec.set_term_match_type(term_match_type); search_spec.add_schema_type_filters("email"); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds DocHitInfo expectedDocHitInfo(document_id1); - expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -2197,12 +2429,14 @@ TEST_F(QueryProcessorTest, SectionFilterForOneDocument) { search_spec.set_query("subject:animal"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Descending order of valid DocumentIds DocHitInfo expectedDocHitInfo(document_id); - expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -2289,15 +2523,17 @@ TEST_F(QueryProcessorTest, SectionFilterAcrossSchemaTypes) { search_spec.set_query("foo:animal"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Ordered by descending DocumentId, so message comes first since it was // inserted last DocHitInfo expectedDocHitInfo1(message_document_id); - expectedDocHitInfo1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0); DocHitInfo expectedDocHitInfo2(email_document_id); - expectedDocHitInfo2.UpdateSection(/*section_id=*/1, /*hit_term_frequency=*/1); + expectedDocHitInfo2.UpdateSection(/*section_id=*/1); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo1, expectedDocHitInfo2)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -2374,13 +2610,15 @@ TEST_F(QueryProcessorTest, SectionFilterWithinSchemaType) { search_spec.add_schema_type_filters("email"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Shouldn't include the message document since we're only looking at email // types DocHitInfo expectedDocHitInfo(email_document_id); - expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -2459,13 +2697,15 @@ TEST_F(QueryProcessorTest, SectionFilterRespectsDifferentSectionIds) { search_spec.set_query("foo:animal"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Even though the section id is the same, we should be able to tell that it // doesn't match to the name of the section filter DocHitInfo expectedDocHitInfo(email_document_id); - expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); EXPECT_THAT(results.query_terms, SizeIs(1)); @@ -2520,8 +2760,10 @@ TEST_F(QueryProcessorTest, NonexistentSectionFilterReturnsEmptyResults) { search_spec.set_query("nonexistent:animal"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Even though the section id is the same, we should be able to tell that it // doesn't match to the name of the section filter @@ -2587,8 +2829,10 @@ TEST_F(QueryProcessorTest, UnindexedSectionFilterReturnsEmptyResults) { search_spec.set_query("foo:animal"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Even though the section id is the same, we should be able to tell that it // doesn't match to the name of the section filter @@ -2668,15 +2912,17 @@ TEST_F(QueryProcessorTest, SectionFilterTermAndUnrestrictedTerm) { search_spec.set_query("cat OR foo:animal"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch( + search_spec, ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE)); // Ordered by descending DocumentId, so message comes first since it was // inserted last DocHitInfo expectedDocHitInfo1(message_document_id); - expectedDocHitInfo1.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo1.UpdateSection(/*section_id=*/0); DocHitInfo expectedDocHitInfo2(email_document_id); - expectedDocHitInfo2.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo2.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo1, expectedDocHitInfo2)); EXPECT_THAT(results.query_terms, SizeIs(2)); @@ -2735,11 +2981,13 @@ TEST_F(QueryProcessorTest, DocumentBeforeTtlNotFilteredOut) { search_spec.set_query("hello"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); DocHitInfo expectedDocHitInfo(document_id); - expectedDocHitInfo.UpdateSection(/*section_id=*/0, /*hit_term_frequency=*/1); + expectedDocHitInfo.UpdateSection(/*section_id=*/0); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), ElementsAre(expectedDocHitInfo)); } @@ -2794,8 +3042,10 @@ TEST_F(QueryProcessorTest, DocumentPastTtlFilteredOut) { search_spec.set_query("hello"); search_spec.set_term_match_type(term_match_type); - ICING_ASSERT_OK_AND_ASSIGN(QueryProcessor::QueryResults results, - query_processor->ParseSearch(search_spec)); + ICING_ASSERT_OK_AND_ASSIGN( + QueryProcessor::QueryResults results, + query_processor->ParseSearch(search_spec, + ScoringSpecProto::RankingStrategy::NONE)); EXPECT_THAT(GetDocHitInfos(results.root_iterator.get()), IsEmpty()); } diff --git a/icing/query/query-terms.h b/icing/query/query-terms.h index 1c5ce02..b186218 100644 --- a/icing/query/query-terms.h +++ b/icing/query/query-terms.h @@ -26,7 +26,7 @@ namespace lib { // A map from section names to sets of terms restricted to those sections. // Query terms that are not restricted are found at the entry with key "". using SectionRestrictQueryTermsMap = - std::unordered_map<std::string_view, std::unordered_set<std::string>>; + std::unordered_map<std::string, std::unordered_set<std::string>>; } // namespace lib } // namespace icing diff --git a/icing/result/projection-tree.cc b/icing/result/projection-tree.cc index 67617a3..3347439 100644 --- a/icing/result/projection-tree.cc +++ b/icing/result/projection-tree.cc @@ -41,7 +41,7 @@ ProjectionTree::Node* ProjectionTree::AddChildNode( if (itr != current_children->end()) { return &(*itr); } - current_children->push_back(ProjectionTree::Node(property_name)); + current_children->push_back(ProjectionTree::Node(std::string(property_name))); return ¤t_children->back(); } diff --git a/icing/result/projection-tree.h b/icing/result/projection-tree.h index 8e38aaf..5916fe6 100644 --- a/icing/result/projection-tree.h +++ b/icing/result/projection-tree.h @@ -28,10 +28,9 @@ class ProjectionTree { static constexpr std::string_view kSchemaTypeWildcard = "*"; struct Node { - explicit Node(std::string_view name = "") : name(name) {} + explicit Node(std::string name = "") : name(std::move(name)) {} - // TODO: change string_view to string - std::string_view name; + std::string name; std::vector<Node> children; bool operator==(const Node& other) const { diff --git a/icing/result/result-retriever-v2_projection_test.cc b/icing/result/result-retriever-v2_projection_test.cc index bdd1715..cb0de0b 100644 --- a/icing/result/result-retriever-v2_projection_test.cc +++ b/icing/result/result-retriever-v2_projection_test.cc @@ -165,7 +165,7 @@ class ResultRetrieverV2ProjectionTest : public testing::Test { SectionIdMask CreateSectionIdMask(const std::vector<SectionId>& section_ids) { SectionIdMask mask = 0; for (SectionId section_id : section_ids) { - mask |= (1u << section_id); + mask |= (UINT64_C(1) << section_id); } return mask; } diff --git a/icing/result/result-retriever-v2_snippet_test.cc b/icing/result/result-retriever-v2_snippet_test.cc index afb31cf..0643e9b 100644 --- a/icing/result/result-retriever-v2_snippet_test.cc +++ b/icing/result/result-retriever-v2_snippet_test.cc @@ -188,7 +188,7 @@ DocumentProto CreateDocument(int id) { SectionIdMask CreateSectionIdMask(const std::vector<SectionId>& section_ids) { SectionIdMask mask = 0; for (SectionId section_id : section_ids) { - mask |= (1u << section_id); + mask |= (UINT64_C(1) << section_id); } return mask; } diff --git a/icing/result/result-retriever-v2_test.cc b/icing/result/result-retriever-v2_test.cc index 0998754..1ac56ff 100644 --- a/icing/result/result-retriever-v2_test.cc +++ b/icing/result/result-retriever-v2_test.cc @@ -196,7 +196,7 @@ DocumentProto CreateDocument(int id) { SectionIdMask CreateSectionIdMask(const std::vector<SectionId>& section_ids) { SectionIdMask mask = 0; for (SectionId section_id : section_ids) { - mask |= (1u << section_id); + mask |= (UINT64_C(1) << section_id); } return mask; } diff --git a/icing/result/result-retriever_test.cc b/icing/result/result-retriever_test.cc index 0d812e4..1b2b359 100644 --- a/icing/result/result-retriever_test.cc +++ b/icing/result/result-retriever_test.cc @@ -177,7 +177,7 @@ DocumentProto CreateDocument(int id) { SectionIdMask CreateSectionIdMask(const std::vector<SectionId>& section_ids) { SectionIdMask mask = 0; for (SectionId section_id : section_ids) { - mask |= (1u << section_id); + mask |= (UINT64_C(1) << section_id); } return mask; } diff --git a/icing/result/result-state-manager.h b/icing/result/result-state-manager.h index 0684864..e2bc797 100644 --- a/icing/result/result-state-manager.h +++ b/icing/result/result-state-manager.h @@ -102,6 +102,8 @@ class ResultStateManager { // Invalidates all result states / tokens currently in ResultStateManager. void InvalidateAllResultStates() ICING_LOCKS_EXCLUDED(mutex_); + int num_total_hits() const { return num_total_hits_; } + private: absl_ports::shared_mutex mutex_; diff --git a/icing/result/result-state-manager_thread-safety_test.cc b/icing/result/result-state-manager_thread-safety_test.cc new file mode 100644 index 0000000..523f84a --- /dev/null +++ b/icing/result/result-state-manager_thread-safety_test.cc @@ -0,0 +1,451 @@ +// Copyright (C) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <algorithm> +#include <optional> +#include <thread> // NOLINT + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/document-builder.h" +#include "icing/file/filesystem.h" +#include "icing/portable/equals-proto.h" +#include "icing/result/page-result.h" +#include "icing/result/result-retriever-v2.h" +#include "icing/result/result-state-manager.h" +#include "icing/schema/schema-store.h" +#include "icing/scoring/priority-queue-scored-document-hits-ranker.h" +#include "icing/scoring/scored-document-hits-ranker.h" +#include "icing/store/document-store.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/icu-data-file-helper.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/transform/normalizer-factory.h" +#include "icing/transform/normalizer.h" +#include "icing/util/clock.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { +namespace { + +using ::testing::Eq; +using ::testing::Ge; +using ::testing::Not; +using ::testing::SizeIs; +using PageResultInfo = std::pair<uint64_t, PageResult>; + +ScoringSpecProto CreateScoringSpec() { + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); + return scoring_spec; +} + +ResultSpecProto CreateResultSpec(int num_per_page) { + ResultSpecProto result_spec; + result_spec.set_num_per_page(num_per_page); + return result_spec; +} + +DocumentProto CreateDocument(int document_id) { + return DocumentBuilder() + .SetNamespace("namespace") + .SetUri(std::to_string(document_id)) + .SetSchema("Document") + .SetCreationTimestampMs(1574365086666 + document_id) + .SetScore(document_id) + .Build(); +} + +class ResultStateManagerThreadSafetyTest : public testing::Test { + protected: + ResultStateManagerThreadSafetyTest() + : test_dir_(GetTestTempDir() + "/icing") { + filesystem_.CreateDirectoryRecursively(test_dir_.c_str()); + } + + void SetUp() override { + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + clock_ = std::make_unique<FakeClock>(); + + language_segmenter_factory::SegmenterOptions options(ULOC_US); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(std::move(options))); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, clock_.get())); + SchemaProto schema; + schema.add_types()->set_schema_type("Document"); + ICING_ASSERT_OK(schema_store_->SetSchema(std::move(schema))); + + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( + /*max_term_byte_size=*/10000)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult result, + DocumentStore::Create(&filesystem_, test_dir_, clock_.get(), + schema_store_.get())); + document_store_ = std::move(result.document_store); + + ICING_ASSERT_OK_AND_ASSIGN( + result_retriever_, ResultRetrieverV2::Create( + document_store_.get(), schema_store_.get(), + language_segmenter_.get(), normalizer_.get())); + } + + void TearDown() override { + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + clock_.reset(); + } + + Filesystem filesystem_; + const std::string test_dir_; + std::unique_ptr<FakeClock> clock_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<Normalizer> normalizer_; + std::unique_ptr<DocumentStore> document_store_; + std::unique_ptr<ResultRetrieverV2> result_retriever_; +}; + +TEST_F(ResultStateManagerThreadSafetyTest, + RequestSameResultStateSimultaneously) { + // Create several threads to send GetNextPage requests with the same + // ResultState. + // + // This test verifies the usage of ResultState per instance lock. Only one + // thread is allowed to access ResultState, so there should be no crash and + // the result documents in a single page should be continuous (i.e. no + // interleaf). + + // Prepare documents. + constexpr int kNumDocuments = 10000; + std::vector<ScoredDocumentHit> scored_document_hits; + for (int i = 0; i < kNumDocuments; ++i) { + // Put a document with id and score = i. + ICING_ASSERT_OK(document_store_->Put(CreateDocument(/*document_id=*/i))); + scored_document_hits.push_back( + ScoredDocumentHit(/*document_id=*/i, kSectionIdMaskNone, /*score=*/i)); + } + + constexpr int kNumPerPage = 100; + ResultStateManager result_state_manager(/*max_total_hits=*/kNumDocuments, + *document_store_, clock_.get()); + + // Retrieve the first page. + // Documents are ordered by score *ascending*, so the first page should + // contain documents with scores [0, 1, 2, ..., kNumPerPage - 1]. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(kNumPerPage), *document_store_, + *result_retriever_)); + ASSERT_THAT(page_result_info1.second.results, SizeIs(kNumPerPage)); + for (int i = 0; i < kNumPerPage; ++i) { + ASSERT_THAT(page_result_info1.second.results[i].score(), Eq(i)); + } + + uint64_t next_page_token = page_result_info1.first; + ASSERT_THAT(next_page_token, Not(Eq(kInvalidNextPageToken))); + + // Create kNumThreads threads to call GetNextPage() with the same token at the + // same time. Each thread should get a valid result. + // Use page_results to store the result. + constexpr int kNumThreads = 50; + std::vector<std::optional<PageResultInfo>> page_results(kNumThreads); + auto callable = [&](int thread_id) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), + normalizer_.get())); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info, + result_state_manager.GetNextPage(next_page_token, *result_retriever)); + page_results[thread_id] = + std::make_optional<PageResultInfo>(std::move(page_result_info)); + }; + + // Spawn threads for GetNextPage(). + std::vector<std::thread> thread_objs; + for (int i = 0; i < kNumThreads; ++i) { + thread_objs.emplace_back(callable, /*thread_id=*/i); + } + + // Join threads. + for (int i = 0; i < kNumThreads; ++i) { + thread_objs[i].join(); + EXPECT_THAT(page_results[i], Not(Eq(std::nullopt))); + EXPECT_THAT(page_results[i]->second.results, SizeIs(kNumPerPage)); + } + + // Since we have per instance lock for ResultState, only one thread is allowed + // to access ResultState at a moment. Therefore, every thread should get + // continuous scores instead of interleaved scores, regardless of the + // execution order. IOW, within a particular page the scores of all results + // should be ordered as: [N, N+1, N+2, N+3, ...] where N is dependent on the + // execution order. Also there should be no crash. + std::vector<int> first_doc_scores; + for (const auto& page_result_info : page_results) { + first_doc_scores.push_back(page_result_info->second.results[0].score()); + for (int i = 1; i < kNumPerPage; ++i) { + EXPECT_THAT(page_result_info->second.results[i].score(), + Eq(page_result_info->second.results[i - 1].score() + 1)); + } + } + + // Verify all first doc scores of page results are correct. Should be + // kNumPerPage * 1, kNumPerPage * 2, ..., etc. + // Note: the first score of the first page retrieved via GetNextPage should be + // kNumPerPage because the *actual* first page with first score = 0 was + // retrieved during CacheAndRetrieveFirstPage. + std::sort(first_doc_scores.begin(), first_doc_scores.end()); + for (int i = 0; i < kNumThreads; ++i) { + EXPECT_THAT(first_doc_scores[i], Eq(kNumPerPage * (i + 1))); + } +} + +TEST_F(ResultStateManagerThreadSafetyTest, InvalidateResultStateWhileUsing) { + // Create several threads to send GetNextPage requests with the same + // ResultState and another single thread to invalidate this ResultState. + // + // This test verifies the usage of std::shared_ptr. Even after invalidating + // the original copy of std::shared_ptr in the cache, the ResultState instance + // should be still valid and no crash should occur in threads that are still + // holding a copy of std::shared_ptr pointing to the same ResultState + // instance. + + // Prepare documents. + constexpr int kNumDocuments = 10000; + std::vector<ScoredDocumentHit> scored_document_hits; + for (int i = 0; i < kNumDocuments; ++i) { + // Put a document with id and score = i. + ICING_ASSERT_OK(document_store_->Put(CreateDocument(/*document_id=*/i))); + scored_document_hits.push_back( + ScoredDocumentHit(/*document_id=*/i, kSectionIdMaskNone, /*score=*/i)); + } + + constexpr int kNumPerPage = 100; + ResultStateManager result_state_manager(/*max_total_hits=*/kNumDocuments, + *document_store_, clock_.get()); + + // Retrieve the first page. + // Documents are ordered by score *ascending*, so the first page should + // contain documents with scores [0, 1, 2, ..., kNumPerPage - 1]. + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits), /*is_descending=*/false), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(kNumPerPage), *document_store_, + *result_retriever_)); + ASSERT_THAT(page_result_info1.second.results, SizeIs(kNumPerPage)); + for (int i = 0; i < kNumPerPage; ++i) { + ASSERT_THAT(page_result_info1.second.results[i].score(), Eq(i)); + } + + uint64_t next_page_token = page_result_info1.first; + ASSERT_THAT(next_page_token, Not(Eq(kInvalidNextPageToken))); + + // Create kNumThreads threads to call GetNextPage() with the same token at the + // same time. The ResultState might have been invalidated, so it is normal to + // get NOT_FOUND error. + // Use page_results to store the result. + constexpr int kNumThreads = 50; + std::vector<std::optional<PageResultInfo>> page_results(kNumThreads); + auto callable = [&](int thread_id) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), + normalizer_.get())); + + libtextclassifier3::StatusOr<PageResultInfo> page_result_info_or = + result_state_manager.GetNextPage(next_page_token, *result_retriever); + if (page_result_info_or.ok()) { + page_results[thread_id] = std::make_optional<PageResultInfo>( + std::move(page_result_info_or).ValueOrDie()); + } else { + EXPECT_THAT(page_result_info_or, + StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); + } + }; + + // Spawn threads for GetNextPage(). + std::vector<std::thread> thread_objs; + for (int i = 0; i < kNumThreads; ++i) { + thread_objs.emplace_back(callable, /*thread_id=*/i); + } + + // Spawn another single thread to invalidate the ResultState. + std::thread invalidating_thread([&]() -> void { + result_state_manager.InvalidateResultState(next_page_token); + }); + + // Join threads. + for (int i = 0; i < kNumThreads; ++i) { + thread_objs[i].join(); + if (page_results[i] != std::nullopt) { + EXPECT_THAT(page_results[i]->second.results, SizeIs(kNumPerPage)); + } + } + invalidating_thread.join(); + + // Threads fetching ResultState before invalidation will get normal results, + // while others will get NOT_FOUND error. + std::vector<int> first_doc_scores; + for (const auto& page_result_info : page_results) { + if (page_result_info == std::nullopt) { + continue; + } + + first_doc_scores.push_back(page_result_info->second.results[0].score()); + for (int i = 1; i < kNumPerPage; ++i) { + EXPECT_THAT(page_result_info->second.results[i].score(), + Eq(page_result_info->second.results[i - 1].score() + 1)); + } + } + + // Verify all first doc scores of page results are correct. Should be + // kNumPerPage * 1, kNumPerPage * 2, ..., etc. + std::sort(first_doc_scores.begin(), first_doc_scores.end()); + for (int i = 0; i < first_doc_scores.size(); ++i) { + EXPECT_THAT(first_doc_scores[i], Eq(kNumPerPage * (i + 1))); + } + + // Verify num_total_hits should be decremented correctly. + EXPECT_THAT(result_state_manager.num_total_hits(), Eq(0)); +} + +TEST_F(ResultStateManagerThreadSafetyTest, MultipleResultStates) { + // Create several threads to send GetNextPage requests with different + // ResultStates. + // + // This test verifies each ResultState should work independently and correctly + // with each thread. Also it verifies there should be no race condition for + // num_total_hits, which will be incremented/decremented by multiple threads. + + // Prepare documents. + constexpr int kNumDocuments = 2000; + std::vector<ScoredDocumentHit> scored_document_hits; + for (int i = 0; i < kNumDocuments; ++i) { + // Put a document with id and score = i. + ICING_ASSERT_OK(document_store_->Put(CreateDocument(/*document_id=*/i))); + scored_document_hits.push_back( + ScoredDocumentHit(/*document_id=*/i, kSectionIdMaskNone, /*score=*/i)); + } + + constexpr int kNumThreads = 50; + constexpr int kNumPerPage = 30; + ResultStateManager result_state_manager( + /*max_total_hits=*/kNumDocuments * kNumThreads, *document_store_, + clock_.get()); + + // Create kNumThreads threads to: + // - Call CacheAndRetrieveFirstPage() once to create its own ResultState. + // - Call GetNextPage() on its own ResultState for thread_id times. + // + // Each thread will get (thread_id + 1) pages, i.e. kNumPerPage * + // (thread_id + 1) docs. + ASSERT_THAT(kNumDocuments, Ge(kNumPerPage * kNumThreads)); + auto callable = [&](int thread_id) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ResultRetrieverV2> result_retriever, + ResultRetrieverV2::Create(document_store_.get(), schema_store_.get(), + language_segmenter_.get(), + normalizer_.get())); + + // Retrieve the first page. + // Documents are ordered by score *ascending*, so the first page should + // contain documents with scores [0, 1, 2, ..., kNumPerPage - 1]. + std::vector<ScoredDocumentHit> scored_document_hits_copy( + scored_document_hits); + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info1, + result_state_manager.CacheAndRetrieveFirstPage( + std::make_unique<PriorityQueueScoredDocumentHitsRanker>( + std::move(scored_document_hits_copy), /*is_descending=*/false), + /*query_terms=*/{}, SearchSpecProto::default_instance(), + CreateScoringSpec(), CreateResultSpec(kNumPerPage), + *document_store_, *result_retriever)); + EXPECT_THAT(page_result_info1.second.results, SizeIs(kNumPerPage)); + for (int i = 0; i < kNumPerPage; ++i) { + EXPECT_THAT(page_result_info1.second.results[i].score(), Eq(i)); + } + + uint64_t next_page_token = page_result_info1.first; + ASSERT_THAT(next_page_token, Not(Eq(kInvalidNextPageToken))); + + // Retrieve some of the subsequent pages. We use thread_id as how many + // subsequent pages should be retrieved (how many times GetNextPage should + // be called) for each thread in order to: + // - Vary the number of pages that we're retrieving in each thread. + // - Still make the total number of hits remaining (num_total_hits) a + // predictable number. + // Then, including the first page (retrieved by CacheAndRetrieveFirstPage), + // each thread should retrieve 1, 2, 3, ..., kNumThreads pages. + int num_subsequent_pages_to_retrieve = thread_id; + for (int i = 0; i < num_subsequent_pages_to_retrieve; ++i) { + ICING_ASSERT_OK_AND_ASSIGN( + PageResultInfo page_result_info, + result_state_manager.GetNextPage(next_page_token, *result_retriever)); + EXPECT_THAT(page_result_info.second.results, SizeIs(kNumPerPage)); + for (int j = 0; j < kNumPerPage; ++j) { + EXPECT_THAT(page_result_info.second.results[j].score(), + Eq(kNumPerPage * (i + 1) + j)); + } + } + }; + + // Spawn threads. + std::vector<std::thread> thread_objs; + for (int i = 0; i < kNumThreads; ++i) { + thread_objs.emplace_back(callable, /*thread_id=*/i); + } + + // Join threads. + for (int i = 0; i < kNumThreads; ++i) { + thread_objs[i].join(); + } + + // There will be kNumThreads * kNumDocuments ScoredDocumentHits being created + // in the beginning, and kNumPerPage * (1 + 2 + ... + kNumThreads) docs should + // be returned after retrieval, since each thread should retrieve 1, 2, 3, + // ..., kNumThreads pages. Thus, all retrieved ScoredDocumentHits should be + // removed from the cache and num_total_hits should be decremented correctly. + int expected_remaining_hits = + kNumThreads * kNumDocuments - + kNumPerPage * (kNumThreads * (kNumThreads + 1) / 2); + EXPECT_THAT(result_state_manager.num_total_hits(), + Eq(expected_remaining_hits)); +} + +} // namespace +} // namespace lib +} // namespace icing diff --git a/icing/result/snippet-retriever.cc b/icing/result/snippet-retriever.cc index 2391900..ff9d6c5 100644 --- a/icing/result/snippet-retriever.cc +++ b/icing/result/snippet-retriever.cc @@ -613,9 +613,9 @@ SnippetProto SnippetRetriever::RetrieveSnippet( const std::unordered_set<std::string>& unrestricted_set = (itr != query_terms.end()) ? itr->second : empty_set; while (section_id_mask != kSectionIdMaskNone) { - SectionId section_id = __builtin_ctz(section_id_mask); + SectionId section_id = __builtin_ctzll(section_id_mask); // Remove this section from the mask. - section_id_mask &= ~(1u << section_id); + section_id_mask &= ~(UINT64_C(1) << section_id); MatchOptions match_options = {snippet_spec}; match_options.max_matches_remaining = diff --git a/icing/schema/schema-store.cc b/icing/schema/schema-store.cc index 653f34f..7af7351 100644 --- a/icing/schema/schema-store.cc +++ b/icing/schema/schema-store.cc @@ -563,7 +563,7 @@ SchemaStoreStorageInfoProto SchemaStore::GetStorageInfo() const { continue; } total_sections += sections_list_or.ValueOrDie()->size(); - if (sections_list_or.ValueOrDie()->size() == kMaxSectionId + 1) { + if (sections_list_or.ValueOrDie()->size() == kTotalNumSections) { ++num_types_sections_exhausted; } } diff --git a/icing/schema/schema-store_test.cc b/icing/schema/schema-store_test.cc index ffd1292..ae84358 100644 --- a/icing/schema/schema-store_test.cc +++ b/icing/schema/schema-store_test.cc @@ -1099,36 +1099,24 @@ TEST_F(SchemaStoreTest, SchemaStoreStorageInfoProto) { SchemaStore::Create(&filesystem_, schema_store_dir_, &fake_clock_)); // Create a schema with two types: one simple type and one type that uses all - // 16 sections. + // 64 sections. PropertyConfigProto prop = PropertyConfigBuilder() .SetName("subject") .SetDataTypeString(MATCH_EXACT, TOKENIZER_PLAIN) .SetCardinality(CARDINALITY_OPTIONAL) .Build(); + SchemaTypeConfigBuilder full_sections_type_builder = + SchemaTypeConfigBuilder().SetType("fullSectionsType"); + for (int i = 0; i < 64; ++i) { + full_sections_type_builder.AddProperty( + PropertyConfigBuilder(prop).SetName("prop" + std::to_string(i))); + } SchemaProto schema = SchemaBuilder() .AddType(SchemaTypeConfigBuilder().SetType("email").AddProperty( PropertyConfigBuilder(prop))) - .AddType( - SchemaTypeConfigBuilder() - .SetType("fullSectionsType") - .AddProperty(PropertyConfigBuilder(prop).SetName("prop0")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop1")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop2")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop3")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop4")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop5")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop6")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop7")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop8")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop9")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop10")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop11")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop12")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop13")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop14")) - .AddProperty(PropertyConfigBuilder(prop).SetName("prop15"))) + .AddType(full_sections_type_builder) .Build(); SchemaStore::SetSchemaResult result; @@ -1141,7 +1129,7 @@ TEST_F(SchemaStoreTest, SchemaStoreStorageInfoProto) { SchemaStoreStorageInfoProto storage_info = schema_store->GetStorageInfo(); EXPECT_THAT(storage_info.schema_store_size(), Ge(0)); EXPECT_THAT(storage_info.num_schema_types(), Eq(2)); - EXPECT_THAT(storage_info.num_total_sections(), Eq(17)); + EXPECT_THAT(storage_info.num_total_sections(), Eq(65)); EXPECT_THAT(storage_info.num_schema_types_sections_exhausted(), Eq(1)); } diff --git a/icing/schema/section.h b/icing/schema/section.h index 8b2ba55..34c8c58 100644 --- a/icing/schema/section.h +++ b/icing/schema/section.h @@ -28,17 +28,17 @@ namespace icing { namespace lib { using SectionId = int8_t; -// 4 bits for 16 values. NOTE: Increasing this value means that SectionIdMask -// must increase from an int16_t to an int32_t -inline constexpr int kSectionIdBits = 4; -inline constexpr SectionId kInvalidSectionId = (1 << kSectionIdBits); -inline constexpr SectionId kMaxSectionId = kInvalidSectionId - 1; +// 6 bits for 64 values. +inline constexpr int kSectionIdBits = 6; +inline constexpr SectionId kTotalNumSections = (1 << kSectionIdBits); +inline constexpr SectionId kInvalidSectionId = kTotalNumSections; +inline constexpr SectionId kMaxSectionId = kTotalNumSections - 1; inline constexpr SectionId kMinSectionId = 0; constexpr bool IsSectionIdValid(SectionId section_id) { return section_id >= kMinSectionId && section_id <= kMaxSectionId; } -using SectionIdMask = int16_t; +using SectionIdMask = int64_t; inline constexpr SectionIdMask kSectionIdMaskAll = ~SectionIdMask{0}; inline constexpr SectionIdMask kSectionIdMaskNone = SectionIdMask{0}; diff --git a/icing/scoring/bm25f-calculator.cc b/icing/scoring/bm25f-calculator.cc index 28ee2ba..f169cda 100644 --- a/icing/scoring/bm25f-calculator.cc +++ b/icing/scoring/bm25f-calculator.cc @@ -115,8 +115,8 @@ float Bm25fCalculator::ComputeScore(const DocHitInfoIterator* query_it, score += idf_weight * normalized_tf; } - ICING_VLOG(1) << "BM25F: corpus_id:" << data.corpus_id() << " docid:" - << hit_info.document_id() << " score:" << score; + ICING_VLOG(1) << "BM25F: corpus_id:" << data.corpus_id() + << " docid:" << hit_info.document_id() << " score:" << score; return score; } @@ -152,8 +152,8 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, float idf = nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi + 0.5f)) : 0.0f; corpus_idf_map_.insert({corpus_term_info.value, idf}); - ICING_VLOG(1) << "corpus_id:" << corpus_id << " term:" - << term << " N:" << num_docs << "nqi:" << nqi << " idf:" << idf; + ICING_VLOG(1) << "corpus_id:" << corpus_id << " term:" << term + << " N:" << num_docs << "nqi:" << nqi << " idf:" << idf; return idf; } @@ -200,9 +200,10 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency( float normalized_tf = f_q * (k1_ + 1) / (f_q + k1_ * (1 - b_ + b_ * dl / avgdl)); - ICING_VLOG(1) << "corpus_id:" << data.corpus_id() << " docid:" - << hit_info.document_id() << " dl:" << dl << " avgdl:" << avgdl << " f_q:" - << f_q << " norm_tf:" << normalized_tf; + ICING_VLOG(1) << "corpus_id:" << data.corpus_id() + << " docid:" << hit_info.document_id() << " dl:" << dl + << " avgdl:" << avgdl << " f_q:" << f_q + << " norm_tf:" << normalized_tf; return normalized_tf; } @@ -214,8 +215,8 @@ float Bm25fCalculator::ComputeTermFrequencyForMatchedSections( SchemaTypeId schema_type_id = GetSchemaTypeId(document_id); while (sections != 0) { - SectionId section_id = __builtin_ctz(sections); - sections &= ~(1u << section_id); + SectionId section_id = __builtin_ctzll(sections); + sections &= ~(UINT64_C(1) << section_id); Hit::TermFrequency tf = term_match_info.term_frequencies[section_id]; double weighted_tf = tf * section_weights_->GetNormalizedSectionWeight( @@ -236,7 +237,7 @@ SchemaTypeId Bm25fCalculator::GetSchemaTypeId(DocumentId document_id) const { // allocated document_ids, which shouldn't be possible since we're getting // this document_id from the posting lists. ICING_LOG(WARNING) << "No document filter data for document [" - << document_id << "]"; + << document_id << "]"; return kInvalidSchemaTypeId; } return filter_data_optional.value().schema_type_id(); diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc index 44dda3c..bf12f96 100644 --- a/icing/scoring/score-and-rank_benchmark.cc +++ b/icing/scoring/score-and-rank_benchmark.cc @@ -59,7 +59,8 @@ // $ adb push blaze-bin/icing/scoring/score-and-rank_benchmark // /data/local/tmp/ // -// $ adb shell /data/local/tmp/score-and-rank_benchmark --benchmark_filter=all +// $ adb shell /data/local/tmp/score-and-rank_benchmark +// --benchmark_filter=all namespace icing { namespace lib { @@ -436,7 +437,7 @@ void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) { SectionIdMask section_id_mask = 1U << section_id; // Puts documents into document store - std::vector<DocHitInfo> doc_hit_infos; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos; for (int i = 0; i < num_of_documents; i++) { ICING_ASSERT_OK_AND_ASSIGN( DocumentId document_id, @@ -444,7 +445,8 @@ void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) { /*id=*/i, /*document_score=*/1, /*creation_timestamp_ms=*/1), /*num_tokens=*/10)); - DocHitInfo doc_hit = DocHitInfo(document_id, section_id_mask); + DocHitInfoTermFrequencyPair doc_hit = + DocHitInfo(document_id, section_id_mask); // Set five matches for term "foo" for each document hit. doc_hit.UpdateSection(section_id, /*hit_term_frequency=*/5); doc_hit_infos.push_back(doc_hit); diff --git a/icing/scoring/scored-document-hit.h b/icing/scoring/scored-document-hit.h index c2e51b8..079ba7e 100644 --- a/icing/scoring/scored-document-hit.h +++ b/icing/scoring/scored-document-hit.h @@ -53,8 +53,8 @@ class ScoredDocumentHit { double score_; } __attribute__((packed)); -static_assert(sizeof(ScoredDocumentHit) == 14, - "Size of ScoredDocHit should be 14"); +static_assert(sizeof(ScoredDocumentHit) == 20, + "Size of ScoredDocHit should be 20"); static_assert(icing_is_packed_pod<ScoredDocumentHit>::value, "go/icing-ubsan"); // A custom comparator for ScoredDocumentHit that determines which diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc index b42ba31..9ca7dfd 100644 --- a/icing/scoring/scoring-processor_test.cc +++ b/icing/scoring/scoring-processor_test.cc @@ -329,19 +329,19 @@ TEST_F(ScoringProcessorTest, DocumentId document_id3, document_store()->Put(document3, /*num_tokens=*/50)); - DocHitInfo doc_hit_info1(document_id1); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(document_id1); doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); - DocHitInfo doc_hit_info2(document_id2); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(document_id2); doc_hit_info2.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); - DocHitInfo doc_hit_info3(document_id3); + DocHitInfoTermFrequencyPair doc_hit_info3 = DocHitInfo(document_id3); doc_hit_info3.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; + SectionIdMask section_id_mask = UINT64_C(1) << section_id; // Creates input doc_hit_infos and expected output scored_document_hits - std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2, - doc_hit_info3}; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos = { + doc_hit_info1, doc_hit_info2, doc_hit_info3}; // Creates a dummy DocHitInfoIterator with 3 results for the query "foo" std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = @@ -398,19 +398,19 @@ TEST_F(ScoringProcessorTest, DocumentId document_id3, document_store()->Put(document3, /*num_tokens=*/10)); - DocHitInfo doc_hit_info1(document_id1); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(document_id1); doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); - DocHitInfo doc_hit_info2(document_id2); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(document_id2); doc_hit_info2.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); - DocHitInfo doc_hit_info3(document_id3); + DocHitInfoTermFrequencyPair doc_hit_info3 = DocHitInfo(document_id3); doc_hit_info3.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); SectionId section_id = 0; - SectionIdMask section_id_mask = 1U << section_id; + SectionIdMask section_id_mask = UINT64_C(1) << section_id; // Creates input doc_hit_infos and expected output scored_document_hits - std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2, - doc_hit_info3}; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos = { + doc_hit_info1, doc_hit_info2, doc_hit_info3}; // Creates a dummy DocHitInfoIterator with 3 results for the query "foo" std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = @@ -466,13 +466,13 @@ TEST_F(ScoringProcessorTest, DocumentId document_id3, document_store()->Put(document3, /*num_tokens=*/10)); - DocHitInfo doc_hit_info1(document_id1); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(document_id1); // Document 1 contains the query term "foo" 5 times doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/5); - DocHitInfo doc_hit_info2(document_id2); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(document_id2); // Document 1 contains the query term "foo" 1 time doc_hit_info2.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); - DocHitInfo doc_hit_info3(document_id3); + DocHitInfoTermFrequencyPair doc_hit_info3 = DocHitInfo(document_id3); // Document 1 contains the query term "foo" 3 times doc_hit_info3.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/1); doc_hit_info3.UpdateSection(/*section_id*/ 1, /*hit_term_frequency=*/2); @@ -482,8 +482,8 @@ TEST_F(ScoringProcessorTest, SectionIdMask section_id_mask3 = 0b00000011; // Creates input doc_hit_infos and expected output scored_document_hits - std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2, - doc_hit_info3}; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos = { + doc_hit_info1, doc_hit_info2, doc_hit_info3}; // Creates a dummy DocHitInfoIterator with 3 results for the query "foo" std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = @@ -528,11 +528,11 @@ TEST_F(ScoringProcessorTest, document_store()->Put(document1, /*num_tokens=*/10)); // Document 1 contains the term "foo" 0 times in the "subject" property - DocHitInfo doc_hit_info1(document_id1); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(document_id1); doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/0); // Creates input doc_hit_infos and expected output scored_document_hits - std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos = {doc_hit_info1}; // Creates a dummy DocHitInfoIterator with 1 result for the query "foo" std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = @@ -580,16 +580,17 @@ TEST_F(ScoringProcessorTest, // Document 1 contains the term "foo" 1 time in the "body" property SectionId body_section_id = 0; - DocHitInfo doc_hit_info1(document_id1); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(document_id1); doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); // Document 2 contains the term "foo" 1 time in the "subject" property SectionId subject_section_id = 1; - DocHitInfo doc_hit_info2(document_id2); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(document_id2); doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1); // Creates input doc_hit_infos and expected output scored_document_hits - std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2}; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos = {doc_hit_info1, + doc_hit_info2}; // Creates a dummy DocHitInfoIterator with 2 results for the query "foo" std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = @@ -651,16 +652,17 @@ TEST_F(ScoringProcessorTest, // Document 1 contains the term "foo" 1 time in the "body" property SectionId body_section_id = 0; - DocHitInfo doc_hit_info1(document_id1); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(document_id1); doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); // Document 2 contains the term "foo" 1 time in the "subject" property SectionId subject_section_id = 1; - DocHitInfo doc_hit_info2(document_id2); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(document_id2); doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1); // Creates input doc_hit_infos and expected output scored_document_hits - std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2}; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos = {doc_hit_info1, + doc_hit_info2}; // Creates a dummy DocHitInfoIterator with 2 results for the query "foo" std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = @@ -719,11 +721,11 @@ TEST_F(ScoringProcessorTest, // Document 1 contains the term "foo" 1 time in the "body" property SectionId body_section_id = 0; - DocHitInfo doc_hit_info1(document_id1); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(document_id1); doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); // Creates input doc_hit_infos and expected output scored_document_hits - std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1}; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos = {doc_hit_info1}; // Creates a dummy DocHitInfoIterator with 1 result for the query "foo" std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = @@ -809,16 +811,17 @@ TEST_F(ScoringProcessorTest, // Document 1 contains the term "foo" 1 time in the "body" property SectionId body_section_id = 0; - DocHitInfo doc_hit_info1(document_id1); + DocHitInfoTermFrequencyPair doc_hit_info1 = DocHitInfo(document_id1); doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); // Document 2 contains the term "foo" 1 time in the "subject" property SectionId subject_section_id = 1; - DocHitInfo doc_hit_info2(document_id2); + DocHitInfoTermFrequencyPair doc_hit_info2 = DocHitInfo(document_id2); doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1); // Creates input doc_hit_infos and expected output scored_document_hits - std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2}; + std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos = {doc_hit_info1, + doc_hit_info2}; // Creates a dummy DocHitInfoIterator with 2 results for the query "foo" std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = diff --git a/icing/store/document-id.h b/icing/store/document-id.h index cbe9959..3230819 100644 --- a/icing/store/document-id.h +++ b/icing/store/document-id.h @@ -23,9 +23,9 @@ namespace lib { // Id of a document using DocumentId = int32_t; -// We use 20 bits to encode document_ids and use the largest value (1M - 1) to +// We use 22 bits to encode document_ids and use the largest value (2^22 - 1) to // represent an invalid document_id. -inline constexpr int kDocumentIdBits = 20; +inline constexpr int kDocumentIdBits = 22; inline constexpr DocumentId kInvalidDocumentId = (1u << kDocumentIdBits) - 1; inline constexpr DocumentId kMinDocumentId = 0; inline constexpr DocumentId kMaxDocumentId = kInvalidDocumentId - 1; diff --git a/icing/testing/common-matchers.h b/icing/testing/common-matchers.h index 81f65b2..610eb71 100644 --- a/icing/testing/common-matchers.h +++ b/icing/testing/common-matchers.h @@ -24,6 +24,7 @@ #include "gtest/gtest.h" #include "icing/absl_ports/str_join.h" #include "icing/index/hit/doc-hit-info.h" +#include "icing/index/iterator/doc-hit-info-iterator-test-util.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/proto/search.pb.h" #include "icing/schema/schema-store.h" @@ -50,11 +51,12 @@ MATCHER_P2(EqualsDocHitInfo, document_id, section_ids, "") { const DocHitInfo& actual = arg; SectionIdMask section_mask = kSectionIdMaskNone; for (SectionId section_id : section_ids) { - section_mask |= 1U << section_id; + section_mask |= UINT64_C(1) << section_id; } *result_listener << IcingStringUtil::StringPrintf( - "(actual is {document_id=%d, section_mask=%d}, but expected was " - "{document_id=%d, section_mask=%d}.)", + "(actual is {document_id=%d, section_mask=%" PRIu64 + "}, but expected was " + "{document_id=%d, section_mask=%" PRIu64 "}.)", actual.document_id(), actual.hit_section_ids_mask(), document_id, section_mask); return actual.document_id() == document_id && @@ -64,7 +66,7 @@ MATCHER_P2(EqualsDocHitInfo, document_id, section_ids, "") { // Used to match a DocHitInfo MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id, section_ids_to_term_frequencies_map, "") { - const DocHitInfo& actual = arg; + const DocHitInfoTermFrequencyPair& actual = arg; SectionIdMask section_mask = kSectionIdMaskNone; bool term_frequency_as_expected = true; @@ -73,7 +75,7 @@ MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id, for (auto itr = section_ids_to_term_frequencies_map.begin(); itr != section_ids_to_term_frequencies_map.end(); itr++) { SectionId section_id = itr->first; - section_mask |= 1U << section_id; + section_mask |= UINT64_C(1) << section_id; expected_tfs.push_back(itr->second); actual_tfs.push_back(actual.hit_term_frequency(section_id)); if (actual.hit_term_frequency(section_id) != itr->second) { @@ -88,14 +90,15 @@ MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id, absl_ports::StrJoin(expected_tfs, ",", absl_ports::NumberFormatter()), "]"); *result_listener << IcingStringUtil::StringPrintf( - "(actual is {document_id=%d, section_mask=%d, term_frequencies=%s}, but " - "expected was " - "{document_id=%d, section_mask=%d, term_frequencies=%s}.)", - actual.document_id(), actual.hit_section_ids_mask(), + "(actual is {document_id=%d, section_mask=%" PRIu64 + ", term_frequencies=%s}, but expected was " + "{document_id=%d, section_mask=%" PRIu64 ", term_frequencies=%s}.)", + actual.doc_hit_info().document_id(), + actual.doc_hit_info().hit_section_ids_mask(), actual_term_frequencies.c_str(), document_id, section_mask, expected_term_frequencies.c_str()); - return actual.document_id() == document_id && - actual.hit_section_ids_mask() == section_mask && + return actual.doc_hit_info().document_id() == document_id && + actual.doc_hit_info().hit_section_ids_mask() == section_mask && term_frequency_as_expected; } diff --git a/icing/testing/fake-clock.h b/icing/testing/fake-clock.h index f9f3654..f451753 100644 --- a/icing/testing/fake-clock.h +++ b/icing/testing/fake-clock.h @@ -24,7 +24,7 @@ namespace lib { // every time it's requested. class FakeTimer : public Timer { public: - int64_t GetElapsedMilliseconds() override { + int64_t GetElapsedMilliseconds() const override { return fake_elapsed_milliseconds_; } diff --git a/icing/text_classifier/lib3/utils/java/jni-base.h b/icing/text_classifier/lib3/utils/java/jni-base.h index 65c64a5..f86434b 100644 --- a/icing/text_classifier/lib3/utils/java/jni-base.h +++ b/icing/text_classifier/lib3/utils/java/jni-base.h @@ -17,6 +17,7 @@ #include <jni.h> +#include <memory> #include <string> #include "icing/text_classifier/lib3/utils/base/statusor.h" diff --git a/icing/util/clock.h b/icing/util/clock.h index 9e57854..d987a4c 100644 --- a/icing/util/clock.h +++ b/icing/util/clock.h @@ -42,12 +42,12 @@ class Timer { virtual ~Timer() = default; // Returns the elapsed time from when timer started. - virtual int64_t GetElapsedMilliseconds() { + virtual int64_t GetElapsedMilliseconds() const { return GetElapsedNanoseconds() / 1000000; } // Returns the elapsed time from when timer started. - virtual int64_t GetElapsedNanoseconds() { + virtual int64_t GetElapsedNanoseconds() const { return GetSteadyTimeNanoseconds() - start_timestamp_nanoseconds_; } @@ -90,6 +90,8 @@ class ScopedTimer { } } + const Timer& timer() const { return *timer_; } + private: std::unique_ptr<Timer> timer_; std::function<void(int64_t)> callback_; diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java index 16a4a4a..b54b344 100644 --- a/java/src/com/google/android/icing/IcingSearchEngine.java +++ b/java/src/com/google/android/icing/IcingSearchEngine.java @@ -306,9 +306,14 @@ public class IcingSearchEngine implements Closeable { @NonNull ResultSpecProto resultSpec) { throwIfClosed(); + long javaToNativeStartTimestampMs = System.currentTimeMillis(); byte[] searchResultBytes = nativeSearch( - this, searchSpec.toByteArray(), scoringSpec.toByteArray(), resultSpec.toByteArray()); + this, + searchSpec.toByteArray(), + scoringSpec.toByteArray(), + resultSpec.toByteArray(), + javaToNativeStartTimestampMs); if (searchResultBytes == null) { Log.e(TAG, "Received null SearchResultProto from native."); return SearchResultProto.newBuilder() @@ -317,7 +322,10 @@ public class IcingSearchEngine implements Closeable { } try { - return SearchResultProto.parseFrom(searchResultBytes, EXTENSION_REGISTRY_LITE); + SearchResultProto.Builder searchResultProtoBuilder = + SearchResultProto.newBuilder().mergeFrom(searchResultBytes, EXTENSION_REGISTRY_LITE); + setNativeToJavaJniLatency(searchResultProtoBuilder); + return searchResultProtoBuilder.build(); } catch (InvalidProtocolBufferException e) { Log.e(TAG, "Error parsing SearchResultProto.", e); return SearchResultProto.newBuilder() @@ -330,7 +338,7 @@ public class IcingSearchEngine implements Closeable { public SearchResultProto getNextPage(long nextPageToken) { throwIfClosed(); - byte[] searchResultBytes = nativeGetNextPage(this, nextPageToken); + byte[] searchResultBytes = nativeGetNextPage(this, nextPageToken, System.currentTimeMillis()); if (searchResultBytes == null) { Log.e(TAG, "Received null SearchResultProto from native."); return SearchResultProto.newBuilder() @@ -339,7 +347,10 @@ public class IcingSearchEngine implements Closeable { } try { - return SearchResultProto.parseFrom(searchResultBytes, EXTENSION_REGISTRY_LITE); + SearchResultProto.Builder searchResultProtoBuilder = + SearchResultProto.newBuilder().mergeFrom(searchResultBytes, EXTENSION_REGISTRY_LITE); + setNativeToJavaJniLatency(searchResultProtoBuilder); + return searchResultProtoBuilder.build(); } catch (InvalidProtocolBufferException e) { Log.e(TAG, "Error parsing SearchResultProto.", e); return SearchResultProto.newBuilder() @@ -348,6 +359,16 @@ public class IcingSearchEngine implements Closeable { } } + private void setNativeToJavaJniLatency(SearchResultProto.Builder searchResultProtoBuilder) { + int nativeToJavaLatencyMs = + (int) + (System.currentTimeMillis() + - searchResultProtoBuilder.getQueryStats().getNativeToJavaStartTimestampMs()); + searchResultProtoBuilder.setQueryStats( + searchResultProtoBuilder.getQueryStats().toBuilder() + .setNativeToJavaJniLatencyMs(nativeToJavaLatencyMs)); + } + @NonNull public void invalidateNextPageToken(long nextPageToken) { throwIfClosed(); @@ -657,9 +678,11 @@ public class IcingSearchEngine implements Closeable { IcingSearchEngine instance, byte[] searchSpecBytes, byte[] scoringSpecBytes, - byte[] resultSpecBytes); + byte[] resultSpecBytes, + long javaToNativeStartTimestampMs); - private static native byte[] nativeGetNextPage(IcingSearchEngine instance, long nextPageToken); + private static native byte[] nativeGetNextPage( + IcingSearchEngine instance, long nextPageToken, long javaToNativeStartTimestampMs); private static native void nativeInvalidateNextPageToken( IcingSearchEngine instance, long nextPageToken); diff --git a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java index b55cfd1..556e537 100644 --- a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java +++ b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java @@ -214,6 +214,16 @@ public final class IcingSearchEngineTest { assertStatusOk(searchResultProto.getStatus()); assertThat(searchResultProto.getResultsCount()).isEqualTo(1); assertThat(searchResultProto.getResults(0).getDocument()).isEqualTo(emailDocument); + + // TODO(b/236412954): Enable these JNI latency tests once cl/469819190 is synced to Jetpack + // Test that JNI latency has been set properly + // assertThat(searchResultProto.getQueryStats().hasNativeToJavaJniLatencyMs()).isTrue(); + // assertThat(searchResultProto.getQueryStats().hasNativeToJavaStartTimestampMs()).isTrue(); + // assertThat(searchResultProto.getQueryStats().hasJavaToNativeJniLatencyMs()).isTrue(); + // assertThat(searchResultProto.getQueryStats().getNativeToJavaJniLatencyMs()).isAtLeast(0); + // assertThat(searchResultProto.getQueryStats().getNativeToJavaStartTimestampMs()) + // .isGreaterThan(0); + // assertThat(searchResultProto.getQueryStats().getJavaToNativeJniLatencyMs()).isAtLeast(0); } @Test @@ -256,6 +266,16 @@ public final class IcingSearchEngineTest { DocumentProto resultDocument = searchResultProto.getResults(0).getDocument(); assertThat(resultDocument).isEqualTo(documents.remove(resultDocument.getUri())); + // TODO(b/236412954): Enable these JNI latency tests once cl/469819190 is synced to Jetpack + // Test that JNI latency has been set + // assertThat(searchResultProto.getQueryStats().hasNativeToJavaJniLatencyMs()).isTrue(); + // assertThat(searchResultProto.getQueryStats().hasNativeToJavaStartTimestampMs()).isTrue(); + // assertThat(searchResultProto.getQueryStats().hasJavaToNativeJniLatencyMs()).isTrue(); + // assertThat(searchResultProto.getQueryStats().getNativeToJavaJniLatencyMs()).isAtLeast(0); + // assertThat(searchResultProto.getQueryStats().getNativeToJavaStartTimestampMs()) + // .isGreaterThan(0); + // assertThat(searchResultProto.getQueryStats().getJavaToNativeJniLatencyMs()).isAtLeast(0); + // fetch rest pages for (int i = 1; i < 5; i++) { searchResultProto = icingSearchEngine.getNextPage(searchResultProto.getNextPageToken()); diff --git a/proto/icing/proto/logging.proto b/proto/icing/proto/logging.proto index 0a7c4a6..6f168bd 100644 --- a/proto/icing/proto/logging.proto +++ b/proto/icing/proto/logging.proto @@ -131,7 +131,7 @@ message PutDocumentStatsProto { // Stats of the top-level function IcingSearchEngine::Search() and // IcingSearchEngine::GetNextPage(). -// Next tag: 17 +// Next tag: 21 message QueryStatsProto { // The UTF-8 length of the query string optional int32 query_length = 16; @@ -182,6 +182,19 @@ message QueryStatsProto { // time to snippet if ‘has_snippets’ is true. optional int32 document_retrieval_latency_ms = 14; + // Time passed while waiting to acquire the lock before query execution. + optional int32 lock_acquisition_latency_ms = 17; + + // Timestamp taken just before sending proto across the JNI boundary from + // native to java side. + optional int64 native_to_java_start_timestamp_ms = 18; + + // Time used to send protos across the JNI boundary from java to native side. + optional int32 java_to_native_jni_latency_ms = 19; + + // Time used to send protos across the JNI boundary from native to java side. + optional int32 native_to_java_jni_latency_ms = 20; + reserved 9; } diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index 7a361d3..8592c2f 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -26,7 +26,7 @@ option java_multiple_files = true; option objc_class_prefix = "ICNG"; // Client-supplied specifications on what documents to retrieve. -// Next tag: 5 +// Next tag: 6 message SearchSpecProto { // REQUIRED: The "raw" query string that users may type. For example, "cat" // will search for documents with the term cat in it. @@ -61,6 +61,10 @@ message SearchSpecProto { // applies to the entire 'query'. To issue different queries for different // schema types, separate Search()'s will need to be made. repeated string schema_type_filters = 4; + + // Timestamp taken just before sending proto across the JNI boundary from java + // to native side. + optional int64 java_to_native_start_timestamp_ms = 5; } // Client-supplied specifications on what to include/how to format the search diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index cd00254..2026297 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=466546985) +set(synced_AOSP_CL_number=473080785) |