diff options
author | Alexander Dorokhine <adorokhine@google.com> | 2021-11-05 15:01:48 -0700 |
---|---|---|
committer | Alexander Dorokhine <adorokhine@google.com> | 2021-11-05 15:06:39 -0700 |
commit | d4760eba4d61908605fef358b0851cb89f3dd9d8 (patch) | |
tree | ef49d61e92007f8eb1a3027803206cf30e91d2fb | |
parent | 7ad2a74434a30d786ba247a5320a673d2d0dea63 (diff) | |
parent | ef33b5af6b6c19e22d23ef0d70402fa30da20bbd (diff) | |
download | icing-d4760eba4d61908605fef358b0851cb89f3dd9d8.tar.gz |
Merge remote-tracking branch 'goog/androidx-platform-dev' into master
* goog/androidx-platform-dev:
Sync from upstream.
Sync from upstream.
Add an OWNERS file for external/icing.
Sync from upstream.
Sync from upstream.
Merge androidx-platform-dev/external/icing upstream-master into upstream-master
Sync from upstream.
Sync from upstream.
Commit descriptions:
==================
Add Initialization Count marker file to break out of crash loops.
==================
Delete the Simple Normalizer.
==================
Add submatch information to identify a submatch within a document.
==================
Remove no-longer-used write paths for file-backed-proto-log.
================
Modify segmentation rules to consider any segment that begins with a non-Ascii
alphanumeric character as valid
=================
Implement CalculateNormalizedMatchLength for IcuNormalizer.
================
Add additional benchmark cases that were useful in developing
submatching and CalculateNormalizedMatchLength for IcuNormalizer
=================
Switch NormalizationMap from
static const std::unordered_map<char16_t, char16_t>& to
static const std::unordered_map<char16_t, char16_t> *const.
================
Implement ranking in FindTermByPrefix.
================
Fork proto's GzipStream into Icing Lib.
================
Remove token limit behavior from index-processor.
================
Replace refs to c lib headers w/ c++ stdlib equivalents.
================
Update IDF component of BM25F Calculator in IcingLib
================
Expose QuerySuggestions API.
================
Change the tokenizer used in QuerySuggest.
================
Add SectionWeights API to Icing.
================
Apply SectionWeights to BM25F Scoring.
================
Replaces uses of u_strTo/FromUTF32 w/ u_strTo/FromUTF8.
Bug: 147509515
Bug: 149610413
Bug: 152934343
Bug: 195720764
Bug: 196257995
Bug: 196771754
Bug: 202308641
Bug: 203700301
Test: Presubmit
Change-Id: I31e39a3e5fe5ecccafadf34ef72a7442902bc0cd
108 files changed, 4865 insertions, 2071 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 01ee8eb..8c8e439 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,6 +14,8 @@ cmake_minimum_required(VERSION 3.10.2) +project(icing) + add_definitions("-DICING_REVERSE_JNI_SEGMENTATION=1") set(VERSION_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/icing/jni.lds") set(CMAKE_SHARED_LINKER_FLAGS @@ -74,7 +76,7 @@ foreach(FILE ${Icing_PROTO_FILES}) "${Icing_PROTO_GEN_DIR}/${FILE_NOEXT}.pb.h" COMMAND ${Protobuf_PROTOC_PATH} --proto_path "${CMAKE_CURRENT_SOURCE_DIR}/proto" - --cpp_out ${Icing_PROTO_GEN_DIR} + --cpp_out "lite:${Icing_PROTO_GEN_DIR}" ${FILE} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/proto/${FILE} @@ -127,4 +129,4 @@ target_include_directories(icing PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_include_directories(icing PRIVATE ${Icing_PROTO_GEN_DIR}) target_include_directories(icing PRIVATE "${Protobuf_SOURCE_DIR}/src") target_include_directories(icing PRIVATE "${ICU_SOURCE_DIR}/include") -target_link_libraries(icing protobuf::libprotobuf libandroidicu log) +target_link_libraries(icing protobuf::libprotobuf-lite libandroidicu log z) @@ -0,0 +1,3 @@ +adorokhine@google.com +tjbarron@google.com +dsaadati@google.com diff --git a/build.gradle b/build.gradle index 0f60c5e..5b5f3a6 100644 --- a/build.gradle +++ b/build.gradle @@ -14,8 +14,6 @@ * limitations under the License. */ - -import androidx.build.dependencies.DependenciesKt import static androidx.build.SupportConfig.* buildscript { @@ -65,7 +63,7 @@ dependencies { protobuf { protoc { - artifact = DependenciesKt.getDependencyAsString(libs.protobufCompiler) + artifact = libs.protobufCompiler.get() } generateProtoTasks { diff --git a/icing/file/file-backed-proto-log.h b/icing/file/file-backed-proto-log.h index b2b37e8..686b4fb 100644 --- a/icing/file/file-backed-proto-log.h +++ b/icing/file/file-backed-proto-log.h @@ -14,16 +14,14 @@ // File-backed log of protos with append-only writes and position based reads. // -// There should only be one instance of a FileBackedProtoLog of the same file at -// a time; using multiple instances at the same time may lead to undefined -// behavior. +// The implementation in this file is deprecated and replaced by +// portable-file-backed-proto-log.h. // -// The entire checksum is computed on initialization to verify the contents are -// valid. On failure, the log will be truncated to the last verified state when -// PersistToDisk() was called. If the log cannot successfully restore the last -// state due to disk corruption or some other inconsistency, then the entire log -// will be lost. +// This deprecated implementation has been made read-only for the purposes of +// migration; writing and erasing this format of log is no longer supported and +// the methods to accomplish this have been removed. // +// The details of this format follow below: // Each proto written to the file will have a metadata written just before it. // The metadata consists of // { @@ -31,45 +29,24 @@ // 3 bytes of the proto size // n bytes of the proto itself // } -// -// Example usage: -// ICING_ASSERT_OK_AND_ASSIGN(auto create_result, -// FileBackedProtoLog<DocumentProto>::Create(filesystem, file_path_, -// options)); -// auto proto_log = create_result.proto_log; -// -// Document document; -// document.set_namespace("com.google.android.example"); -// document.set_uri("www.google.com"); -// -// int64_t document_offset = proto_log->WriteProto(document)); -// Document same_document = proto_log->ReadProto(document_offset)); -// proto_log->PersistToDisk(); -// // TODO(b/136514769): Add versioning to the header and a UpgradeToVersion // migration method. - #ifndef ICING_FILE_FILE_BACKED_PROTO_LOG_H_ #define ICING_FILE_FILE_BACKED_PROTO_LOG_H_ -#include <cstddef> #include <cstdint> -#include <cstring> #include <memory> #include <string> #include <string_view> -#include <utility> -#include <vector> -#include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" -#include <google/protobuf/io/gzip_stream.h> #include <google/protobuf/io/zero_copy_stream_impl_lite.h> #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" #include "icing/file/filesystem.h" #include "icing/file/memory-mapped-file.h" #include "icing/legacy/core/icing-string-util.h" +#include "icing/portable/gzip_stream.h" #include "icing/portable/platform.h" #include "icing/portable/zlib.h" #include "icing/util/crc32.h" @@ -112,10 +89,6 @@ class FileBackedProtoLog { // Header stored at the beginning of the file before the rest of the log // contents. Stores metadata on the log. - // - // TODO(b/139375388): Migrate the Header struct to a proto. This makes - // migrations easier since we don't need to worry about different size padding - // (which would affect the checksum) and different endians. struct Header { static constexpr int32_t kMagic = 0xf4c6f67a; @@ -195,20 +168,6 @@ class FileBackedProtoLog { FileBackedProtoLog(const FileBackedProtoLog&) = delete; FileBackedProtoLog& operator=(const FileBackedProtoLog&) = delete; - // This will update the checksum of the log as well. - ~FileBackedProtoLog(); - - // Writes the serialized proto to the underlying file. Writes are applied - // directly to the underlying file. Users do not need to sync the file after - // writing. - // - // Returns: - // Offset of the newly appended proto in file on success - // INVALID_ARGUMENT if proto is too large, as decided by - // Options.max_proto_size - // INTERNAL_ERROR on IO error - libtextclassifier3::StatusOr<int64_t> WriteProto(const ProtoT& proto); - // Reads out a proto located at file_offset from the file. // // Returns: @@ -218,31 +177,6 @@ class FileBackedProtoLog { // INTERNAL_ERROR on IO error libtextclassifier3::StatusOr<ProtoT> ReadProto(int64_t file_offset) const; - // Erases the data of a proto located at file_offset from the file. - // - // Returns: - // OK on success - // OUT_OF_RANGE_ERROR if file_offset exceeds file size - // INTERNAL_ERROR on IO error - libtextclassifier3::Status EraseProto(int64_t file_offset); - - // Calculates and returns the disk usage in bytes. Rounds up to the nearest - // block size. - // - // Returns: - // Disk usage on success - // INTERNAL_ERROR on IO error - libtextclassifier3::StatusOr<int64_t> GetDiskUsage() const; - - // Returns the file size of all the elements held in the log. File size is in - // bytes. This excludes the size of any internal metadata of the log, e.g. the - // log's header. - // - // Returns: - // File size on success - // INTERNAL_ERROR on IO error - libtextclassifier3::StatusOr<int64_t> GetElementsFileSize() const; - // An iterator helping to find offsets of all the protos in file. // Example usage: // @@ -281,72 +215,6 @@ class FileBackedProtoLog { // behaviors could happen. Iterator GetIterator(); - // Persists all changes since initialization or the last call to - // PersistToDisk(). Any changes that aren't persisted may be lost if the - // system fails to close safely. - // - // Example use case: - // - // Document document; - // document.set_namespace("com.google.android.example"); - // document.set_uri("www.google.com"); - // - // { - // ICING_ASSERT_OK_AND_ASSIGN(auto create_result, - // FileBackedProtoLog<DocumentProto>::Create(filesystem, file_path, - // options)); - // auto proto_log = std::move(create_result.proto_log); - // - // int64_t document_offset = proto_log->WriteProto(document)); - // - // // We lose the document here since it wasn't persisted. - // // *SYSTEM CRASH* - // } - // - // { - // // Can still successfully create after a crash since the log can - // // rewind/truncate to recover into a previously good state - // ICING_ASSERT_OK_AND_ASSIGN(auto create_result, - // FileBackedProtoLog<DocumentProto>::Create(filesystem, file_path, - // options)); - // auto proto_log = std::move(create_result.proto_log); - // - // // Lost the proto since we didn't PersistToDisk before the crash - // proto_log->ReadProto(document_offset)); // INVALID_ARGUMENT error - // - // int64_t document_offset = proto_log->WriteProto(document)); - // - // // Persisted this time, so we should be ok. - // ICING_ASSERT_OK(proto_log->PersistToDisk()); - // } - // - // { - // ICING_ASSERT_OK_AND_ASSIGN(auto create_result, - // FileBackedProtoLog<DocumentProto>::Create(filesystem, file_path, - // options)); - // auto proto_log = std::move(create_result.proto_log); - // - // // SUCCESS - // Document same_document = proto_log->ReadProto(document_offset)); - // } - // - // NOTE: Since all protos are already written to the file directly, this - // just updates the checksum and rewind position. Without these updates, - // future initializations will truncate the file and discard unpersisted - // changes. - // - // Returns: - // OK on success - // INTERNAL_ERROR on IO error - libtextclassifier3::Status PersistToDisk(); - - // Calculates the checksum of the log contents. Excludes the header content. - // - // Returns: - // Crc of the log content - // INTERNAL_ERROR on IO error - libtextclassifier3::StatusOr<Crc32> ComputeChecksum(); - private: // Object can only be instantiated via the ::Create factory. FileBackedProtoLog(const Filesystem* filesystem, const std::string& file_path, @@ -424,9 +292,6 @@ class FileBackedProtoLog { static_assert(kMaxProtoSize <= 0x00FFFFFF, "kMaxProtoSize doesn't fit in 3 bytes"); - // Level of compression, BEST_SPEED = 1, BEST_COMPRESSION = 9 - static constexpr int kDeflateCompressionLevel = 3; - // Chunks of the file to mmap at a time, so we don't mmap the entire file. // Only used on 32-bit devices static constexpr int kMmapChunkSize = 4 * 1024 * 1024; // 4MiB @@ -438,9 +303,6 @@ class FileBackedProtoLog { }; template <typename ProtoT> -constexpr uint8_t FileBackedProtoLog<ProtoT>::kProtoMagic; - -template <typename ProtoT> FileBackedProtoLog<ProtoT>::FileBackedProtoLog(const Filesystem* filesystem, const std::string& file_path, std::unique_ptr<Header> header) @@ -451,15 +313,6 @@ FileBackedProtoLog<ProtoT>::FileBackedProtoLog(const Filesystem* filesystem, } template <typename ProtoT> -FileBackedProtoLog<ProtoT>::~FileBackedProtoLog() { - if (!PersistToDisk().ok()) { - ICING_LOG(WARNING) - << "Error persisting to disk during destruction of FileBackedProtoLog: " - << file_path_; - } -} - -template <typename ProtoT> libtextclassifier3::StatusOr<typename FileBackedProtoLog<ProtoT>::CreateResult> FileBackedProtoLog<ProtoT>::Create(const Filesystem* filesystem, const std::string& file_path, @@ -688,79 +541,6 @@ libtextclassifier3::StatusOr<Crc32> FileBackedProtoLog<ProtoT>::ComputeChecksum( } template <typename ProtoT> -libtextclassifier3::StatusOr<int64_t> FileBackedProtoLog<ProtoT>::WriteProto( - const ProtoT& proto) { - int64_t proto_size = proto.ByteSizeLong(); - int32_t metadata; - int metadata_size = sizeof(metadata); - int64_t current_position = filesystem_->GetCurrentPosition(fd_.get()); - - if (proto_size > header_->max_proto_size) { - return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( - "proto_size, %lld, was too large to write. Max is %d", - static_cast<long long>(proto_size), header_->max_proto_size)); - } - - // At this point, we've guaranteed that proto_size is under kMaxProtoSize - // (see - // ::Create), so we can safely store it in an int. - int final_size = 0; - - std::string proto_str; - google::protobuf::io::StringOutputStream proto_stream(&proto_str); - - if (header_->compress) { - google::protobuf::io::GzipOutputStream::Options options; - options.format = google::protobuf::io::GzipOutputStream::ZLIB; - options.compression_level = kDeflateCompressionLevel; - - google::protobuf::io::GzipOutputStream compressing_stream(&proto_stream, - options); - - bool success = proto.SerializeToZeroCopyStream(&compressing_stream) && - compressing_stream.Close(); - - if (!success) { - return absl_ports::InternalError("Error compressing proto."); - } - - final_size = proto_str.size(); - - // In case the compressed proto is larger than the original proto, we also - // can't write it. - if (final_size > header_->max_proto_size) { - return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( - "Compressed proto size, %d, was greater than " - "max_proto_size, %d", - final_size, header_->max_proto_size)); - } - } else { - // Serialize the proto directly into the write buffer at an offset of the - // metadata. - proto.SerializeToZeroCopyStream(&proto_stream); - final_size = proto_str.size(); - } - - // 1st byte for magic, next 3 bytes for proto size. - metadata = (kProtoMagic << 24) | final_size; - - // Actually write metadata, has to be done after we know the possibly - // compressed proto size - if (!filesystem_->Write(fd_.get(), &metadata, metadata_size)) { - return absl_ports::InternalError( - absl_ports::StrCat("Failed to write proto metadata to: ", file_path_)); - } - - // Write the serialized proto - if (!filesystem_->Write(fd_.get(), proto_str.data(), proto_str.size())) { - return absl_ports::InternalError( - absl_ports::StrCat("Failed to write proto to: ", file_path_)); - } - - return current_position; -} - -template <typename ProtoT> libtextclassifier3::StatusOr<ProtoT> FileBackedProtoLog<ProtoT>::ReadProto( int64_t file_offset) const { int64_t file_size = filesystem_->GetFileSize(fd_.get()); @@ -796,7 +576,7 @@ libtextclassifier3::StatusOr<ProtoT> FileBackedProtoLog<ProtoT>::ReadProto( // Deserialize proto ProtoT proto; if (header_->compress) { - google::protobuf::io::GzipInputStream decompress_stream(&proto_stream); + protobuf_ports::GzipInputStream decompress_stream(&proto_stream); proto.ParseFromZeroCopyStream(&decompress_stream); } else { proto.ParseFromZeroCopyStream(&proto_stream); @@ -806,83 +586,6 @@ libtextclassifier3::StatusOr<ProtoT> FileBackedProtoLog<ProtoT>::ReadProto( } template <typename ProtoT> -libtextclassifier3::Status FileBackedProtoLog<ProtoT>::EraseProto( - int64_t file_offset) { - int64_t file_size = filesystem_->GetFileSize(fd_.get()); - if (file_offset >= file_size) { - // file_size points to the next byte to write at, so subtract one to get - // the inclusive, actual size of file. - return absl_ports::OutOfRangeError(IcingStringUtil::StringPrintf( - "Trying to erase data at a location, %lld, " - "out of range of the file size, %lld", - static_cast<long long>(file_offset), - static_cast<long long>(file_size - 1))); - } - - MemoryMappedFile mmapped_file( - *filesystem_, file_path_, - MemoryMappedFile::Strategy::READ_WRITE_AUTO_SYNC); - - // Read out the metadata - ICING_ASSIGN_OR_RETURN( - int metadata, ReadProtoMetadata(&mmapped_file, file_offset, file_size)); - - ICING_RETURN_IF_ERROR(mmapped_file.Remap(file_offset + sizeof(metadata), - GetProtoSize(metadata))); - - // We need to update the crc checksum if the erased area is before the - // rewind position. - if (file_offset + sizeof(metadata) < header_->rewind_offset) { - // We need to calculate [original string xor 0s]. - // The xored string is the same as the original string because 0 xor 0 = - // 0, 1 xor 0 = 1. - const std::string_view xored_str(mmapped_file.region(), - mmapped_file.region_size()); - - Crc32 crc(header_->log_checksum); - ICING_ASSIGN_OR_RETURN( - uint32_t new_crc, - crc.UpdateWithXor( - xored_str, - /*full_data_size=*/header_->rewind_offset - sizeof(Header), - /*position=*/file_offset + sizeof(metadata) - sizeof(Header))); - - header_->log_checksum = new_crc; - header_->header_checksum = header_->CalculateHeaderChecksum(); - - if (!filesystem_->PWrite(fd_.get(), /*offset=*/0, header_.get(), - sizeof(Header))) { - return absl_ports::InternalError( - absl_ports::StrCat("Failed to update header to: ", file_path_)); - } - } - - memset(mmapped_file.mutable_region(), '\0', mmapped_file.region_size()); - return libtextclassifier3::Status::OK; -} - -template <typename ProtoT> -libtextclassifier3::StatusOr<int64_t> FileBackedProtoLog<ProtoT>::GetDiskUsage() - const { - int64_t size = filesystem_->GetDiskUsage(file_path_.c_str()); - if (size == Filesystem::kBadFileSize) { - return absl_ports::InternalError("Failed to get disk usage of proto log"); - } - return size; -} - -template <typename ProtoT> -libtextclassifier3::StatusOr<int64_t> -FileBackedProtoLog<ProtoT>::GetElementsFileSize() const { - int64_t total_file_size = filesystem_->GetFileSize(file_path_.c_str()); - if (total_file_size == Filesystem::kBadFileSize) { - return absl_ports::InternalError( - "Failed to get file size of elments in the proto log"); - } - return total_file_size - sizeof(Header); -} - -template <typename ProtoT> FileBackedProtoLog<ProtoT>::Iterator::Iterator(const Filesystem& filesystem, const std::string& file_path, int64_t initial_offset) @@ -964,51 +667,6 @@ libtextclassifier3::StatusOr<int> FileBackedProtoLog<ProtoT>::ReadProtoMetadata( return metadata; } -template <typename ProtoT> -libtextclassifier3::Status FileBackedProtoLog<ProtoT>::PersistToDisk() { - int64_t file_size = filesystem_->GetFileSize(file_path_.c_str()); - if (file_size == header_->rewind_offset) { - // No new protos appended, don't need to update the checksum. - return libtextclassifier3::Status::OK; - } - - int64_t new_content_size = file_size - header_->rewind_offset; - Crc32 crc; - if (new_content_size < 0) { - // File shrunk, recalculate the entire checksum. - ICING_ASSIGN_OR_RETURN( - crc, ComputeChecksum(filesystem_, file_path_, Crc32(), sizeof(Header), - file_size)); - } else { - // Append new changes to the existing checksum. - ICING_ASSIGN_OR_RETURN( - crc, - ComputeChecksum(filesystem_, file_path_, Crc32(header_->log_checksum), - header_->rewind_offset, file_size)); - } - - header_->log_checksum = crc.Get(); - header_->rewind_offset = file_size; - header_->header_checksum = header_->CalculateHeaderChecksum(); - - if (!filesystem_->PWrite(fd_.get(), /*offset=*/0, header_.get(), - sizeof(Header)) || - !filesystem_->DataSync(fd_.get())) { - return absl_ports::InternalError( - absl_ports::StrCat("Failed to update header to: ", file_path_)); - } - - return libtextclassifier3::Status::OK; -} - -template <typename ProtoT> -libtextclassifier3::StatusOr<Crc32> -FileBackedProtoLog<ProtoT>::ComputeChecksum() { - return FileBackedProtoLog<ProtoT>::ComputeChecksum( - filesystem_, file_path_, Crc32(), /*start=*/sizeof(Header), - /*end=*/filesystem_->GetFileSize(file_path_.c_str())); -} - } // namespace lib } // namespace icing diff --git a/icing/file/file-backed-proto-log_benchmark.cc b/icing/file/file-backed-proto-log_benchmark.cc deleted file mode 100644 index c09fd5a..0000000 --- a/icing/file/file-backed-proto-log_benchmark.cc +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright (C) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <cstdint> -#include <random> - -#include "testing/base/public/benchmark.h" -#include "gmock/gmock.h" -#include "icing/document-builder.h" -#include "icing/file/file-backed-proto-log.h" -#include "icing/file/filesystem.h" -#include "icing/legacy/core/icing-string-util.h" -#include "icing/proto/document.pb.h" -#include "icing/testing/common-matchers.h" -#include "icing/testing/random-string.h" -#include "icing/testing/tmp-directory.h" - -// go/microbenchmarks -// -// To build and run on a local machine: -// $ blaze build -c opt --dynamic_mode=off --copt=-gmlt -// icing/file:file-backed-proto-log_benchmark -// -// $ blaze-bin/icing/file/file-backed-proto-log_benchmark -// --benchmarks=all -// -// -// To build and run on an Android device (must be connected and rooted): -// $ blaze build --copt="-DGOOGLE_COMMANDLINEFLAGS_FULL_API=1" -// --config=android_arm64 -c opt --dynamic_mode=off --copt=-gmlt -// icing/file:file-backed-proto-log_benchmark -// -// $ adb root -// -// $ adb push -// blaze-bin/icing/file/file-backed-proto-log_benchmark -// /data/local/tmp/ -// -// $ adb shell /data/local/tmp/file-backed-proto-log-benchmark -// --benchmarks=all - -namespace icing { -namespace lib { - -namespace { - -static void BM_Write(benchmark::State& state) { - const Filesystem filesystem; - int string_length = state.range(0); - const std::string file_path = IcingStringUtil::StringPrintf( - "%s%s%d%s", GetTestTempDir().c_str(), "/proto_", string_length, ".log"); - int max_proto_size = (1 << 24) - 1; // 16 MiB - bool compress = true; - - // Make sure it doesn't already exist. - filesystem.DeleteFile(file_path.c_str()); - - auto proto_log = - FileBackedProtoLog<DocumentProto>::Create( - &filesystem, file_path, - FileBackedProtoLog<DocumentProto>::Options(compress, max_proto_size)) - .ValueOrDie() - .proto_log; - - DocumentProto document = DocumentBuilder().SetKey("namespace", "uri").Build(); - - std::default_random_engine random; - const std::string rand_str = - RandomString(kAlNumAlphabet, string_length, &random); - - auto document_properties = document.add_properties(); - document_properties->set_name("string property"); - document_properties->add_string_values(rand_str); - - for (auto _ : state) { - testing::DoNotOptimize(proto_log->WriteProto(document)); - } - state.SetBytesProcessed(static_cast<int64_t>(state.iterations()) * - string_length); - - // Cleanup after ourselves - filesystem.DeleteFile(file_path.c_str()); -} -BENCHMARK(BM_Write) - ->Arg(1) - ->Arg(32) - ->Arg(512) - ->Arg(1024) - ->Arg(4 * 1024) - ->Arg(8 * 1024) - ->Arg(16 * 1024) - ->Arg(32 * 1024) - ->Arg(256 * 1024) - ->Arg(2 * 1024 * 1024) - ->Arg(8 * 1024 * 1024) - ->Arg(15 * 1024 * 1024); // We do 15MiB here since our max proto size is - // 16MiB, and we need some extra space for the - // rest of the document properties - -static void BM_Read(benchmark::State& state) { - const Filesystem filesystem; - int string_length = state.range(0); - const std::string file_path = IcingStringUtil::StringPrintf( - "%s%s%d%s", GetTestTempDir().c_str(), "/proto_", string_length, ".log"); - int max_proto_size = (1 << 24) - 1; // 16 MiB - bool compress = true; - - // Make sure it doesn't already exist. - filesystem.DeleteFile(file_path.c_str()); - - auto proto_log = - FileBackedProtoLog<DocumentProto>::Create( - &filesystem, file_path, - FileBackedProtoLog<DocumentProto>::Options(compress, max_proto_size)) - .ValueOrDie() - .proto_log; - - DocumentProto document = DocumentBuilder().SetKey("namespace", "uri").Build(); - - std::default_random_engine random; - const std::string rand_str = - RandomString(kAlNumAlphabet, string_length, &random); - - auto document_properties = document.add_properties(); - document_properties->set_name("string property"); - document_properties->add_string_values(rand_str); - - ICING_ASSERT_OK_AND_ASSIGN(int64_t write_offset, - proto_log->WriteProto(document)); - - for (auto _ : state) { - testing::DoNotOptimize(proto_log->ReadProto(write_offset)); - } - state.SetBytesProcessed(static_cast<int64_t>(state.iterations()) * - string_length); - - // Cleanup after ourselves - filesystem.DeleteFile(file_path.c_str()); -} -BENCHMARK(BM_Read) - ->Arg(1) - ->Arg(32) - ->Arg(512) - ->Arg(1024) - ->Arg(4 * 1024) - ->Arg(8 * 1024) - ->Arg(16 * 1024) - ->Arg(32 * 1024) - ->Arg(256 * 1024) - ->Arg(2 * 1024 * 1024) - ->Arg(8 * 1024 * 1024) - ->Arg(15 * 1024 * 1024); // We do 15MiB here since our max proto size is - // 16MiB, and we need some extra space for the - // rest of the document properties - -static void BM_Erase(benchmark::State& state) { - const Filesystem filesystem; - const std::string file_path = IcingStringUtil::StringPrintf( - "%s%s", GetTestTempDir().c_str(), "/proto.log"); - int max_proto_size = (1 << 24) - 1; // 16 MiB - bool compress = true; - - // Make sure it doesn't already exist. - filesystem.DeleteFile(file_path.c_str()); - - auto proto_log = - FileBackedProtoLog<DocumentProto>::Create( - &filesystem, file_path, - FileBackedProtoLog<DocumentProto>::Options(compress, max_proto_size)) - .ValueOrDie() - .proto_log; - - DocumentProto document = DocumentBuilder().SetKey("namespace", "uri").Build(); - - std::default_random_engine random; - const std::string rand_str = RandomString(kAlNumAlphabet, /*len=*/1, &random); - - auto document_properties = document.add_properties(); - document_properties->set_name("string property"); - document_properties->add_string_values(rand_str); - - for (auto _ : state) { - state.PauseTiming(); - ICING_ASSERT_OK_AND_ASSIGN(int64_t write_offset, - proto_log->WriteProto(document)); - state.ResumeTiming(); - - testing::DoNotOptimize(proto_log->EraseProto(write_offset)); - } - - // Cleanup after ourselves - filesystem.DeleteFile(file_path.c_str()); -} -BENCHMARK(BM_Erase); - -static void BM_ComputeChecksum(benchmark::State& state) { - const Filesystem filesystem; - const std::string file_path = GetTestTempDir() + "/proto.log"; - int max_proto_size = (1 << 24) - 1; // 16 MiB - bool compress = true; - - // Make sure it doesn't already exist. - filesystem.DeleteFile(file_path.c_str()); - - auto proto_log = - FileBackedProtoLog<DocumentProto>::Create( - &filesystem, file_path, - FileBackedProtoLog<DocumentProto>::Options(compress, max_proto_size)) - .ValueOrDie() - .proto_log; - - DocumentProto document = DocumentBuilder().SetKey("namespace", "uri").Build(); - - // Make each document 1KiB - int string_length = 1024; - std::default_random_engine random; - const std::string rand_str = - RandomString(kAlNumAlphabet, string_length, &random); - - auto document_properties = document.add_properties(); - document_properties->set_name("string property"); - document_properties->add_string_values(rand_str); - - int num_docs = state.range(0); - for (int i = 0; i < num_docs; ++i) { - ICING_ASSERT_OK(proto_log->WriteProto(document)); - } - - for (auto _ : state) { - testing::DoNotOptimize(proto_log->ComputeChecksum()); - } - - // Cleanup after ourselves - filesystem.DeleteFile(file_path.c_str()); -} -BENCHMARK(BM_ComputeChecksum)->Range(1024, 1 << 20); - -} // namespace -} // namespace lib -} // namespace icing diff --git a/icing/file/file-backed-proto-log_test.cc b/icing/file/file-backed-proto-log_test.cc index d429277..eccb0c7 100644 --- a/icing/file/file-backed-proto-log_test.cc +++ b/icing/file/file-backed-proto-log_test.cc @@ -19,10 +19,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "icing/document-builder.h" #include "icing/file/filesystem.h" -#include "icing/file/mock-filesystem.h" -#include "icing/portable/equals-proto.h" #include "icing/proto/document.pb.h" #include "icing/testing/common-matchers.h" #include "icing/testing/tmp-directory.h" @@ -32,14 +29,7 @@ namespace lib { namespace { -using ::icing::lib::portable_equals_proto::EqualsProto; -using ::testing::A; -using ::testing::Eq; -using ::testing::Gt; -using ::testing::Not; using ::testing::NotNull; -using ::testing::Pair; -using ::testing::Return; class FileBackedProtoLogTest : public ::testing::Test { protected: @@ -87,193 +77,6 @@ TEST_F(FileBackedProtoLogTest, Initialize) { StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); } -TEST_F(FileBackedProtoLogTest, WriteProtoTooLarge) { - int max_proto_size = 1; - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - DocumentProto document = DocumentBuilder().SetKey("namespace", "uri").Build(); - - // Proto is too large for the max_proto_size_in - ASSERT_THAT(proto_log->WriteProto(document), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); -} - -TEST_F(FileBackedProtoLogTest, ReadProtoWrongKProtoMagic) { - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Write a proto - DocumentProto document = DocumentBuilder().SetKey("namespace", "uri").Build(); - - ICING_ASSERT_OK_AND_ASSIGN(int64_t file_offset, - proto_log->WriteProto(document)); - - // The 4 bytes of metadata that just doesn't have the same kProtoMagic - // specified in file-backed-proto-log.h - uint32_t wrong_magic = 0x7E000000; - - // Sanity check that we opened the file correctly - int fd = filesystem_.OpenForWrite(file_path_.c_str()); - ASSERT_GT(fd, 0); - - // Write the wrong kProtoMagic in, kProtoMagics are stored at the beginning of - // a proto entry. - filesystem_.PWrite(fd, file_offset, &wrong_magic, sizeof(wrong_magic)); - - ASSERT_THAT(proto_log->ReadProto(file_offset), - StatusIs(libtextclassifier3::StatusCode::INTERNAL)); -} - -TEST_F(FileBackedProtoLogTest, ReadWriteUncompressedProto) { - int last_offset; - { - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options( - /*compress_in=*/false, max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Write the first proto - DocumentProto document1 = - DocumentBuilder().SetKey("namespace1", "uri1").Build(); - - ICING_ASSERT_OK_AND_ASSIGN(int written_position, - proto_log->WriteProto(document1)); - - int document1_offset = written_position; - - // Check that what we read is what we wrote - ASSERT_THAT(proto_log->ReadProto(written_position), - IsOkAndHolds(EqualsProto(document1))); - - // Write a second proto that's close to the max size. Leave some room for - // the rest of the proto properties. - std::string long_str(max_proto_size_ - 1024, 'a'); - DocumentProto document2 = DocumentBuilder() - .SetKey("namespace2", "uri2") - .AddStringProperty("long_str", long_str) - .Build(); - - ICING_ASSERT_OK_AND_ASSIGN(written_position, - proto_log->WriteProto(document2)); - - int document2_offset = written_position; - last_offset = written_position; - ASSERT_GT(document2_offset, document1_offset); - - // Check the second proto - ASSERT_THAT(proto_log->ReadProto(written_position), - IsOkAndHolds(EqualsProto(document2))); - - ICING_ASSERT_OK(proto_log->PersistToDisk()); - } - - { - // Make a new proto_log with the same file_path, and make sure we - // can still write to the same underlying file. - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options( - /*compress_in=*/false, max_proto_size_))); - auto recreated_proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Write a third proto - DocumentProto document3 = - DocumentBuilder().SetKey("namespace3", "uri3").Build(); - - ASSERT_THAT(recreated_proto_log->WriteProto(document3), - IsOkAndHolds(Gt(last_offset))); - } -} - -TEST_F(FileBackedProtoLogTest, ReadWriteCompressedProto) { - int last_offset; - - { - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options( - /*compress_in=*/true, max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Write the first proto - DocumentProto document1 = - DocumentBuilder().SetKey("namespace1", "uri1").Build(); - - ICING_ASSERT_OK_AND_ASSIGN(int written_position, - proto_log->WriteProto(document1)); - - int document1_offset = written_position; - - // Check that what we read is what we wrote - ASSERT_THAT(proto_log->ReadProto(written_position), - IsOkAndHolds(EqualsProto(document1))); - - // Write a second proto that's close to the max size. Leave some room for - // the rest of the proto properties. - std::string long_str(max_proto_size_ - 1024, 'a'); - DocumentProto document2 = DocumentBuilder() - .SetKey("namespace2", "uri2") - .AddStringProperty("long_str", long_str) - .Build(); - - ICING_ASSERT_OK_AND_ASSIGN(written_position, - proto_log->WriteProto(document2)); - - int document2_offset = written_position; - last_offset = written_position; - ASSERT_GT(document2_offset, document1_offset); - - // Check the second proto - ASSERT_THAT(proto_log->ReadProto(written_position), - IsOkAndHolds(EqualsProto(document2))); - - ICING_ASSERT_OK(proto_log->PersistToDisk()); - } - - { - // Make a new proto_log with the same file_path, and make sure we - // can still write to the same underlying file. - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options( - /*compress_in=*/true, max_proto_size_))); - auto recreated_proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Write a third proto - DocumentProto document3 = - DocumentBuilder().SetKey("namespace3", "uri3").Build(); - - ASSERT_THAT(recreated_proto_log->WriteProto(document3), - IsOkAndHolds(Gt(last_offset))); - } -} - TEST_F(FileBackedProtoLogTest, CorruptHeader) { { ICING_ASSERT_OK_AND_ASSIGN( @@ -303,382 +106,6 @@ TEST_F(FileBackedProtoLogTest, CorruptHeader) { } } -TEST_F(FileBackedProtoLogTest, CorruptContent) { - { - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - EXPECT_FALSE(create_result.has_data_loss()); - - DocumentProto document = - DocumentBuilder().SetKey("namespace1", "uri1").Build(); - - // Write and persist an document. - ICING_ASSERT_OK_AND_ASSIGN(int document_offset, - proto_log->WriteProto(document)); - ICING_ASSERT_OK(proto_log->PersistToDisk()); - - // "Corrupt" the content written in the log. - document.set_uri("invalid"); - std::string serialized_document = document.SerializeAsString(); - filesystem_.PWrite(file_path_.c_str(), document_offset, - serialized_document.data(), serialized_document.size()); - } - - { - // We can recover, but we have data loss. - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_TRUE(create_result.has_data_loss()); - ASSERT_THAT(create_result.data_loss, Eq(DataLoss::COMPLETE)); - - // Lost everything in the log since the rewind position doesn't help if - // there's been data corruption within the persisted region - ASSERT_EQ(filesystem_.GetFileSize(file_path_.c_str()), - sizeof(FileBackedProtoLog<DocumentProto>::Header)); - } -} - -TEST_F(FileBackedProtoLogTest, PersistToDisk) { - DocumentProto document1 = - DocumentBuilder().SetKey("namespace1", "uri1").Build(); - DocumentProto document2 = - DocumentBuilder().SetKey("namespace2", "uri2").Build(); - int document1_offset, document2_offset; - int log_size; - - { - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Write and persist the first proto - ICING_ASSERT_OK_AND_ASSIGN(document1_offset, - proto_log->WriteProto(document1)); - ICING_ASSERT_OK(proto_log->PersistToDisk()); - - // Write, but don't explicitly persist the second proto - ICING_ASSERT_OK_AND_ASSIGN(document2_offset, - proto_log->WriteProto(document2)); - - // Check that what we read is what we wrote - ASSERT_THAT(proto_log->ReadProto(document1_offset), - IsOkAndHolds(EqualsProto(document1))); - ASSERT_THAT(proto_log->ReadProto(document2_offset), - IsOkAndHolds(EqualsProto(document2))); - - log_size = filesystem_.GetFileSize(file_path_.c_str()); - ASSERT_GT(log_size, 0); - } - - { - // The header rewind position and checksum aren't updated in this "system - // crash" scenario. - - std::string bad_proto = - "some incomplete proto that we didn't finish writing before the system " - "crashed"; - filesystem_.PWrite(file_path_.c_str(), log_size, bad_proto.data(), - bad_proto.size()); - - // Double check that we actually wrote something to the underlying file - ASSERT_GT(filesystem_.GetFileSize(file_path_.c_str()), log_size); - } - - { - // We can recover, but we have data loss - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_TRUE(create_result.has_data_loss()); - ASSERT_THAT(create_result.data_loss, Eq(DataLoss::PARTIAL)); - - // Check that everything was persisted across instances - ASSERT_THAT(proto_log->ReadProto(document1_offset), - IsOkAndHolds(EqualsProto(document1))); - ASSERT_THAT(proto_log->ReadProto(document2_offset), - IsOkAndHolds(EqualsProto(document2))); - - // We correctly rewound to the last good state. - ASSERT_EQ(log_size, filesystem_.GetFileSize(file_path_.c_str())); - } -} - -TEST_F(FileBackedProtoLogTest, Iterator) { - DocumentProto document1 = - DocumentBuilder().SetKey("namespace", "uri1").Build(); - DocumentProto document2 = - DocumentBuilder().SetKey("namespace", "uri2").Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - { - // Empty iterator - auto iterator = proto_log->GetIterator(); - ASSERT_THAT(iterator.Advance(), - StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); - } - - { - // Iterates through some documents - ICING_ASSERT_OK(proto_log->WriteProto(document1)); - ICING_ASSERT_OK(proto_log->WriteProto(document2)); - auto iterator = proto_log->GetIterator(); - // 1st proto - ICING_ASSERT_OK(iterator.Advance()); - ASSERT_THAT(proto_log->ReadProto(iterator.GetOffset()), - IsOkAndHolds(EqualsProto(document1))); - // 2nd proto - ICING_ASSERT_OK(iterator.Advance()); - ASSERT_THAT(proto_log->ReadProto(iterator.GetOffset()), - IsOkAndHolds(EqualsProto(document2))); - // Tries to advance - ASSERT_THAT(iterator.Advance(), - StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); - } - - { - // Iterator with bad filesystem - MockFilesystem mock_filesystem; - ON_CALL(mock_filesystem, GetFileSize(A<const char *>())) - .WillByDefault(Return(Filesystem::kBadFileSize)); - FileBackedProtoLog<DocumentProto>::Iterator bad_iterator( - mock_filesystem, file_path_, /*initial_offset=*/0); - ASSERT_THAT(bad_iterator.Advance(), - StatusIs(libtextclassifier3::StatusCode::OUT_OF_RANGE)); - } -} - -TEST_F(FileBackedProtoLogTest, ComputeChecksum) { - DocumentProto document = DocumentBuilder().SetKey("namespace", "uri").Build(); - Crc32 checksum; - - { - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - ICING_EXPECT_OK(proto_log->WriteProto(document)); - - ICING_ASSERT_OK_AND_ASSIGN(checksum, proto_log->ComputeChecksum()); - - // Calling it twice with no changes should get us the same checksum - EXPECT_THAT(proto_log->ComputeChecksum(), IsOkAndHolds(Eq(checksum))); - } - - { - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Checksum should be consistent across instances - EXPECT_THAT(proto_log->ComputeChecksum(), IsOkAndHolds(Eq(checksum))); - - // PersistToDisk shouldn't affect the checksum value - ICING_EXPECT_OK(proto_log->PersistToDisk()); - EXPECT_THAT(proto_log->ComputeChecksum(), IsOkAndHolds(Eq(checksum))); - - // Check that modifying the log leads to a different checksum - ICING_EXPECT_OK(proto_log->WriteProto(document)); - EXPECT_THAT(proto_log->ComputeChecksum(), IsOkAndHolds(Not(Eq(checksum)))); - } -} - -TEST_F(FileBackedProtoLogTest, EraseProtoShouldSetZero) { - DocumentProto document1 = - DocumentBuilder().SetKey("namespace", "uri1").Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Writes and erases proto - ICING_ASSERT_OK_AND_ASSIGN(int64_t document1_offset, - proto_log->WriteProto(document1)); - ICING_ASSERT_OK(proto_log->EraseProto(document1_offset)); - - // Checks if the erased area is set to 0. - int64_t file_size = filesystem_.GetFileSize(file_path_.c_str()); - MemoryMappedFile mmapped_file(filesystem_, file_path_, - MemoryMappedFile::Strategy::READ_ONLY); - - // document1_offset + sizeof(int) is the start byte of the proto where - // sizeof(int) is the size of the proto metadata. - mmapped_file.Remap(document1_offset + sizeof(int), file_size - 1); - for (size_t i = 0; i < mmapped_file.region_size(); ++i) { - ASSERT_THAT(mmapped_file.region()[i], Eq(0)); - } -} - -TEST_F(FileBackedProtoLogTest, EraseProtoShouldReturnNotFound) { - DocumentProto document1 = - DocumentBuilder().SetKey("namespace", "uri1").Build(); - DocumentProto document2 = - DocumentBuilder().SetKey("namespace", "uri2").Build(); - - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Writes 2 protos - ICING_ASSERT_OK_AND_ASSIGN(int64_t document1_offset, - proto_log->WriteProto(document1)); - ICING_ASSERT_OK_AND_ASSIGN(int64_t document2_offset, - proto_log->WriteProto(document2)); - - // Erases the first proto - ICING_ASSERT_OK(proto_log->EraseProto(document1_offset)); - - // The first proto has been erased. - ASSERT_THAT(proto_log->ReadProto(document1_offset), - StatusIs(libtextclassifier3::StatusCode::NOT_FOUND)); - // The second proto should be returned. - ASSERT_THAT(proto_log->ReadProto(document2_offset), - IsOkAndHolds(EqualsProto(document2))); -} - -TEST_F(FileBackedProtoLogTest, ChecksumShouldBeCorrectWithErasedProto) { - DocumentProto document1 = - DocumentBuilder().SetKey("namespace", "uri1").Build(); - DocumentProto document2 = - DocumentBuilder().SetKey("namespace", "uri2").Build(); - DocumentProto document3 = - DocumentBuilder().SetKey("namespace", "uri3").Build(); - DocumentProto document4 = - DocumentBuilder().SetKey("namespace", "uri4").Build(); - - int64_t document2_offset; - int64_t document3_offset; - - { - // Erase data after the rewind position. This won't update the checksum - // immediately. - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Writes 3 protos - ICING_ASSERT_OK_AND_ASSIGN(int64_t document1_offset, - proto_log->WriteProto(document1)); - ICING_ASSERT_OK_AND_ASSIGN(document2_offset, - proto_log->WriteProto(document2)); - ICING_ASSERT_OK_AND_ASSIGN(document3_offset, - proto_log->WriteProto(document3)); - - // Erases the 1st proto, checksum won't be updated immediately because the - // rewind position is 0. - ICING_ASSERT_OK(proto_log->EraseProto(document1_offset)); - - EXPECT_THAT(proto_log->ComputeChecksum(), - IsOkAndHolds(Eq(Crc32(2293202502)))); - } // New checksum is updated in destructor. - - { - // Erase data before the rewind position. This will update the checksum - // immediately. - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Erases the 2nd proto that is now before the rewind position. Checksum is - // updated. - ICING_ASSERT_OK(proto_log->EraseProto(document2_offset)); - - EXPECT_THAT(proto_log->ComputeChecksum(), - IsOkAndHolds(Eq(Crc32(639634028)))); - } - - { - // Append data and erase data before the rewind position. This will update - // the checksum twice: in EraseProto() and destructor. - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - ASSERT_FALSE(create_result.has_data_loss()); - - // Append a new document which is after the rewind position. - ICING_ASSERT_OK(proto_log->WriteProto(document4)); - - // Erases the 3rd proto that is now before the rewind position. Checksum is - // updated. - ICING_ASSERT_OK(proto_log->EraseProto(document3_offset)); - - EXPECT_THAT(proto_log->ComputeChecksum(), - IsOkAndHolds(Eq(Crc32(1990198693)))); - } // Checksum is updated with the newly appended document. - - { - // A successful creation means that the checksum matches. - ICING_ASSERT_OK_AND_ASSIGN( - FileBackedProtoLog<DocumentProto>::CreateResult create_result, - FileBackedProtoLog<DocumentProto>::Create( - &filesystem_, file_path_, - FileBackedProtoLog<DocumentProto>::Options(compress_, - max_proto_size_))); - auto proto_log = std::move(create_result.proto_log); - EXPECT_FALSE(create_result.has_data_loss()); - } -} - } // namespace } // namespace lib } // namespace icing diff --git a/icing/file/file-backed-vector.h b/icing/file/file-backed-vector.h index 0989935..00bdc7e 100644 --- a/icing/file/file-backed-vector.h +++ b/icing/file/file-backed-vector.h @@ -56,10 +56,9 @@ #ifndef ICING_FILE_FILE_BACKED_VECTOR_H_ #define ICING_FILE_FILE_BACKED_VECTOR_H_ -#include <inttypes.h> -#include <stdint.h> #include <sys/mman.h> +#include <cinttypes> #include <cstdint> #include <memory> #include <string> diff --git a/icing/file/file-backed-vector_test.cc b/icing/file/file-backed-vector_test.cc index b05ce2d..7c02af9 100644 --- a/icing/file/file-backed-vector_test.cc +++ b/icing/file/file-backed-vector_test.cc @@ -14,9 +14,8 @@ #include "icing/file/file-backed-vector.h" -#include <errno.h> - #include <algorithm> +#include <cerrno> #include <cstdint> #include <memory> #include <string_view> diff --git a/icing/file/filesystem.cc b/icing/file/filesystem.cc index 0655cb9..82b8d98 100644 --- a/icing/file/filesystem.cc +++ b/icing/file/filesystem.cc @@ -16,7 +16,6 @@ #include <dirent.h> #include <dlfcn.h> -#include <errno.h> #include <fcntl.h> #include <fnmatch.h> #include <pthread.h> @@ -26,6 +25,7 @@ #include <unistd.h> #include <algorithm> +#include <cerrno> #include <cstdint> #include <unordered_set> diff --git a/icing/file/filesystem.h b/icing/file/filesystem.h index 6bed8e6..ca8c4a8 100644 --- a/icing/file/filesystem.h +++ b/icing/file/filesystem.h @@ -17,11 +17,9 @@ #ifndef ICING_FILE_FILESYSTEM_H_ #define ICING_FILE_FILESYSTEM_H_ -#include <stdint.h> -#include <stdio.h> -#include <string.h> - #include <cstdint> +#include <cstdio> +#include <cstring> #include <memory> #include <string> #include <unordered_set> diff --git a/icing/file/portable-file-backed-proto-log.h b/icing/file/portable-file-backed-proto-log.h index 5284b6e..f676dc5 100644 --- a/icing/file/portable-file-backed-proto-log.h +++ b/icing/file/portable-file-backed-proto-log.h @@ -64,7 +64,6 @@ #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/statusor.h" -#include <google/protobuf/io/gzip_stream.h> #include <google/protobuf/io/zero_copy_stream_impl_lite.h> #include "icing/absl_ports/canonical_errors.h" #include "icing/absl_ports/str_cat.h" @@ -72,6 +71,7 @@ #include "icing/file/memory-mapped-file.h" #include "icing/legacy/core/icing-string-util.h" #include "icing/portable/endian.h" +#include "icing/portable/gzip_stream.h" #include "icing/portable/platform.h" #include "icing/portable/zlib.h" #include "icing/util/bit-util.h" @@ -141,42 +141,50 @@ class PortableFileBackedProtoLog { return crc.Get(); } - int32_t GetMagic() const { return gntohl(magic_nbytes_); } + int32_t GetMagic() const { return GNetworkToHostL(magic_nbytes_); } - void SetMagic(int32_t magic_in) { magic_nbytes_ = ghtonl(magic_in); } + void SetMagic(int32_t magic_in) { + magic_nbytes_ = GHostToNetworkL(magic_in); + } int32_t GetFileFormatVersion() const { - return gntohl(file_format_version_nbytes_); + return GNetworkToHostL(file_format_version_nbytes_); } void SetFileFormatVersion(int32_t file_format_version_in) { - file_format_version_nbytes_ = ghtonl(file_format_version_in); + file_format_version_nbytes_ = GHostToNetworkL(file_format_version_in); } - int32_t GetMaxProtoSize() const { return gntohl(max_proto_size_nbytes_); } + int32_t GetMaxProtoSize() const { + return GNetworkToHostL(max_proto_size_nbytes_); + } void SetMaxProtoSize(int32_t max_proto_size_in) { - max_proto_size_nbytes_ = ghtonl(max_proto_size_in); + max_proto_size_nbytes_ = GHostToNetworkL(max_proto_size_in); } - int32_t GetLogChecksum() const { return gntohl(log_checksum_nbytes_); } + int32_t GetLogChecksum() const { + return GNetworkToHostL(log_checksum_nbytes_); + } void SetLogChecksum(int32_t log_checksum_in) { - log_checksum_nbytes_ = ghtonl(log_checksum_in); + log_checksum_nbytes_ = GHostToNetworkL(log_checksum_in); } - int64_t GetRewindOffset() const { return gntohll(rewind_offset_nbytes_); } + int64_t GetRewindOffset() const { + return GNetworkToHostLL(rewind_offset_nbytes_); + } void SetRewindOffset(int64_t rewind_offset_in) { - rewind_offset_nbytes_ = ghtonll(rewind_offset_in); + rewind_offset_nbytes_ = GHostToNetworkLL(rewind_offset_in); } int32_t GetHeaderChecksum() const { - return gntohl(header_checksum_nbytes_); + return GNetworkToHostL(header_checksum_nbytes_); } void SetHeaderChecksum(int32_t header_checksum_in) { - header_checksum_nbytes_ = ghtonl(header_checksum_in); + header_checksum_nbytes_ = GHostToNetworkL(header_checksum_in); } bool GetCompressFlag() const { return GetFlag(kCompressBit); } @@ -209,7 +217,7 @@ class PortableFileBackedProtoLog { // Holds the magic as a quick sanity check against file corruption. // // Field is in network-byte order. - int32_t magic_nbytes_ = ghtonl(kMagic); + int32_t magic_nbytes_ = GHostToNetworkL(kMagic); // Must be at the beginning after kMagic. Contains the crc checksum of // the following fields. @@ -223,7 +231,7 @@ class PortableFileBackedProtoLog { // valid instead of throwing away the entire log. // // Field is in network-byte order. - int64_t rewind_offset_nbytes_ = ghtonll(kHeaderReservedBytes); + int64_t rewind_offset_nbytes_ = GHostToNetworkLL(kHeaderReservedBytes); // Version number tracking how we serialize the file to disk. If we change // how/what we write to disk, this version should be updated and this class @@ -568,9 +576,6 @@ class PortableFileBackedProtoLog { }; template <typename ProtoT> -constexpr uint8_t PortableFileBackedProtoLog<ProtoT>::kProtoMagic; - -template <typename ProtoT> PortableFileBackedProtoLog<ProtoT>::PortableFileBackedProtoLog( const Filesystem* filesystem, const std::string& file_path, std::unique_ptr<Header> header) @@ -725,7 +730,7 @@ PortableFileBackedProtoLog<ProtoT>::InitializeExistingFile( return absl_ports::InternalError(IcingStringUtil::StringPrintf( "Failed to truncate '%s' to size %lld", file_path.data(), static_cast<long long>(header->GetRewindOffset()))); - }; + } data_loss = DataLoss::PARTIAL; } @@ -881,12 +886,11 @@ PortableFileBackedProtoLog<ProtoT>::WriteProto(const ProtoT& proto) { google::protobuf::io::StringOutputStream proto_stream(&proto_str); if (header_->GetCompressFlag()) { - google::protobuf::io::GzipOutputStream::Options options; - options.format = google::protobuf::io::GzipOutputStream::ZLIB; + protobuf_ports::GzipOutputStream::Options options; + options.format = protobuf_ports::GzipOutputStream::ZLIB; options.compression_level = kDeflateCompressionLevel; - google::protobuf::io::GzipOutputStream compressing_stream(&proto_stream, - options); + protobuf_ports::GzipOutputStream compressing_stream(&proto_stream, options); bool success = proto.SerializeToZeroCopyStream(&compressing_stream) && compressing_stream.Close(); @@ -966,7 +970,7 @@ PortableFileBackedProtoLog<ProtoT>::ReadProto(int64_t file_offset) const { // Deserialize proto ProtoT proto; if (header_->GetCompressFlag()) { - google::protobuf::io::GzipInputStream decompress_stream(&proto_stream); + protobuf_ports::GzipInputStream decompress_stream(&proto_stream); proto.ParseFromZeroCopyStream(&decompress_stream); } else { proto.ParseFromZeroCopyStream(&proto_stream); @@ -1148,7 +1152,7 @@ PortableFileBackedProtoLog<ProtoT>::ReadProtoMetadata( memcpy(&portable_metadata, mmapped_file->region(), metadata_size); // Need to switch it back to host order endianness after reading from disk. - int32_t host_order_metadata = gntohl(portable_metadata); + int32_t host_order_metadata = GNetworkToHostL(portable_metadata); // Checks magic number uint8_t stored_k_proto_magic = GetProtoMagic(host_order_metadata); @@ -1166,7 +1170,7 @@ libtextclassifier3::Status PortableFileBackedProtoLog<ProtoT>::WriteProtoMetadata( const Filesystem* filesystem, int fd, int32_t host_order_metadata) { // Convert it into portable endian format before writing to disk - int32_t portable_metadata = ghtonl(host_order_metadata); + int32_t portable_metadata = GHostToNetworkL(host_order_metadata); int portable_metadata_size = sizeof(portable_metadata); // Write metadata diff --git a/icing/icing-search-engine.cc b/icing/icing-search-engine.cc index 1efad9b..9aa833b 100644 --- a/icing/icing-search-engine.cc +++ b/icing/icing-search-engine.cc @@ -35,6 +35,7 @@ #include "icing/index/index.h" #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/legacy/index/icing-filesystem.h" +#include "icing/portable/endian.h" #include "icing/proto/document.pb.h" #include "icing/proto/initialize.pb.h" #include "icing/proto/internal/optimize.pb.h" @@ -46,6 +47,7 @@ #include "icing/proto/search.pb.h" #include "icing/proto/status.pb.h" #include "icing/query/query-processor.h" +#include "icing/query/suggestion-processor.h" #include "icing/result/projection-tree.h" #include "icing/result/projector.h" #include "icing/result/result-retriever.h" @@ -77,8 +79,14 @@ constexpr std::string_view kDocumentSubfolderName = "document_dir"; constexpr std::string_view kIndexSubfolderName = "index_dir"; constexpr std::string_view kSchemaSubfolderName = "schema_dir"; constexpr std::string_view kSetSchemaMarkerFilename = "set_schema_marker"; +constexpr std::string_view kInitMarkerFilename = "init_marker"; constexpr std::string_view kOptimizeStatusFilename = "optimize_status"; +// The maximum number of unsuccessful initialization attempts from the current +// state that we will tolerate before deleting all data and starting from a +// fresh state. +constexpr int kMaxUnsuccessfulInitAttempts = 5; + libtextclassifier3::Status ValidateOptions( const IcingSearchEngineOptions& options) { // These options are only used in IndexProcessor, which won't be created @@ -127,14 +135,24 @@ libtextclassifier3::Status ValidateSearchSpec( return libtextclassifier3::Status::OK; } -IndexProcessor::Options CreateIndexProcessorOptions( - const IcingSearchEngineOptions& options) { - IndexProcessor::Options index_processor_options; - index_processor_options.max_tokens_per_document = - options.max_tokens_per_doc(); - index_processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kSuppressError; - return index_processor_options; +libtextclassifier3::Status ValidateSuggestionSpec( + const SuggestionSpecProto& suggestion_spec, + const PerformanceConfiguration& configuration) { + if (suggestion_spec.prefix().empty()) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat("SuggestionSpecProto.prefix is empty!")); + } + if (suggestion_spec.num_to_return() <= 0) { + return absl_ports::InvalidArgumentError(absl_ports::StrCat( + "SuggestionSpecProto.num_to_return must be positive.")); + } + if (suggestion_spec.prefix().size() > configuration.max_query_length) { + return absl_ports::InvalidArgumentError( + absl_ports::StrCat("SuggestionSpecProto.prefix is longer than the " + "maximum allowed prefix length: ", + std::to_string(configuration.max_query_length))); + } + return libtextclassifier3::Status::OK; } // Document store files are in a standalone subfolder for easier file @@ -164,10 +182,15 @@ std::string MakeIndexDirectoryPath(const std::string& base_dir) { std::string MakeSchemaDirectoryPath(const std::string& base_dir) { return absl_ports::StrCat(base_dir, "/", kSchemaSubfolderName); } + std::string MakeSetSchemaMarkerFilePath(const std::string& base_dir) { return absl_ports::StrCat(base_dir, "/", kSetSchemaMarkerFilename); } +std::string MakeInitMarkerFilePath(const std::string& base_dir) { + return absl_ports::StrCat(base_dir, "/", kInitMarkerFilename); +} + void TransformStatus(const libtextclassifier3::Status& internal_status, StatusProto* status_proto) { StatusProto::Code code; @@ -276,6 +299,66 @@ InitializeResultProto IcingSearchEngine::Initialize() { return InternalInitialize(); } +void IcingSearchEngine::ResetMembers() { + schema_store_.reset(); + document_store_.reset(); + language_segmenter_.reset(); + normalizer_.reset(); + index_.reset(); +} + +libtextclassifier3::Status IcingSearchEngine::CheckInitMarkerFile( + InitializeStatsProto* initialize_stats) { + // Check to see if the marker file exists and if we've already passed our max + // number of init attempts. + std::string marker_filepath = MakeInitMarkerFilePath(options_.base_dir()); + bool file_exists = filesystem_->FileExists(marker_filepath.c_str()); + int network_init_attempts = 0; + int host_init_attempts = 0; + + // Read the number of previous failed init attempts from the file. If it + // fails, then just assume the value is zero (the most likely reason for + // failure would be non-existence because the last init was successful + // anyways). + ScopedFd marker_file_fd(filesystem_->OpenForWrite(marker_filepath.c_str())); + libtextclassifier3::Status status; + if (file_exists && + filesystem_->PRead(marker_file_fd.get(), &network_init_attempts, + sizeof(network_init_attempts), /*offset=*/0)) { + host_init_attempts = GNetworkToHostL(network_init_attempts); + if (host_init_attempts > kMaxUnsuccessfulInitAttempts) { + // We're tried and failed to init too many times. We need to throw + // everything out and start from scratch. + ResetMembers(); + if (!filesystem_->DeleteDirectoryRecursively( + options_.base_dir().c_str())) { + return absl_ports::InternalError("Failed to delete icing base dir!"); + } + status = absl_ports::DataLossError( + "Encountered failed initialization limit. Cleared all data."); + host_init_attempts = 0; + } + } + + // Use network_init_attempts here because we might have set host_init_attempts + // to 0 if it exceeded the max threshold. + initialize_stats->set_num_previous_init_failures( + GNetworkToHostL(network_init_attempts)); + + ++host_init_attempts; + network_init_attempts = GHostToNetworkL(host_init_attempts); + // Write the updated number of attempts before we get started. + if (!filesystem_->PWrite(marker_file_fd.get(), /*offset=*/0, + &network_init_attempts, + sizeof(network_init_attempts)) || + !filesystem_->DataSync(marker_file_fd.get())) { + return absl_ports::InternalError( + "Failed to write and sync init marker file"); + } + + return status; +} + InitializeResultProto IcingSearchEngine::InternalInitialize() { ICING_VLOG(1) << "Initializing IcingSearchEngine in dir: " << options_.base_dir(); @@ -296,9 +379,17 @@ InitializeResultProto IcingSearchEngine::InternalInitialize() { return result_proto; } + // Now go ahead and try to initialize. libtextclassifier3::Status status = InitializeMembers(initialize_stats); if (status.ok() || absl_ports::IsDataLoss(status)) { - initialized_ = true; + // We successfully initialized. We should delete the init marker file to + // indicate a successful init. + std::string marker_filepath = MakeInitMarkerFilePath(options_.base_dir()); + if (!filesystem_->DeleteFile(marker_filepath.c_str())) { + status = absl_ports::InternalError("Failed to delete init marker file!"); + } else { + initialized_ = true; + } } TransformStatus(status, result_status); initialize_stats->set_latency_ms(initialize_timer->GetElapsedMilliseconds()); @@ -308,7 +399,21 @@ InitializeResultProto IcingSearchEngine::InternalInitialize() { libtextclassifier3::Status IcingSearchEngine::InitializeMembers( InitializeStatsProto* initialize_stats) { ICING_RETURN_ERROR_IF_NULL(initialize_stats); - ICING_RETURN_IF_ERROR(InitializeOptions()); + ICING_RETURN_IF_ERROR(ValidateOptions(options_)); + + // Make sure the base directory exists + if (!filesystem_->CreateDirectoryRecursively(options_.base_dir().c_str())) { + return absl_ports::InternalError(absl_ports::StrCat( + "Could not create directory: ", options_.base_dir())); + } + + // Check to see if the marker file exists and if we've already passed our max + // number of init attempts. + libtextclassifier3::Status status = CheckInitMarkerFile(initialize_stats); + if (!status.ok() && !absl_ports::IsDataLoss(status)) { + return status; + } + ICING_RETURN_IF_ERROR(InitializeSchemaStore(initialize_stats)); // TODO(b/156383798) : Resolve how to specify the locale. @@ -322,7 +427,7 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( std::string marker_filepath = MakeSetSchemaMarkerFilePath(options_.base_dir()); - libtextclassifier3::Status status; + libtextclassifier3::Status index_init_status; if (absl_ports::IsNotFound(schema_store_->GetSchema().status())) { // The schema was either lost or never set before. Wipe out the doc store // and index directories and initialize them from scratch. @@ -336,7 +441,10 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( } ICING_RETURN_IF_ERROR(InitializeDocumentStore( /*force_recovery_and_revalidate_documents=*/false, initialize_stats)); - status = InitializeIndex(initialize_stats); + index_init_status = InitializeIndex(initialize_stats); + if (!index_init_status.ok() && !absl_ports::IsDataLoss(index_init_status)) { + return index_init_status; + } } else if (filesystem_->FileExists(marker_filepath.c_str())) { // If the marker file is still around then something wonky happened when we // last tried to set the schema. @@ -360,12 +468,12 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( std::unique_ptr<Timer> restore_timer = clock_->GetNewTimer(); IndexRestorationResult restore_result = RestoreIndexIfNeeded(); - status = std::move(restore_result.status); + index_init_status = std::move(restore_result.status); // DATA_LOSS means that we have successfully initialized and re-added // content to the index. Some indexed content was lost, but otherwise the // index is in a valid state and can be queried. - if (!status.ok() && !absl_ports::IsDataLoss(status)) { - return status; + if (!index_init_status.ok() && !absl_ports::IsDataLoss(index_init_status)) { + return index_init_status; } // Delete the marker file to indicate that everything is now in sync with @@ -379,30 +487,22 @@ libtextclassifier3::Status IcingSearchEngine::InitializeMembers( } else { ICING_RETURN_IF_ERROR(InitializeDocumentStore( /*force_recovery_and_revalidate_documents=*/false, initialize_stats)); - status = InitializeIndex(initialize_stats); - if (!status.ok() && !absl_ports::IsDataLoss(status)) { - return status; + index_init_status = InitializeIndex(initialize_stats); + if (!index_init_status.ok() && !absl_ports::IsDataLoss(index_init_status)) { + return index_init_status; } } + if (status.ok()) { + status = index_init_status; + } + result_state_manager_ = std::make_unique<ResultStateManager>( performance_configuration_.max_num_total_hits, *document_store_); return status; } -libtextclassifier3::Status IcingSearchEngine::InitializeOptions() { - ICING_RETURN_IF_ERROR(ValidateOptions(options_)); - - // Make sure the base directory exists - if (!filesystem_->CreateDirectoryRecursively(options_.base_dir().c_str())) { - return absl_ports::InternalError(absl_ports::StrCat( - "Could not create directory: ", options_.base_dir())); - } - - return libtextclassifier3::Status::OK; -} - libtextclassifier3::Status IcingSearchEngine::InitializeSchemaStore( InitializeStatsProto* initialize_stats) { ICING_RETURN_ERROR_IF_NULL(initialize_stats); @@ -710,9 +810,8 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { } DocumentId document_id = document_id_or.ValueOrDie(); - auto index_processor_or = IndexProcessor::Create( - normalizer_.get(), index_.get(), CreateIndexProcessorOptions(options_), - clock_.get()); + auto index_processor_or = + IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get()); if (!index_processor_or.ok()) { TransformStatus(index_processor_or.status(), result_status); put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); @@ -723,6 +822,17 @@ PutResultProto IcingSearchEngine::Put(DocumentProto&& document) { auto status = index_processor->IndexDocument(tokenized_document, document_id, put_document_stats); + if (!status.ok()) { + // If we encountered a failure while indexing this document, then mark it as + // deleted. + libtextclassifier3::Status delete_status = + document_store_->Delete(document_id); + if (!delete_status.ok()) { + // This is pretty dire (and, hopefully, unlikely). We can't roll back the + // document that we just added. Wipeout the whole index. + ResetInternal(); + } + } TransformStatus(status, result_status); put_document_stats->set_latency_ms(put_timer->GetElapsedMilliseconds()); @@ -1308,8 +1418,8 @@ SearchResultProto IcingSearchEngine::Search( component_timer = clock_->GetNewTimer(); // Scores but does not rank the results. libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> - scoring_processor_or = - ScoringProcessor::Create(scoring_spec, document_store_.get()); + scoring_processor_or = ScoringProcessor::Create( + scoring_spec, document_store_.get(), schema_store_.get()); if (!scoring_processor_or.ok()) { TransformStatus(scoring_processor_or.status(), result_status); return result_proto; @@ -1620,9 +1730,8 @@ IcingSearchEngine::RestoreIndexIfNeeded() { return {libtextclassifier3::Status::OK, false}; } - auto index_processor_or = IndexProcessor::Create( - normalizer_.get(), index_.get(), CreateIndexProcessorOptions(options_), - clock_.get()); + auto index_processor_or = + IndexProcessor::Create(normalizer_.get(), index_.get(), clock_.get()); if (!index_processor_or.ok()) { return {index_processor_or.status(), true}; } @@ -1700,22 +1809,18 @@ libtextclassifier3::StatusOr<bool> IcingSearchEngine::LostPreviousSchema() { } ResetResultProto IcingSearchEngine::Reset() { + absl_ports::unique_lock l(&mutex_); + return ResetInternal(); +} + +ResetResultProto IcingSearchEngine::ResetInternal() { ICING_VLOG(1) << "Resetting IcingSearchEngine"; ResetResultProto result_proto; StatusProto* result_status = result_proto.mutable_status(); - absl_ports::unique_lock l(&mutex_); - initialized_ = false; - - // Resets members variables - schema_store_.reset(); - document_store_.reset(); - language_segmenter_.reset(); - normalizer_.reset(); - index_.reset(); - + ResetMembers(); if (!filesystem_->DeleteDirectoryRecursively(options_.base_dir().c_str())) { result_status->set_code(StatusProto::INTERNAL); return result_proto; @@ -1741,5 +1846,62 @@ ResetResultProto IcingSearchEngine::Reset() { return result_proto; } +SuggestionResponse IcingSearchEngine::SearchSuggestions( + const SuggestionSpecProto& suggestion_spec) { + // TODO(b/146008613) Explore ideas to make this function read-only. + absl_ports::unique_lock l(&mutex_); + SuggestionResponse response; + StatusProto* response_status = response.mutable_status(); + if (!initialized_) { + response_status->set_code(StatusProto::FAILED_PRECONDITION); + response_status->set_message("IcingSearchEngine has not been initialized!"); + return response; + } + + libtextclassifier3::Status status = + ValidateSuggestionSpec(suggestion_spec, performance_configuration_); + if (!status.ok()) { + TransformStatus(status, response_status); + return response; + } + + // Create the suggestion processor. + auto suggestion_processor_or = SuggestionProcessor::Create( + index_.get(), language_segmenter_.get(), normalizer_.get()); + if (!suggestion_processor_or.ok()) { + TransformStatus(suggestion_processor_or.status(), response_status); + return response; + } + std::unique_ptr<SuggestionProcessor> suggestion_processor = + std::move(suggestion_processor_or).ValueOrDie(); + + std::vector<NamespaceId> namespace_ids; + namespace_ids.reserve(suggestion_spec.namespace_filters_size()); + for (std::string_view name_space : suggestion_spec.namespace_filters()) { + auto namespace_id_or = document_store_->GetNamespaceId(name_space); + if (!namespace_id_or.ok()) { + continue; + } + namespace_ids.push_back(namespace_id_or.ValueOrDie()); + } + + // Run suggestion based on given SuggestionSpec. + libtextclassifier3::StatusOr<std::vector<TermMetadata>> terms_or = + suggestion_processor->QuerySuggestions(suggestion_spec, namespace_ids); + if (!terms_or.ok()) { + TransformStatus(terms_or.status(), response_status); + return response; + } + + // Convert vector<TermMetaData> into final SuggestionResponse proto. + for (TermMetadata& term : terms_or.ValueOrDie()) { + SuggestionResponse::Suggestion suggestion; + suggestion.set_query(std::move(term.content)); + response.mutable_suggestions()->Add(std::move(suggestion)); + } + response_status->set_code(StatusProto::OK); + return response; +} + } // namespace lib } // namespace icing diff --git a/icing/icing-search-engine.h b/icing/icing-search-engine.h index 855401f..0a79714 100644 --- a/icing/icing-search-engine.h +++ b/icing/icing-search-engine.h @@ -302,6 +302,17 @@ class IcingSearchEngine { const ResultSpecProto& result_spec) ICING_LOCKS_EXCLUDED(mutex_); + // Retrieves, scores, ranks and returns the suggested query string according + // to the specs. Results can be empty. + // + // Returns a SuggestionResponse with status: + // OK with results on success + // INVALID_ARGUMENT if any of specs is invalid + // FAILED_PRECONDITION IcingSearchEngine has not been initialized yet + // INTERNAL_ERROR on any other errors + SuggestionResponse SearchSuggestions( + const SuggestionSpecProto& suggestion_spec) ICING_LOCKS_EXCLUDED(mutex_); + // Fetches the next page of results of a previously executed query. Results // can be empty if next-page token is invalid. Invalid next page tokens are // tokens that are either zero or were previously passed to @@ -452,6 +463,25 @@ class IcingSearchEngine { // Pointer to JNI class references const std::unique_ptr<const JniCache> jni_cache_; + // Resets all members that are created during Initialize. + void ResetMembers() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Resets all members that are created during Initialize, deletes all + // underlying files and initializes a fresh index. + ResetResultProto ResetInternal() ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Checks for the existence of the init marker file. If the failed init count + // exceeds kMaxUnsuccessfulInitAttempts, all data is deleted and the index is + // initialized from scratch. The updated count (original failed init count + 1 + // ) is written to the marker file. + // + // RETURNS + // OK on success + // INTERNAL if an IO error occurs while trying to update the marker file. + libtextclassifier3::Status CheckInitMarkerFile( + InitializeStatsProto* initialize_stats) + ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); + // Helper method to do the actual work to persist data to disk. We need this // separate method so that other public methods don't need to call // PersistToDisk(). Public methods calling each other may cause deadlock @@ -477,15 +507,6 @@ class IcingSearchEngine { InitializeStatsProto* initialize_stats) ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Do any validation/setup required for the given IcingSearchEngineOptions - // - // Returns: - // OK on success - // INVALID_ARGUMENT if options has invalid values - // INTERNAL on I/O error - libtextclassifier3::Status InitializeOptions() - ICING_EXCLUSIVE_LOCKS_REQUIRED(mutex_); - // Do any initialization/recovery necessary to create a SchemaStore instance. // // Returns: diff --git a/icing/icing-search-engine_benchmark.cc b/icing/icing-search-engine_benchmark.cc index ba9aed1..5e610d5 100644 --- a/icing/icing-search-engine_benchmark.cc +++ b/icing/icing-search-engine_benchmark.cc @@ -43,7 +43,6 @@ #include "icing/testing/common-matchers.h" #include "icing/testing/document-generator.h" #include "icing/testing/random-string.h" -#include "icing/testing/recorder-test-utils.h" #include "icing/testing/schema-generator.h" #include "icing/testing/tmp-directory.h" @@ -178,12 +177,12 @@ class DestructibleDirectory { }; std::vector<DocumentProto> GenerateRandomDocuments( - EvenDistributionTypeSelector* type_selector, int num_docs) { + EvenDistributionTypeSelector* type_selector, int num_docs, + const std::vector<std::string>& language) { std::vector<std::string> namespaces = CreateNamespaces(kAvgNumNamespaces); EvenDistributionNamespaceSelector namespace_selector(namespaces); std::default_random_engine random; - std::vector<std::string> language = CreateLanguages(kLanguageSize, &random); UniformDistributionLanguageTokenGenerator<std::default_random_engine> token_generator(language, &random); @@ -227,8 +226,9 @@ void BM_IndexLatency(benchmark::State& state) { ASSERT_THAT(icing->SetSchema(schema).status(), ProtoIsOk()); int num_docs = state.range(0); + std::vector<std::string> language = CreateLanguages(kLanguageSize, &random); const std::vector<DocumentProto> random_docs = - GenerateRandomDocuments(&type_selector, num_docs); + GenerateRandomDocuments(&type_selector, num_docs, language); Timer timer; for (const DocumentProto& doc : random_docs) { ASSERT_THAT(icing->Put(doc).status(), ProtoIsOk()); @@ -271,6 +271,56 @@ BENCHMARK(BM_IndexLatency) ->ArgPair(1 << 15, 10) ->ArgPair(1 << 17, 10); +void BM_QueryLatency(benchmark::State& state) { + // Initialize the filesystem + std::string test_dir = GetTestTempDir() + "/icing/benchmark"; + Filesystem filesystem; + DestructibleDirectory ddir(filesystem, test_dir); + + // Create the schema. + std::default_random_engine random; + int num_types = kAvgNumNamespaces * kAvgNumTypes; + ExactStringPropertyGenerator property_generator; + SchemaGenerator<ExactStringPropertyGenerator> schema_generator( + /*num_properties=*/state.range(1), &property_generator); + SchemaProto schema = schema_generator.GenerateSchema(num_types); + EvenDistributionTypeSelector type_selector(schema); + + // Create the index. + IcingSearchEngineOptions options; + options.set_base_dir(test_dir); + options.set_index_merge_size(kIcingFullIndexSize); + std::unique_ptr<IcingSearchEngine> icing = + std::make_unique<IcingSearchEngine>(options); + + ASSERT_THAT(icing->Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing->SetSchema(schema).status(), ProtoIsOk()); + + int num_docs = state.range(0); + std::vector<std::string> language = CreateLanguages(kLanguageSize, &random); + const std::vector<DocumentProto> random_docs = + GenerateRandomDocuments(&type_selector, num_docs, language); + for (const DocumentProto& doc : random_docs) { + ASSERT_THAT(icing->Put(doc).status(), ProtoIsOk()); + } + + SearchSpecProto search_spec = CreateSearchSpec( + language.at(0), std::vector<std::string>(), TermMatchType::PREFIX); + ResultSpecProto result_spec = CreateResultSpec(1000000, 1000000, 1000000); + ScoringSpecProto scoring_spec = + CreateScoringSpec(ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP); + for (auto _ : state) { + SearchResultProto results = icing->Search( + search_spec, ScoringSpecProto::default_instance(), result_spec); + } +} +BENCHMARK(BM_QueryLatency) + // Arguments: num_indexed_documents, num_sections + ->ArgPair(32, 2) + ->ArgPair(128, 2) + ->ArgPair(1 << 10, 2) + ->ArgPair(1 << 13, 2); + void BM_IndexThroughput(benchmark::State& state) { // Initialize the filesystem std::string test_dir = GetTestTempDir() + "/icing/benchmark"; @@ -297,8 +347,9 @@ void BM_IndexThroughput(benchmark::State& state) { ASSERT_THAT(icing->SetSchema(schema).status(), ProtoIsOk()); int num_docs = state.range(0); + std::vector<std::string> language = CreateLanguages(kLanguageSize, &random); const std::vector<DocumentProto> random_docs = - GenerateRandomDocuments(&type_selector, num_docs); + GenerateRandomDocuments(&type_selector, num_docs, language); for (auto s : state) { for (const DocumentProto& doc : random_docs) { ASSERT_THAT(icing->Put(doc).status(), ProtoIsOk()); diff --git a/icing/icing-search-engine_test.cc b/icing/icing-search-engine_test.cc index ef4625a..b5206cd 100644 --- a/icing/icing-search-engine_test.cc +++ b/icing/icing-search-engine_test.cc @@ -29,6 +29,7 @@ #include "icing/file/mock-filesystem.h" #include "icing/helpers/icu/icu-data-file-helper.h" #include "icing/legacy/index/icing-mock-filesystem.h" +#include "icing/portable/endian.h" #include "icing/portable/equals-proto.h" #include "icing/portable/platform.h" #include "icing/proto/document.pb.h" @@ -507,6 +508,217 @@ TEST_F(IcingSearchEngineTest, FailToCreateDocStore) { HasSubstr("Could not create directory")); } +TEST_F(IcingSearchEngineTest, InitMarkerFilePreviousFailuresAtThreshold) { + Filesystem filesystem; + DocumentProto email1 = + CreateEmailDocument("namespace", "uri1", 100, "subject1", "body1"); + email1.set_creation_timestamp_ms(10000); + DocumentProto email2 = + CreateEmailDocument("namespace", "uri2", 50, "subject2", "body2"); + email2.set_creation_timestamp_ms(10000); + + { + // Create an index with a few documents. + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + InitializeResultProto init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoIsOk()); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(0)); + ASSERT_THAT(icing.SetSchema(CreateEmailSchema()).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email2).status(), ProtoIsOk()); + } + + // Write an init marker file with 5 previously failed attempts. + std::string marker_filepath = GetTestBaseDir() + "/init_marker"; + + { + ScopedFd marker_file_fd(filesystem.OpenForWrite(marker_filepath.c_str())); + int network_init_attempts = GHostToNetworkL(5); + // Write the updated number of attempts before we get started. + ASSERT_TRUE(filesystem.PWrite(marker_file_fd.get(), 0, + &network_init_attempts, + sizeof(network_init_attempts))); + ASSERT_TRUE(filesystem.DataSync(marker_file_fd.get())); + } + + { + // Create the index again and verify that initialization succeeds and no + // data is thrown out. + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + InitializeResultProto init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoIsOk()); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(5)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()) + .document(), + EqualsProto(email1)); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()) + .document(), + EqualsProto(email2)); + } + + // The successful init should have thrown out the marker file. + ASSERT_FALSE(filesystem.FileExists(marker_filepath.c_str())); +} + +TEST_F(IcingSearchEngineTest, InitMarkerFilePreviousFailuresBeyondThreshold) { + Filesystem filesystem; + DocumentProto email1 = + CreateEmailDocument("namespace", "uri1", 100, "subject1", "body1"); + DocumentProto email2 = + CreateEmailDocument("namespace", "uri2", 50, "subject2", "body2"); + + { + // Create an index with a few documents. + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + InitializeResultProto init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoIsOk()); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(0)); + ASSERT_THAT(icing.SetSchema(CreateEmailSchema()).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email2).status(), ProtoIsOk()); + } + + // Write an init marker file with 6 previously failed attempts. + std::string marker_filepath = GetTestBaseDir() + "/init_marker"; + + { + ScopedFd marker_file_fd(filesystem.OpenForWrite(marker_filepath.c_str())); + int network_init_attempts = GHostToNetworkL(6); + // Write the updated number of attempts before we get started. + ASSERT_TRUE(filesystem.PWrite(marker_file_fd.get(), 0, + &network_init_attempts, + sizeof(network_init_attempts))); + ASSERT_TRUE(filesystem.DataSync(marker_file_fd.get())); + } + + { + // Create the index again and verify that initialization succeeds and all + // data is thrown out. + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + InitializeResultProto init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), + ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(6)); + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()) + .status(), + ProtoStatusIs(StatusProto::NOT_FOUND)); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()) + .status(), + ProtoStatusIs(StatusProto::NOT_FOUND)); + } + + // The successful init should have thrown out the marker file. + ASSERT_FALSE(filesystem.FileExists(marker_filepath.c_str())); +} + +TEST_F(IcingSearchEngineTest, SuccessiveInitFailuresIncrementsInitMarker) { + Filesystem filesystem; + DocumentProto email1 = + CreateEmailDocument("namespace", "uri1", 100, "subject1", "body1"); + DocumentProto email2 = + CreateEmailDocument("namespace", "uri2", 50, "subject2", "body2"); + + { + // 1. Create an index with a few documents. + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + InitializeResultProto init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoIsOk()); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(0)); + ASSERT_THAT(icing.SetSchema(CreateEmailSchema()).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(email2).status(), ProtoIsOk()); + } + + { + // 2. Create an index that will encounter an IO failure when trying to + // create the document log. + IcingSearchEngineOptions icing_options = GetDefaultIcingOptions(); + + auto mock_filesystem = std::make_unique<MockFilesystem>(); + std::string document_log_filepath = + icing_options.base_dir() + "/document_dir/document_log_v1"; + auto get_filesize_lambda = [this, + &document_log_filepath](const char* filename) { + if (strncmp(document_log_filepath.c_str(), filename, + document_log_filepath.length()) == 0) { + return Filesystem::kBadFileSize; + } + return this->filesystem()->GetFileSize(filename); + }; + ON_CALL(*mock_filesystem, GetFileSize(A<const char*>())) + .WillByDefault(get_filesize_lambda); + + TestIcingSearchEngine icing(icing_options, std::move(mock_filesystem), + std::make_unique<IcingFilesystem>(), + std::make_unique<FakeClock>(), + GetTestJniCache()); + + // Fail to initialize six times in a row. + InitializeResultProto init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoStatusIs(StatusProto::INTERNAL)); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(0)); + + init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoStatusIs(StatusProto::INTERNAL)); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(1)); + + init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoStatusIs(StatusProto::INTERNAL)); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(2)); + + init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoStatusIs(StatusProto::INTERNAL)); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(3)); + + init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoStatusIs(StatusProto::INTERNAL)); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(4)); + + init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), ProtoStatusIs(StatusProto::INTERNAL)); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(5)); + } + + { + // 3. Create the index again and verify that initialization succeeds and all + // data is thrown out. + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + InitializeResultProto init_result = icing.Initialize(); + ASSERT_THAT(init_result.status(), + ProtoStatusIs(StatusProto::WARNING_DATA_LOSS)); + ASSERT_THAT(init_result.initialize_stats().num_previous_init_failures(), + Eq(6)); + + EXPECT_THAT( + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()) + .status(), + ProtoStatusIs(StatusProto::NOT_FOUND)); + EXPECT_THAT( + icing.Get("namespace", "uri2", GetResultSpecProto::default_instance()) + .status(), + ProtoStatusIs(StatusProto::NOT_FOUND)); + } + + // The successful init should have thrown out the marker file. + std::string marker_filepath = GetTestBaseDir() + "/init_marker"; + ASSERT_FALSE(filesystem.FileExists(marker_filepath.c_str())); +} + TEST_F(IcingSearchEngineTest, CircularReferenceCreateSectionManagerReturnsInvalidArgument) { // Create a type config with a circular reference. @@ -7227,10 +7439,6 @@ TEST_F(IcingSearchEngineTest, PutDocumentShouldLogIndexingStats) { // No merge should happen. EXPECT_THAT(put_result_proto.put_document_stats().index_merge_latency_ms(), Eq(0)); - // Number of tokens should not exceed. - EXPECT_FALSE(put_result_proto.put_document_stats() - .tokenization_stats() - .exceeded_max_token_num()); // The input document has 2 tokens. EXPECT_THAT(put_result_proto.put_document_stats() .tokenization_stats() @@ -7238,33 +7446,6 @@ TEST_F(IcingSearchEngineTest, PutDocumentShouldLogIndexingStats) { Eq(2)); } -TEST_F(IcingSearchEngineTest, PutDocumentShouldLogWhetherNumTokensExceeds) { - // Create a document with 2 tokens. - DocumentProto document = DocumentBuilder() - .SetKey("icing", "fake_type/0") - .SetSchema("Message") - .AddStringProperty("body", "message body") - .Build(); - - // Create an icing instance with max_tokens_per_doc = 1. - IcingSearchEngineOptions icing_options = GetDefaultIcingOptions(); - icing_options.set_max_tokens_per_doc(1); - IcingSearchEngine icing(icing_options, GetTestJniCache()); - ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); - ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); - - PutResultProto put_result_proto = icing.Put(document); - EXPECT_THAT(put_result_proto.status(), ProtoIsOk()); - // Number of tokens(2) exceeds the max allowed value(1). - EXPECT_TRUE(put_result_proto.put_document_stats() - .tokenization_stats() - .exceeded_max_token_num()); - EXPECT_THAT(put_result_proto.put_document_stats() - .tokenization_stats() - .num_tokens_indexed(), - Eq(1)); -} - TEST_F(IcingSearchEngineTest, PutDocumentShouldLogIndexMergeLatency) { DocumentProto document1 = DocumentBuilder() .SetKey("icing", "fake_type/1") @@ -7832,6 +8013,147 @@ TEST_F(IcingSearchEngineTest, CJKSnippetTest) { EXPECT_THAT(match_proto.exact_match_utf16_length(), Eq(2)); } +TEST_F(IcingSearchEngineTest, PutDocumentIndexFailureDeletion) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreateMessageSchema()).status(), ProtoIsOk()); + + // Testing has shown that adding ~600,000 terms generated this way will + // fill up the hit buffer. + std::vector<std::string> terms = GenerateUniqueTerms(600000); + std::string content = absl_ports::StrJoin(terms, " "); + DocumentProto document = DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Message") + .AddStringProperty("body", "foo " + content) + .Build(); + // We failed to add the document to the index fully. This means that we should + // reject the document from Icing entirely. + ASSERT_THAT(icing.Put(document).status(), + ProtoStatusIs(StatusProto::OUT_OF_SPACE)); + + // Make sure that the document isn't searchable. + SearchSpecProto search_spec; + search_spec.set_query("foo"); + search_spec.set_term_match_type(MATCH_PREFIX); + + SearchResultProto search_results = + icing.Search(search_spec, ScoringSpecProto::default_instance(), + ResultSpecProto::default_instance()); + ASSERT_THAT(search_results.status(), ProtoIsOk()); + ASSERT_THAT(search_results.results(), IsEmpty()); + + // Make sure that the document isn't retrievable. + GetResultProto get_result = + icing.Get("namespace", "uri1", GetResultSpecProto::default_instance()); + ASSERT_THAT(get_result.status(), ProtoStatusIs(StatusProto::NOT_FOUND)); +} + +TEST_F(IcingSearchEngineTest, SearchSuggestionsTest) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + ASSERT_THAT(icing.SetSchema(CreatePersonAndEmailSchema()).status(), + ProtoIsOk()); + + // Creates and inserts 6 documents, and index 6 termSix, 5 termFive, 4 + // termFour, 3 termThree, 2 termTwo and one termOne. + DocumentProto document1 = + DocumentBuilder() + .SetKey("namespace", "uri1") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty( + "subject", "termOne termTwo termThree termFour termFive termSix") + .Build(); + DocumentProto document2 = + DocumentBuilder() + .SetKey("namespace", "uri2") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", + "termTwo termThree termFour termFive termSix") + .Build(); + DocumentProto document3 = + DocumentBuilder() + .SetKey("namespace", "uri3") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "termThree termFour termFive termSix") + .Build(); + DocumentProto document4 = + DocumentBuilder() + .SetKey("namespace", "uri4") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "termFour termFive termSix") + .Build(); + DocumentProto document5 = + DocumentBuilder() + .SetKey("namespace", "uri5") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "termFive termSix") + .Build(); + DocumentProto document6 = DocumentBuilder() + .SetKey("namespace", "uri6") + .SetSchema("Email") + .SetCreationTimestampMs(10) + .AddStringProperty("subject", "termSix") + .Build(); + ASSERT_THAT(icing.Put(document1).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document2).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document3).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document4).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document5).status(), ProtoIsOk()); + ASSERT_THAT(icing.Put(document6).status(), ProtoIsOk()); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("t"); + suggestion_spec.set_num_to_return(10); + + // Query all suggestions, and they will be ranked. + SuggestionResponse response = icing.SearchSuggestions(suggestion_spec); + ASSERT_THAT(response.status(), ProtoIsOk()); + ASSERT_THAT(response.suggestions().at(0).query(), "termsix"); + ASSERT_THAT(response.suggestions().at(1).query(), "termfive"); + ASSERT_THAT(response.suggestions().at(2).query(), "termfour"); + ASSERT_THAT(response.suggestions().at(3).query(), "termthree"); + ASSERT_THAT(response.suggestions().at(4).query(), "termtwo"); + ASSERT_THAT(response.suggestions().at(5).query(), "termone"); + + // Query first three suggestions, and they will be ranked. + suggestion_spec.set_num_to_return(3); + response = icing.SearchSuggestions(suggestion_spec); + ASSERT_THAT(response.status(), ProtoIsOk()); + ASSERT_THAT(response.suggestions().at(0).query(), "termsix"); + ASSERT_THAT(response.suggestions().at(1).query(), "termfive"); + ASSERT_THAT(response.suggestions().at(2).query(), "termfour"); +} + +TEST_F(IcingSearchEngineTest, SearchSuggestionsTest_emptyPrefix) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix(""); + suggestion_spec.set_num_to_return(10); + + ASSERT_THAT(icing.SearchSuggestions(suggestion_spec).status(), + ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); +} + +TEST_F(IcingSearchEngineTest, SearchSuggestionsTest_NonPositiveNumToReturn) { + IcingSearchEngine icing(GetDefaultIcingOptions(), GetTestJniCache()); + ASSERT_THAT(icing.Initialize().status(), ProtoIsOk()); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("prefix"); + suggestion_spec.set_num_to_return(0); + + ASSERT_THAT(icing.SearchSuggestions(suggestion_spec).status(), + ProtoStatusIs(StatusProto::INVALID_ARGUMENT)); +} + #ifndef ICING_JNI_TEST // We skip this test case when we're running in a jni_test since the data files // will be stored in the android-instrumented storage location, rather than the diff --git a/icing/index/index-processor.cc b/icing/index/index-processor.cc index 6d8632f..1aae732 100644 --- a/icing/index/index-processor.cc +++ b/icing/index/index-processor.cc @@ -43,14 +43,13 @@ namespace lib { libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>> IndexProcessor::Create(const Normalizer* normalizer, Index* index, - const IndexProcessor::Options& options, const Clock* clock) { ICING_RETURN_ERROR_IF_NULL(normalizer); ICING_RETURN_ERROR_IF_NULL(index); ICING_RETURN_ERROR_IF_NULL(clock); return std::unique_ptr<IndexProcessor>( - new IndexProcessor(normalizer, index, options, clock)); + new IndexProcessor(normalizer, index, clock)); } libtextclassifier3::Status IndexProcessor::IndexDocument( @@ -66,53 +65,34 @@ libtextclassifier3::Status IndexProcessor::IndexDocument( } index_->set_last_added_document_id(document_id); uint32_t num_tokens = 0; - libtextclassifier3::Status overall_status; + libtextclassifier3::Status status; for (const TokenizedSection& section : tokenized_document.sections()) { // TODO(b/152934343): pass real namespace ids in Index::Editor editor = index_->Edit(document_id, section.metadata.id, section.metadata.term_match_type, /*namespace_id=*/0); for (std::string_view token : section.token_sequence) { - if (++num_tokens > options_.max_tokens_per_document) { - // Index all tokens buffered so far. - editor.IndexAllBufferedTerms(); - if (put_document_stats != nullptr) { - put_document_stats->mutable_tokenization_stats() - ->set_exceeded_max_token_num(true); - put_document_stats->mutable_tokenization_stats() - ->set_num_tokens_indexed(options_.max_tokens_per_document); - } - switch (options_.token_limit_behavior) { - case Options::TokenLimitBehavior::kReturnError: - return absl_ports::ResourceExhaustedError( - "Max number of tokens reached!"); - case Options::TokenLimitBehavior::kSuppressError: - return overall_status; - } - } + ++num_tokens; std::string term = normalizer_.NormalizeTerm(token); - // Add this term to Hit buffer. Even if adding this hit fails, we keep - // trying to add more hits because it's possible that future hits could - // still be added successfully. For instance if the lexicon is full, we - // might fail to add a hit for a new term, but should still be able to - // add hits for terms that are already in the index. - auto status = editor.BufferTerm(term.c_str()); - if (overall_status.ok() && !status.ok()) { - // If we've succeeded to add everything so far, set overall_status to - // represent this new failure. If we've already failed, no need to - // update the status - we're already going to return a resource - // exhausted error. - overall_status = status; + // Add this term to Hit buffer. + status = editor.BufferTerm(term.c_str()); + if (!status.ok()) { + // We've encountered a failure. Bail out. We'll mark this doc as deleted + // and signal a failure to the client. + ICING_LOG(WARNING) << "Failed to buffer term in lite lexicon due to: " + << status.error_message(); + break; } } + if (!status.ok()) { + break; + } // Add all the seen terms to the index with their term frequency. - auto status = editor.IndexAllBufferedTerms(); - if (overall_status.ok() && !status.ok()) { - // If we've succeeded so far, set overall_status to - // represent this new failure. If we've already failed, no need to - // update the status - we're already going to return a resource - // exhausted error. - overall_status = status; + status = editor.IndexAllBufferedTerms(); + if (!status.ok()) { + ICING_LOG(WARNING) << "Failed to add hits in lite index due to: " + << status.error_message(); + break; } } @@ -123,9 +103,11 @@ libtextclassifier3::Status IndexProcessor::IndexDocument( num_tokens); } - // Merge if necessary. - if (overall_status.ok() && index_->WantsMerge()) { - ICING_VLOG(1) << "Merging the index at docid " << document_id << "."; + // If we're either successful or we've hit resource exhausted, then attempt a + // merge. + if ((status.ok() || absl_ports::IsResourceExhausted(status)) && + index_->WantsMerge()) { + ICING_LOG(ERROR) << "Merging the index at docid " << document_id << "."; std::unique_ptr<Timer> merge_timer = clock_.GetNewTimer(); libtextclassifier3::Status merge_status = index_->Merge(); @@ -150,7 +132,7 @@ libtextclassifier3::Status IndexProcessor::IndexDocument( } } - return overall_status; + return status; } } // namespace lib diff --git a/icing/index/index-processor.h b/icing/index/index-processor.h index 6b07c98..c4b77b5 100644 --- a/icing/index/index-processor.h +++ b/icing/index/index-processor.h @@ -32,23 +32,6 @@ namespace lib { class IndexProcessor { public: - struct Options { - int32_t max_tokens_per_document; - - // Indicates how a document exceeding max_tokens_per_document should be - // handled. - enum class TokenLimitBehavior { - // When set, the first max_tokens_per_document will be indexed. If the - // token count exceeds max_tokens_per_document, a ResourceExhausted error - // will be returned. - kReturnError, - // When set, the first max_tokens_per_document will be indexed. If the - // token count exceeds max_tokens_per_document, OK will be returned. - kSuppressError, - }; - TokenLimitBehavior token_limit_behavior; - }; - // Factory function to create an IndexProcessor which does not take ownership // of any input components, and all pointers must refer to valid objects that // outlive the created IndexProcessor instance. @@ -57,8 +40,7 @@ class IndexProcessor { // An IndexProcessor on success // FAILED_PRECONDITION if any of the pointers is null. static libtextclassifier3::StatusOr<std::unique_ptr<IndexProcessor>> Create( - const Normalizer* normalizer, Index* index, const Options& options, - const Clock* clock); + const Normalizer* normalizer, Index* index, const Clock* clock); // Add tokenized document to the index, associated with document_id. If the // number of tokens in the document exceeds max_tokens_per_document, then only @@ -84,18 +66,13 @@ class IndexProcessor { PutDocumentStatsProto* put_document_stats = nullptr); private: - IndexProcessor(const Normalizer* normalizer, Index* index, - const Options& options, const Clock* clock) - : normalizer_(*normalizer), - index_(index), - options_(options), - clock_(*clock) {} + IndexProcessor(const Normalizer* normalizer, Index* index, const Clock* clock) + : normalizer_(*normalizer), index_(index), clock_(*clock) {} std::string NormalizeToken(const Token& token); const Normalizer& normalizer_; Index* const index_; - const Options options_; const Clock& clock_; }; diff --git a/icing/index/index-processor_benchmark.cc b/icing/index/index-processor_benchmark.cc index afeac4d..6e072c7 100644 --- a/icing/index/index-processor_benchmark.cc +++ b/icing/index/index-processor_benchmark.cc @@ -168,17 +168,6 @@ void CleanUp(const Filesystem& filesystem, const std::string& index_dir) { filesystem.DeleteDirectoryRecursively(index_dir.c_str()); } -std::unique_ptr<IndexProcessor> CreateIndexProcessor( - const Normalizer* normalizer, Index* index, const Clock* clock) { - IndexProcessor::Options processor_options{}; - processor_options.max_tokens_per_document = 1024 * 1024 * 10; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - - return IndexProcessor::Create(normalizer, index, processor_options, clock) - .ValueOrDie(); -} - void BM_IndexDocumentWithOneProperty(benchmark::State& state) { bool run_via_adb = absl::GetFlag(FLAGS_adb); if (!run_via_adb) { @@ -200,9 +189,9 @@ void BM_IndexDocumentWithOneProperty(benchmark::State& state) { std::unique_ptr<Normalizer> normalizer = CreateNormalizer(); Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); - std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(normalizer.get(), index.get(), &clock); - + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<IndexProcessor> index_processor, + IndexProcessor::Create(normalizer.get(), index.get(), &clock)); DocumentProto input_document = CreateDocumentWithOneProperty(state.range(0)); TokenizedDocument tokenized_document(std::move( TokenizedDocument::Create(schema_store.get(), language_segmenter.get(), @@ -254,8 +243,9 @@ void BM_IndexDocumentWithTenProperties(benchmark::State& state) { std::unique_ptr<Normalizer> normalizer = CreateNormalizer(); Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); - std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(normalizer.get(), index.get(), &clock); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<IndexProcessor> index_processor, + IndexProcessor::Create(normalizer.get(), index.get(), &clock)); DocumentProto input_document = CreateDocumentWithTenProperties(state.range(0)); @@ -309,8 +299,9 @@ void BM_IndexDocumentWithDiacriticLetters(benchmark::State& state) { std::unique_ptr<Normalizer> normalizer = CreateNormalizer(); Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); - std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(normalizer.get(), index.get(), &clock); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<IndexProcessor> index_processor, + IndexProcessor::Create(normalizer.get(), index.get(), &clock)); DocumentProto input_document = CreateDocumentWithDiacriticLetters(state.range(0)); @@ -364,8 +355,9 @@ void BM_IndexDocumentWithHiragana(benchmark::State& state) { std::unique_ptr<Normalizer> normalizer = CreateNormalizer(); Clock clock; std::unique_ptr<SchemaStore> schema_store = CreateSchemaStore(&clock); - std::unique_ptr<IndexProcessor> index_processor = - CreateIndexProcessor(normalizer.get(), index.get(), &clock); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<IndexProcessor> index_processor, + IndexProcessor::Create(normalizer.get(), index.get(), &clock)); DocumentProto input_document = CreateDocumentWithHiragana(state.range(0)); TokenizedDocument tokenized_document(std::move( diff --git a/icing/index/index-processor_test.cc b/icing/index/index-processor_test.cc index 8a6a9f5..449bc3e 100644 --- a/icing/index/index-processor_test.cc +++ b/icing/index/index-processor_test.cc @@ -27,6 +27,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "icing/absl_ports/str_cat.h" +#include "icing/absl_ports/str_join.h" #include "icing/document-builder.h" #include "icing/file/filesystem.h" #include "icing/helpers/icu/icu-data-file-helper.h" @@ -48,6 +49,7 @@ #include "icing/store/document-id.h" #include "icing/testing/common-matchers.h" #include "icing/testing/fake-clock.h" +#include "icing/testing/random-string.h" #include "icing/testing/test-data.h" #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" @@ -193,15 +195,9 @@ class IndexProcessorTest : public Test { .Build(); ICING_ASSERT_OK(schema_store_->SetSchema(schema)); - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), - processor_options, &fake_clock_)); + IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_)); mock_icing_filesystem_ = std::make_unique<IcingMockFilesystem>(); } @@ -232,17 +228,12 @@ std::vector<DocHitInfo> GetHits(std::unique_ptr<DocHitInfoIterator> iterator) { } TEST_F(IndexProcessorTest, CreationWithNullPointerShouldFail) { - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - EXPECT_THAT(IndexProcessor::Create(/*normalizer=*/nullptr, index_.get(), - processor_options, &fake_clock_), + &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); EXPECT_THAT(IndexProcessor::Create(normalizer_.get(), /*index=*/nullptr, - processor_options, &fake_clock_), + &fake_clock_), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } @@ -434,103 +425,68 @@ TEST_F(IndexProcessorTest, DocWithRepeatedProperty) { kDocumentId0, std::vector<SectionId>{kRepeatedSectionId}))); } -TEST_F(IndexProcessorTest, TooManyTokensReturnError) { - // Only allow the first four tokens ("hello", "world", "good", "night") to be - // indexed. - IndexProcessor::Options options; - options.max_tokens_per_document = 4; - options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; +// TODO(b/196771754) This test is disabled on Android because it takes too long +// to generate all of the unique terms and the test times out. Try storing these +// unique terms in a file that the test can read from. +#ifndef __ANDROID__ - ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, IndexProcessor::Create(normalizer_.get(), index_.get(), - options, &fake_clock_)); +TEST_F(IndexProcessorTest, HitBufferExhaustedTest) { + // Testing has shown that adding ~600,000 hits will fill up the hit buffer. + std::vector<std::string> unique_terms_ = GenerateUniqueTerms(200000); + std::string content = absl_ports::StrJoin(unique_terms_, " "); DocumentProto document = DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kExactProperty), "hello world") - .AddStringProperty(std::string(kPrefixedProperty), "good night moon!") + .AddStringProperty(std::string(kExactProperty), content) + .AddStringProperty(std::string(kPrefixedProperty), content) + .AddStringProperty(std::string(kRepeatedProperty), content) .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), - StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED, + testing::HasSubstr("Hit buffer is full!"))); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); - - // "night" should have been indexed. - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, - index_->GetIterator("night", kSectionIdMaskAll, - TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), - ElementsAre(EqualsDocHitInfo( - kDocumentId0, std::vector<SectionId>{kPrefixedSectionId}))); - - // "moon" should not have been. - ICING_ASSERT_OK_AND_ASSIGN(itr, - index_->GetIterator("moon", kSectionIdMaskAll, - TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); } -TEST_F(IndexProcessorTest, TooManyTokensSuppressError) { - // Only allow the first four tokens ("hello", "world", "good", "night") to be - // indexed. - IndexProcessor::Options options; - options.max_tokens_per_document = 4; - options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kSuppressError; - - ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, IndexProcessor::Create(normalizer_.get(), index_.get(), - options, &fake_clock_)); +TEST_F(IndexProcessorTest, LexiconExhaustedTest) { + // Testing has shown that adding ~300,000 terms generated this way will + // fill up the lexicon. + std::vector<std::string> unique_terms_ = GenerateUniqueTerms(300000); + std::string content = absl_ports::StrJoin(unique_terms_, " "); DocumentProto document = DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) - .AddStringProperty(std::string(kExactProperty), "hello world") - .AddStringProperty(std::string(kPrefixedProperty), "good night moon!") + .AddStringProperty(std::string(kExactProperty), content) .Build(); ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), document)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), - IsOk()); + StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED, + testing::HasSubstr("Unable to add term"))); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); - - // "night" should have been indexed. - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<DocHitInfoIterator> itr, - index_->GetIterator("night", kSectionIdMaskAll, - TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), - ElementsAre(EqualsDocHitInfo( - kDocumentId0, std::vector<SectionId>{kPrefixedSectionId}))); - - // "moon" should not have been. - ICING_ASSERT_OK_AND_ASSIGN(itr, - index_->GetIterator("moon", kSectionIdMaskAll, - TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), IsEmpty()); } +#endif // __ANDROID__ + TEST_F(IndexProcessorTest, TooLongTokens) { // Only allow the tokens of length four, truncating "hello", "world" and // "night". - IndexProcessor::Options options; - options.max_tokens_per_document = 1000; - ICING_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Normalizer> normalizer, normalizer_factory::Create( /*max_term_byte_size=*/4)); ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, IndexProcessor::Create(normalizer.get(), index_.get(), - options, &fake_clock_)); + index_processor_, + IndexProcessor::Create(normalizer.get(), index_.get(), &fake_clock_)); DocumentProto document = DocumentBuilder() @@ -692,16 +648,6 @@ TEST_F(IndexProcessorTest, NonAsciiIndexing) { lang_segmenter_, language_segmenter_factory::Create(std::move(segmenter_options))); - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - - ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, - &fake_clock_)); - DocumentProto document = DocumentBuilder() .SetKey("icing", "fake_type/1") @@ -727,23 +673,13 @@ TEST_F(IndexProcessorTest, NonAsciiIndexing) { TEST_F(IndexProcessorTest, LexiconFullIndexesSmallerTokensReturnsResourceExhausted) { - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - - ICING_ASSERT_OK_AND_ASSIGN( - index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, - &fake_clock_)); - // This is the maximum token length that an empty lexicon constructed for a // lite index with merge size of 1MiB can support. constexpr int kMaxTokenLength = 16777217; // Create a string "ppppppp..." with a length that is too large to fit into // the lexicon. std::string enormous_string(kMaxTokenLength + 1, 'p'); - DocumentProto document = + DocumentProto document_one = DocumentBuilder() .SetKey("icing", "fake_type/1") .SetSchema(std::string(kFakeType)) @@ -754,24 +690,10 @@ TEST_F(IndexProcessorTest, ICING_ASSERT_OK_AND_ASSIGN( TokenizedDocument tokenized_document, TokenizedDocument::Create(schema_store_.get(), lang_segmenter_.get(), - document)); + document_one)); EXPECT_THAT(index_processor_->IndexDocument(tokenized_document, kDocumentId0), StatusIs(libtextclassifier3::StatusCode::RESOURCE_EXHAUSTED)); EXPECT_THAT(index_->last_added_document_id(), Eq(kDocumentId0)); - - ICING_ASSERT_OK_AND_ASSIGN( - std::unique_ptr<DocHitInfoIterator> itr, - index_->GetIterator("foo", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), - ElementsAre(EqualsDocHitInfo( - kDocumentId0, std::vector<SectionId>{kExactSectionId}))); - - ICING_ASSERT_OK_AND_ASSIGN( - itr, - index_->GetIterator("baz", kSectionIdMaskAll, TermMatchType::EXACT_ONLY)); - EXPECT_THAT(GetHits(std::move(itr)), - ElementsAre(EqualsDocHitInfo( - kDocumentId0, std::vector<SectionId>{kPrefixedSectionId}))); } TEST_F(IndexProcessorTest, IndexingDocAutomaticMerge) { @@ -795,15 +717,9 @@ TEST_F(IndexProcessorTest, IndexingDocAutomaticMerge) { ICING_ASSERT_OK_AND_ASSIGN( index_, Index::Create(options, &filesystem_, &icing_filesystem_)); - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, - &fake_clock_)); + IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_)); DocumentId doc_id = 0; // Have determined experimentally that indexing 3373 documents with this text // will cause the LiteIndex to fill up. Further indexing will fail unless the @@ -857,15 +773,9 @@ TEST_F(IndexProcessorTest, IndexingDocMergeFailureResets) { index_, Index::Create(options, &filesystem_, mock_icing_filesystem_.get())); - IndexProcessor::Options processor_options; - processor_options.max_tokens_per_document = 1000; - processor_options.token_limit_behavior = - IndexProcessor::Options::TokenLimitBehavior::kReturnError; - ICING_ASSERT_OK_AND_ASSIGN( index_processor_, - IndexProcessor::Create(normalizer_.get(), index_.get(), processor_options, - &fake_clock_)); + IndexProcessor::Create(normalizer_.get(), index_.get(), &fake_clock_)); // 3. Index one document. This should fit in the LiteIndex without requiring a // merge. diff --git a/icing/index/index.cc b/icing/index/index.cc index db59ad2..1bdab21 100644 --- a/icing/index/index.cc +++ b/icing/index/index.cc @@ -36,6 +36,7 @@ #include "icing/legacy/index/icing-filesystem.h" #include "icing/proto/term.pb.h" #include "icing/schema/section.h" +#include "icing/scoring/ranker.h" #include "icing/store/document-id.h" #include "icing/util/logging.h" #include "icing/util/status-macros.h" @@ -89,20 +90,24 @@ bool IsTermInNamespaces( } enum class MergeAction { kTakeLiteTerm, kTakeMainTerm, kMergeTerms }; -std::vector<TermMetadata> MergeTermMetadatas( + +// Merge the TermMetadata from lite index and main index. If the term exists in +// both index, sum up its hit count and push it to the term heap. +// The heap is a min-heap. So that we can avoid some push operation but the time +// complexity is O(NlgK) which N is total number of term and K is num_to_return. +std::vector<TermMetadata> MergeAndRankTermMetadatas( std::vector<TermMetadata> lite_term_metadata_list, std::vector<TermMetadata> main_term_metadata_list, int num_to_return) { - std::vector<TermMetadata> merged_term_metadata_list; - merged_term_metadata_list.reserve( + std::vector<TermMetadata> merged_term_metadata_heap; + merged_term_metadata_heap.reserve( std::min(lite_term_metadata_list.size() + main_term_metadata_list.size(), static_cast<size_t>(num_to_return))); auto lite_term_itr = lite_term_metadata_list.begin(); auto main_term_itr = main_term_metadata_list.begin(); MergeAction merge_action; - while (merged_term_metadata_list.size() < num_to_return && - (lite_term_itr != lite_term_metadata_list.end() || - main_term_itr != main_term_metadata_list.end())) { + while (lite_term_itr != lite_term_metadata_list.end() || + main_term_itr != main_term_metadata_list.end()) { // Get pointers to the next metadatas in each group, if available // Determine how to merge. if (main_term_itr == main_term_metadata_list.end()) { @@ -119,23 +124,32 @@ std::vector<TermMetadata> MergeTermMetadatas( } switch (merge_action) { case MergeAction::kTakeLiteTerm: - merged_term_metadata_list.push_back(std::move(*lite_term_itr)); + PushToTermHeap(std::move(*lite_term_itr), num_to_return, + merged_term_metadata_heap); ++lite_term_itr; break; case MergeAction::kTakeMainTerm: - merged_term_metadata_list.push_back(std::move(*main_term_itr)); + PushToTermHeap(std::move(*main_term_itr), num_to_return, + merged_term_metadata_heap); ++main_term_itr; break; case MergeAction::kMergeTerms: int total_est_hit_count = lite_term_itr->hit_count + main_term_itr->hit_count; - merged_term_metadata_list.emplace_back( - std::move(lite_term_itr->content), total_est_hit_count); + PushToTermHeap(TermMetadata(std::move(lite_term_itr->content), + total_est_hit_count), + num_to_return, merged_term_metadata_heap); ++lite_term_itr; ++main_term_itr; break; } } + // Reverse the list since we pop them from a min heap and we need to return in + // decreasing order. + std::vector<TermMetadata> merged_term_metadata_list = + PopAllTermsFromHeap(merged_term_metadata_heap); + std::reverse(merged_term_metadata_list.begin(), + merged_term_metadata_list.end()); return merged_term_metadata_list; } @@ -214,8 +228,7 @@ Index::GetIterator(const std::string& term, SectionIdMask section_id_mask, libtextclassifier3::StatusOr<std::vector<TermMetadata>> Index::FindLiteTermsByPrefix(const std::string& prefix, - const std::vector<NamespaceId>& namespace_ids, - int num_to_return) { + const std::vector<NamespaceId>& namespace_ids) { // Finds all the terms that start with the given prefix in the lexicon. IcingDynamicTrie::Iterator term_iterator(lite_index_->lexicon(), prefix.c_str()); @@ -224,7 +237,7 @@ Index::FindLiteTermsByPrefix(const std::string& prefix, IcingDynamicTrie::PropertyReadersAll property_reader(lite_index_->lexicon()); std::vector<TermMetadata> term_metadata_list; - while (term_iterator.IsValid() && term_metadata_list.size() < num_to_return) { + while (term_iterator.IsValid()) { uint32_t term_value_index = term_iterator.GetValueIndex(); // Skips the terms that don't exist in the given namespaces. We won't skip @@ -244,13 +257,6 @@ Index::FindLiteTermsByPrefix(const std::string& prefix, term_iterator.Advance(); } - if (term_iterator.IsValid()) { - // We exited the loop above because we hit the num_to_return limit. - ICING_LOG(WARNING) << "Ran into limit of " << num_to_return - << " retrieving suggestions for " << prefix - << ". Some suggestions may not be returned and others " - "may be misranked."; - } return term_metadata_list; } @@ -264,17 +270,15 @@ Index::FindTermsByPrefix(const std::string& prefix, } // Get results from the LiteIndex. - ICING_ASSIGN_OR_RETURN( - std::vector<TermMetadata> lite_term_metadata_list, - FindLiteTermsByPrefix(prefix, namespace_ids, num_to_return)); - + ICING_ASSIGN_OR_RETURN(std::vector<TermMetadata> lite_term_metadata_list, + FindLiteTermsByPrefix(prefix, namespace_ids)); // Append results from the MainIndex. - ICING_ASSIGN_OR_RETURN( - std::vector<TermMetadata> main_term_metadata_list, - main_index_->FindTermsByPrefix(prefix, namespace_ids, num_to_return)); + ICING_ASSIGN_OR_RETURN(std::vector<TermMetadata> main_term_metadata_list, + main_index_->FindTermsByPrefix(prefix, namespace_ids)); - return MergeTermMetadatas(std::move(lite_term_metadata_list), - std::move(main_term_metadata_list), num_to_return); + return MergeAndRankTermMetadatas(std::move(lite_term_metadata_list), + std::move(main_term_metadata_list), + num_to_return); } IndexStorageInfoProto Index::GetStorageInfo() const { diff --git a/icing/index/index.h b/icing/index/index.h index eab5be8..693cf04 100644 --- a/icing/index/index.h +++ b/icing/index/index.h @@ -267,8 +267,7 @@ class Index { filesystem_(filesystem) {} libtextclassifier3::StatusOr<std::vector<TermMetadata>> FindLiteTermsByPrefix( - const std::string& prefix, const std::vector<NamespaceId>& namespace_ids, - int num_to_return); + const std::string& prefix, const std::vector<NamespaceId>& namespace_ids); std::unique_ptr<LiteIndex> lite_index_; std::unique_ptr<MainIndex> main_index_; diff --git a/icing/index/index_test.cc b/icing/index/index_test.cc index 16593ef..00d5ad6 100644 --- a/icing/index/index_test.cc +++ b/icing/index/index_test.cc @@ -88,6 +88,11 @@ constexpr DocumentId kDocumentId4 = 4; constexpr DocumentId kDocumentId5 = 5; constexpr DocumentId kDocumentId6 = 6; constexpr DocumentId kDocumentId7 = 7; +constexpr DocumentId kDocumentId8 = 8; +constexpr DocumentId kDocumentId9 = 9; +constexpr DocumentId kDocumentId10 = 10; +constexpr DocumentId kDocumentId11 = 11; +constexpr DocumentId kDocumentId12 = 12; constexpr SectionId kSectionId2 = 2; constexpr SectionId kSectionId3 = 3; @@ -1105,11 +1110,10 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnCorrectHitCount) { EXPECT_THAT(edit2.IndexAllBufferedTerms(), IsOk()); // 'foo' has 1 hit, 'fool' has 2 hits. - EXPECT_THAT( - index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, - /*num_to_return=*/10), - IsOkAndHolds(UnorderedElementsAre(EqualsTermMetadata("foo", 1), - EqualsTermMetadata("fool", 2)))); + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("fool", 2), + EqualsTermMetadata("foo", 1)))); ICING_ASSERT_OK(index_->Merge()); @@ -1122,6 +1126,155 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnCorrectHitCount) { EqualsTermMetadata("fool", kMinSizePlApproxHits)))); } +TEST_F(IndexTest, FindTermByPrefixShouldReturnInOrder) { + // Push 6 term-six, 5 term-five, 4 term-four, 3 term-three, 2 term-two and one + // term-one into lite index. + Index::Editor edit1 = + index_->Edit(kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit1.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit1.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit1.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit2 = + index_->Edit(kDocumentId2, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit2.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit2.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit2.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit2.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit2.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit2.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit3 = + index_->Edit(kDocumentId3, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit3.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit3.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit3.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit3.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit3.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit4 = + index_->Edit(kDocumentId4, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit4.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit4.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit4.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit4.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit5 = + index_->Edit(kDocumentId5, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit5.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit5.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit5.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit6 = + index_->Edit(kDocumentId6, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit6.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit6.IndexAllBufferedTerms(), IsOk()); + + // verify the order in lite index is correct. + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"t", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("term-six", 6), + EqualsTermMetadata("term-five", 5), + EqualsTermMetadata("term-four", 4), + EqualsTermMetadata("term-three", 3), + EqualsTermMetadata("term-two", 2), + EqualsTermMetadata("term-one", 1)))); + + ICING_ASSERT_OK(index_->Merge()); + + // Since most of term has same approx hit count, we don't verify order in the + // main index. + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"t", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(UnorderedElementsAre( + EqualsTermMetadata("term-six", kSecondSmallestPlApproxHits), + EqualsTermMetadata("term-five", kSecondSmallestPlApproxHits), + EqualsTermMetadata("term-four", kMinSizePlApproxHits), + EqualsTermMetadata("term-three", kMinSizePlApproxHits), + EqualsTermMetadata("term-two", kMinSizePlApproxHits), + EqualsTermMetadata("term-one", kMinSizePlApproxHits)))); + + // keep push terms to the lite index. For term 1-4, since they has same hit + // count kMinSizePlApproxHits, we will push 4 term-one, 3 term-two, 2 + // term-three and one term-four to make them in reverse order. And for term + // 5 & 6, we will push 2 term-five and one term-six. + Index::Editor edit7 = + index_->Edit(kDocumentId7, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit7.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit7.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit7.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit7.BufferTerm("term-four"), IsOk()); + EXPECT_THAT(edit7.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit8 = + index_->Edit(kDocumentId8, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit8.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit8.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit8.BufferTerm("term-three"), IsOk()); + EXPECT_THAT(edit8.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit9 = + index_->Edit(kDocumentId9, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit9.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit9.BufferTerm("term-two"), IsOk()); + EXPECT_THAT(edit9.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit10 = + index_->Edit(kDocumentId10, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit10.BufferTerm("term-one"), IsOk()); + EXPECT_THAT(edit10.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit11 = + index_->Edit(kDocumentId11, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit11.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit11.BufferTerm("term-six"), IsOk()); + EXPECT_THAT(edit11.IndexAllBufferedTerms(), IsOk()); + + Index::Editor edit12 = + index_->Edit(kDocumentId12, kSectionId2, TermMatchType::EXACT_ONLY, + /*namespace_id=*/0); + EXPECT_THAT(edit12.BufferTerm("term-five"), IsOk()); + EXPECT_THAT(edit12.IndexAllBufferedTerms(), IsOk()); + + // verify the combination of lite index and main index is in correct order. + EXPECT_THAT( + index_->FindTermsByPrefix(/*prefix=*/"t", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre( + EqualsTermMetadata("term-five", + kSecondSmallestPlApproxHits + 2), // 9 + EqualsTermMetadata("term-six", kSecondSmallestPlApproxHits + 1), // 8 + EqualsTermMetadata("term-one", kMinSizePlApproxHits + 4), // 7 + EqualsTermMetadata("term-two", kMinSizePlApproxHits + 3), // 6 + EqualsTermMetadata("term-three", kMinSizePlApproxHits + 2), // 5 + EqualsTermMetadata("term-four", kMinSizePlApproxHits + 1)))); // 4 + + // Get the first three terms. + EXPECT_THAT( + index_->FindTermsByPrefix(/*prefix=*/"t", /*namespace_ids=*/{0}, + /*num_to_return=*/3), + IsOkAndHolds(ElementsAre( + EqualsTermMetadata("term-five", + kSecondSmallestPlApproxHits + 2), // 9 + EqualsTermMetadata("term-six", kSecondSmallestPlApproxHits + 1), // 8 + EqualsTermMetadata("term-one", kMinSizePlApproxHits + 4)))); // 7 +} + TEST_F(IndexTest, FindTermByPrefixShouldReturnApproximateHitCountForMain) { Index::Editor edit = index_->Edit(kDocumentId0, kSectionId2, TermMatchType::EXACT_ONLY, @@ -1160,11 +1313,10 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnApproximateHitCountForMain) { EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); // 'foo' has 1 hit, 'fool' has 8 hits. - EXPECT_THAT( - index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, - /*num_to_return=*/10), - IsOkAndHolds(UnorderedElementsAre(EqualsTermMetadata("foo", 1), - EqualsTermMetadata("fool", 8)))); + EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("fool", 8), + EqualsTermMetadata("foo", 1)))); ICING_ASSERT_OK(index_->Merge()); @@ -1195,9 +1347,9 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnCombinedHitCount) { // 1 hit in the lite index. EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, /*num_to_return=*/10), - IsOkAndHolds(UnorderedElementsAre( - EqualsTermMetadata("foo", kMinSizePlApproxHits), - EqualsTermMetadata("fool", kMinSizePlApproxHits + 1)))); + IsOkAndHolds(ElementsAre( + EqualsTermMetadata("fool", kMinSizePlApproxHits + 1), + EqualsTermMetadata("foo", kMinSizePlApproxHits)))); } TEST_F(IndexTest, FindTermByPrefixShouldReturnTermsFromBothIndices) { @@ -1215,11 +1367,11 @@ TEST_F(IndexTest, FindTermByPrefixShouldReturnTermsFromBothIndices) { EXPECT_THAT(edit.IndexAllBufferedTerms(), IsOk()); // 'foo' has 1 hit in the main index, 'fool' has 1 hit in the lite index. - EXPECT_THAT(index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, - /*num_to_return=*/10), - IsOkAndHolds(UnorderedElementsAre( - EqualsTermMetadata("foo", kMinSizePlApproxHits), - EqualsTermMetadata("fool", 1)))); + EXPECT_THAT( + index_->FindTermsByPrefix(/*prefix=*/"f", /*namespace_ids=*/{0}, + /*num_to_return=*/10), + IsOkAndHolds(ElementsAre(EqualsTermMetadata("foo", kMinSizePlApproxHits), + EqualsTermMetadata("fool", 1)))); } TEST_F(IndexTest, GetElementsSize) { diff --git a/icing/index/iterator/doc-hit-info-iterator-and.cc b/icing/index/iterator/doc-hit-info-iterator-and.cc index 39aa969..543e9ef 100644 --- a/icing/index/iterator/doc-hit-info-iterator-and.cc +++ b/icing/index/iterator/doc-hit-info-iterator-and.cc @@ -14,8 +14,7 @@ #include "icing/index/iterator/doc-hit-info-iterator-and.h" -#include <stddef.h> - +#include <cstddef> #include <cstdint> #include <memory> #include <string> diff --git a/icing/index/lite/lite-index.cc b/icing/index/lite/lite-index.cc index fb23934..9e4ac28 100644 --- a/icing/index/lite/lite-index.cc +++ b/icing/index/lite/lite-index.cc @@ -14,12 +14,11 @@ #include "icing/index/lite/lite-index.h" -#include <inttypes.h> -#include <stddef.h> -#include <stdint.h> #include <sys/mman.h> #include <algorithm> +#include <cinttypes> +#include <cstddef> #include <cstdint> #include <memory> #include <string> diff --git a/icing/index/main/flash-index-storage.cc b/icing/index/main/flash-index-storage.cc index f125b6d..3c52375 100644 --- a/icing/index/main/flash-index-storage.cc +++ b/icing/index/main/flash-index-storage.cc @@ -14,11 +14,11 @@ #include "icing/index/main/flash-index-storage.h" -#include <errno.h> -#include <inttypes.h> #include <sys/types.h> #include <algorithm> +#include <cerrno> +#include <cinttypes> #include <cstdint> #include <memory> #include <unordered_set> diff --git a/icing/index/main/flash-index-storage_test.cc b/icing/index/main/flash-index-storage_test.cc index 7e15524..25fcaad 100644 --- a/icing/index/main/flash-index-storage_test.cc +++ b/icing/index/main/flash-index-storage_test.cc @@ -14,10 +14,10 @@ #include "icing/index/main/flash-index-storage.h" -#include <stdlib.h> #include <unistd.h> #include <algorithm> +#include <cstdlib> #include <limits> #include <utility> #include <vector> diff --git a/icing/index/main/index-block.cc b/icing/index/main/index-block.cc index 4590d06..c6ab345 100644 --- a/icing/index/main/index-block.cc +++ b/icing/index/main/index-block.cc @@ -14,9 +14,8 @@ #include "icing/index/main/index-block.h" -#include <inttypes.h> - #include <algorithm> +#include <cinttypes> #include <limits> #include "icing/text_classifier/lib3/utils/base/statusor.h" diff --git a/icing/index/main/index-block.h b/icing/index/main/index-block.h index edf9a79..5d75a2a 100644 --- a/icing/index/main/index-block.h +++ b/icing/index/main/index-block.h @@ -15,10 +15,10 @@ #ifndef ICING_INDEX_MAIN_INDEX_BLOCK_H_ #define ICING_INDEX_MAIN_INDEX_BLOCK_H_ -#include <string.h> #include <sys/mman.h> #include <algorithm> +#include <cstring> #include <limits> #include <memory> #include <string> diff --git a/icing/index/main/main-index.cc b/icing/index/main/main-index.cc index 8ae6b27..b185138 100644 --- a/icing/index/main/main-index.cc +++ b/icing/index/main/main-index.cc @@ -217,8 +217,7 @@ bool IsTermInNamespaces( libtextclassifier3::StatusOr<std::vector<TermMetadata>> MainIndex::FindTermsByPrefix(const std::string& prefix, - const std::vector<NamespaceId>& namespace_ids, - int num_to_return) { + const std::vector<NamespaceId>& namespace_ids) { // Finds all the terms that start with the given prefix in the lexicon. IcingDynamicTrie::Iterator term_iterator(*main_lexicon_, prefix.c_str()); @@ -226,7 +225,7 @@ MainIndex::FindTermsByPrefix(const std::string& prefix, IcingDynamicTrie::PropertyReadersAll property_reader(*main_lexicon_); std::vector<TermMetadata> term_metadata_list; - while (term_iterator.IsValid() && term_metadata_list.size() < num_to_return) { + while (term_iterator.IsValid()) { uint32_t term_value_index = term_iterator.GetValueIndex(); // Skips the terms that don't exist in the given namespaces. We won't skip @@ -250,13 +249,6 @@ MainIndex::FindTermsByPrefix(const std::string& prefix, term_iterator.Advance(); } - if (term_iterator.IsValid()) { - // We exited the loop above because we hit the num_to_return limit. - ICING_LOG(WARNING) << "Ran into limit of " << num_to_return - << " retrieving suggestions for " << prefix - << ". Some suggestions may not be returned and others " - "may be misranked."; - } return term_metadata_list; } diff --git a/icing/index/main/main-index.h b/icing/index/main/main-index.h index 43635ca..919a5c5 100644 --- a/icing/index/main/main-index.h +++ b/icing/index/main/main-index.h @@ -81,8 +81,7 @@ class MainIndex { // A list of TermMetadata on success // INTERNAL_ERROR if failed to access term data. libtextclassifier3::StatusOr<std::vector<TermMetadata>> FindTermsByPrefix( - const std::string& prefix, const std::vector<NamespaceId>& namespace_ids, - int num_to_return); + const std::string& prefix, const std::vector<NamespaceId>& namespace_ids); struct LexiconMergeOutputs { // Maps from main_lexicon tvi for new branching point to the main_lexicon diff --git a/icing/index/main/posting-list-free.h b/icing/index/main/posting-list-free.h index 4f06057..75b99d7 100644 --- a/icing/index/main/posting-list-free.h +++ b/icing/index/main/posting-list-free.h @@ -15,10 +15,10 @@ #ifndef ICING_INDEX_MAIN_POSTING_LIST_FREE_H_ #define ICING_INDEX_MAIN_POSTING_LIST_FREE_H_ -#include <string.h> #include <sys/mman.h> #include <cstdint> +#include <cstring> #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/absl_ports/canonical_errors.h" diff --git a/icing/index/main/posting-list-used.h b/icing/index/main/posting-list-used.h index 1b2e24e..8944034 100644 --- a/icing/index/main/posting-list-used.h +++ b/icing/index/main/posting-list-used.h @@ -15,10 +15,10 @@ #ifndef ICING_INDEX_MAIN_POSTING_LIST_USED_H_ #define ICING_INDEX_MAIN_POSTING_LIST_USED_H_ -#include <string.h> #include <sys/mman.h> #include <algorithm> +#include <cstring> #include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" diff --git a/icing/jni/icing-search-engine-jni.cc b/icing/jni/icing-search-engine-jni.cc index ea2bcf7..51d3423 100644 --- a/icing/jni/icing-search-engine-jni.cc +++ b/icing/jni/icing-search-engine-jni.cc @@ -420,4 +420,23 @@ Java_com_google_android_icing_IcingSearchEngine_nativeReset( return SerializeProtoToJniByteArray(env, reset_result_proto); } +JNIEXPORT jbyteArray JNICALL +Java_com_google_android_icing_IcingSearchEngine_nativeSearchSuggestions( + JNIEnv* env, jclass clazz, jobject object, + jbyteArray suggestion_spec_bytes) { + icing::lib::IcingSearchEngine* icing = + GetIcingSearchEnginePointer(env, object); + + icing::lib::SuggestionSpecProto suggestion_spec_proto; + if (!ParseProtoFromJniByteArray(env, suggestion_spec_bytes, + &suggestion_spec_proto)) { + ICING_LOG(ERROR) << "Failed to parse SuggestionSpecProto in nativeSearch"; + return nullptr; + } + icing::lib::SuggestionResponse suggestionResponse = + icing->SearchSuggestions(suggestion_spec_proto); + + return SerializeProtoToJniByteArray(env, suggestionResponse); +} + } // extern "C" diff --git a/icing/legacy/core/icing-core-types.h b/icing/legacy/core/icing-core-types.h index cc12663..7db8408 100644 --- a/icing/legacy/core/icing-core-types.h +++ b/icing/legacy/core/icing-core-types.h @@ -21,9 +21,8 @@ #ifndef ICING_LEGACY_CORE_ICING_CORE_TYPES_H_ #define ICING_LEGACY_CORE_ICING_CORE_TYPES_H_ -#include <stdint.h> - #include <cstddef> // size_t not defined implicitly for all platforms. +#include <cstdint> #include <vector> #include "icing/legacy/core/icing-compat.h" diff --git a/icing/legacy/core/icing-string-util.cc b/icing/legacy/core/icing-string-util.cc index 2eb64ac..ed06e03 100644 --- a/icing/legacy/core/icing-string-util.cc +++ b/icing/legacy/core/icing-string-util.cc @@ -13,12 +13,11 @@ // limitations under the License. #include "icing/legacy/core/icing-string-util.h" -#include <stdarg.h> -#include <stddef.h> -#include <stdint.h> -#include <stdio.h> - #include <algorithm> +#include <cstdarg> +#include <cstddef> +#include <cstdint> +#include <cstdio> #include <string> #include "icing/legacy/portable/icing-zlib.h" diff --git a/icing/legacy/core/icing-string-util.h b/icing/legacy/core/icing-string-util.h index 767e581..e5e4941 100644 --- a/icing/legacy/core/icing-string-util.h +++ b/icing/legacy/core/icing-string-util.h @@ -15,9 +15,8 @@ #ifndef ICING_LEGACY_CORE_ICING_STRING_UTIL_H_ #define ICING_LEGACY_CORE_ICING_STRING_UTIL_H_ -#include <stdarg.h> -#include <stdint.h> - +#include <cstdarg> +#include <cstdint> #include <string> #include "icing/legacy/core/icing-compat.h" diff --git a/icing/legacy/core/icing-timer.h b/icing/legacy/core/icing-timer.h index 49ba9ad..af38912 100644 --- a/icing/legacy/core/icing-timer.h +++ b/icing/legacy/core/icing-timer.h @@ -16,7 +16,8 @@ #define ICING_LEGACY_CORE_ICING_TIMER_H_ #include <sys/time.h> -#include <time.h> + +#include <ctime> namespace icing { namespace lib { diff --git a/icing/legacy/index/icing-array-storage.cc b/icing/legacy/index/icing-array-storage.cc index b462135..4d2ef67 100644 --- a/icing/legacy/index/icing-array-storage.cc +++ b/icing/legacy/index/icing-array-storage.cc @@ -14,10 +14,10 @@ #include "icing/legacy/index/icing-array-storage.h" -#include <inttypes.h> #include <sys/mman.h> #include <algorithm> +#include <cinttypes> #include "icing/legacy/core/icing-string-util.h" #include "icing/legacy/core/icing-timer.h" diff --git a/icing/legacy/index/icing-array-storage.h b/icing/legacy/index/icing-array-storage.h index fad0565..0d93172 100644 --- a/icing/legacy/index/icing-array-storage.h +++ b/icing/legacy/index/icing-array-storage.h @@ -20,8 +20,7 @@ #ifndef ICING_LEGACY_INDEX_ICING_ARRAY_STORAGE_H_ #define ICING_LEGACY_INDEX_ICING_ARRAY_STORAGE_H_ -#include <stdint.h> - +#include <cstdint> #include <string> #include <vector> diff --git a/icing/legacy/index/icing-bit-util.h b/icing/legacy/index/icing-bit-util.h index 3273a68..d0c3f50 100644 --- a/icing/legacy/index/icing-bit-util.h +++ b/icing/legacy/index/icing-bit-util.h @@ -20,9 +20,8 @@ #ifndef ICING_LEGACY_INDEX_ICING_BIT_UTIL_H_ #define ICING_LEGACY_INDEX_ICING_BIT_UTIL_H_ -#include <stdint.h> -#include <stdio.h> - +#include <cstdint> +#include <cstdio> #include <limits> #include <vector> diff --git a/icing/legacy/index/icing-dynamic-trie.cc b/icing/legacy/index/icing-dynamic-trie.cc index 29843ba..baa043a 100644 --- a/icing/legacy/index/icing-dynamic-trie.cc +++ b/icing/legacy/index/icing-dynamic-trie.cc @@ -62,15 +62,15 @@ #include "icing/legacy/index/icing-dynamic-trie.h" -#include <errno.h> #include <fcntl.h> -#include <inttypes.h> -#include <string.h> #include <sys/mman.h> #include <sys/stat.h> #include <unistd.h> #include <algorithm> +#include <cerrno> +#include <cinttypes> +#include <cstring> #include <memory> #include <utility> diff --git a/icing/legacy/index/icing-dynamic-trie.h b/icing/legacy/index/icing-dynamic-trie.h index 7fe290b..8821799 100644 --- a/icing/legacy/index/icing-dynamic-trie.h +++ b/icing/legacy/index/icing-dynamic-trie.h @@ -35,8 +35,7 @@ #ifndef ICING_LEGACY_INDEX_ICING_DYNAMIC_TRIE_H_ #define ICING_LEGACY_INDEX_ICING_DYNAMIC_TRIE_H_ -#include <stdint.h> - +#include <cstdint> #include <memory> #include <string> #include <unordered_map> diff --git a/icing/legacy/index/icing-filesystem.cc b/icing/legacy/index/icing-filesystem.cc index 90e9146..4f5e571 100644 --- a/icing/legacy/index/icing-filesystem.cc +++ b/icing/legacy/index/icing-filesystem.cc @@ -16,7 +16,6 @@ #include <dirent.h> #include <dlfcn.h> -#include <errno.h> #include <fcntl.h> #include <fnmatch.h> #include <pthread.h> @@ -27,6 +26,7 @@ #include <unistd.h> #include <algorithm> +#include <cerrno> #include <unordered_set> #include "icing/absl_ports/str_cat.h" diff --git a/icing/legacy/index/icing-flash-bitmap.h b/icing/legacy/index/icing-flash-bitmap.h index 3b3521a..e3ba0e2 100644 --- a/icing/legacy/index/icing-flash-bitmap.h +++ b/icing/legacy/index/icing-flash-bitmap.h @@ -37,8 +37,7 @@ #ifndef ICING_LEGACY_INDEX_ICING_FLASH_BITMAP_H_ #define ICING_LEGACY_INDEX_ICING_FLASH_BITMAP_H_ -#include <stdint.h> - +#include <cstdint> #include <memory> #include <string> diff --git a/icing/legacy/index/icing-mmapper.cc b/icing/legacy/index/icing-mmapper.cc index 737335c..7946c82 100644 --- a/icing/legacy/index/icing-mmapper.cc +++ b/icing/legacy/index/icing-mmapper.cc @@ -17,10 +17,11 @@ // #include "icing/legacy/index/icing-mmapper.h" -#include <errno.h> -#include <string.h> #include <sys/mman.h> +#include <cerrno> +#include <cstring> + #include "icing/legacy/core/icing-string-util.h" #include "icing/legacy/index/icing-filesystem.h" #include "icing/util/logging.h" diff --git a/icing/legacy/index/icing-mock-filesystem.h b/icing/legacy/index/icing-mock-filesystem.h index 75ac62f..122ee7b 100644 --- a/icing/legacy/index/icing-mock-filesystem.h +++ b/icing/legacy/index/icing-mock-filesystem.h @@ -15,16 +15,15 @@ #ifndef ICING_LEGACY_INDEX_ICING_MOCK_FILESYSTEM_H_ #define ICING_LEGACY_INDEX_ICING_MOCK_FILESYSTEM_H_ -#include <stdint.h> -#include <stdio.h> -#include <string.h> - +#include <cstdint> +#include <cstdio> +#include <cstring> #include <memory> #include <string> #include <vector> -#include "icing/legacy/index/icing-filesystem.h" #include "gmock/gmock.h" +#include "icing/legacy/index/icing-filesystem.h" namespace icing { namespace lib { diff --git a/icing/legacy/index/icing-storage-file.cc b/icing/legacy/index/icing-storage-file.cc index b27ec67..35a4418 100644 --- a/icing/legacy/index/icing-storage-file.cc +++ b/icing/legacy/index/icing-storage-file.cc @@ -14,9 +14,9 @@ #include "icing/legacy/index/icing-storage-file.h" -#include <inttypes.h> #include <unistd.h> +#include <cinttypes> #include <string> #include "icing/legacy/core/icing-compat.h" diff --git a/icing/portable/endian.h b/icing/portable/endian.h index 42f6c02..ecebb15 100644 --- a/icing/portable/endian.h +++ b/icing/portable/endian.h @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// Utility functions that depend on bytesex. We define htonll and ntohll, -// as well as "Google" versions of all the standards: ghtonl, ghtons, and -// so on. These functions do exactly the same as their standard variants, -// but don't require including the dangerous netinet/in.h. +// Utility functions that depend on bytesex. We define versions of htonll and +// ntohll (HostToNetworkLL and NetworkToHostLL in our naming), as well as +// "Google" versions of all the standards: ghtonl, ghtons, and so on +// (GHostToNetworkL, GHostToNetworkS, etc in our naming). These functions do +// exactly the same as their standard variants, but don't require including the +// dangerous netinet/in.h. #ifndef ICING_PORTABLE_ENDIAN_H_ #define ICING_PORTABLE_ENDIAN_H_ @@ -75,7 +77,7 @@ // The following guarantees declaration of the byte swap functions #ifdef COMPILER_MSVC -#include <stdlib.h> // NOLINT(build/include) +#include <cstdlib> // NOLINT(build/include) #define bswap_16(x) _byteswap_ushort(x) #define bswap_32(x) _byteswap_ulong(x) @@ -170,37 +172,37 @@ inline uint16 gbswap_16(uint16 host_int) { return bswap_16(host_int); } // correctly handle the (rather involved) definitions of bswap_32. // gcc guarantees that inline functions are as fast as macros, so // this isn't a performance hit. -inline uint16_t ghtons(uint16_t x) { return gbswap_16(x); } -inline uint32_t ghtonl(uint32_t x) { return gbswap_32(x); } -inline uint64_t ghtonll(uint64_t x) { return gbswap_64(x); } +inline uint16_t GHostToNetworkS(uint16_t x) { return gbswap_16(x); } +inline uint32_t GHostToNetworkL(uint32_t x) { return gbswap_32(x); } +inline uint64_t GHostToNetworkLL(uint64_t x) { return gbswap_64(x); } #elif defined IS_BIG_ENDIAN // These definitions are simpler on big-endian machines // These are functions instead of macros to avoid self-assignment warnings // on calls such as "i = ghtnol(i);". This also provides type checking. -inline uint16 ghtons(uint16 x) { return x; } -inline uint32 ghtonl(uint32 x) { return x; } -inline uint64 ghtonll(uint64 x) { return x; } +inline uint16 GHostToNetworkS(uint16 x) { return x; } +inline uint32 GHostToNetworkL(uint32 x) { return x; } +inline uint64 GHostToNetworkLL(uint64 x) { return x; } #else // bytesex #error \ "Unsupported bytesex: Either IS_BIG_ENDIAN or IS_LITTLE_ENDIAN must be defined" // NOLINT #endif // bytesex -#ifndef htonll +#ifndef HostToNetworkLL // With the rise of 64-bit, some systems are beginning to define this. -#define htonll(x) ghtonll(x) -#endif // htonll +#define HostToNetworkLL(x) GHostToNetworkLL(x) +#endif // HostToNetworkLL // ntoh* and hton* are the same thing for any size and bytesex, // since the function is an involution, i.e., its own inverse. -inline uint16_t gntohs(uint16_t x) { return ghtons(x); } -inline uint32_t gntohl(uint32_t x) { return ghtonl(x); } -inline uint64_t gntohll(uint64_t x) { return ghtonll(x); } +inline uint16_t GNetworkToHostS(uint16_t x) { return GHostToNetworkS(x); } +inline uint32_t GNetworkToHostL(uint32_t x) { return GHostToNetworkL(x); } +inline uint64_t GNetworkToHostLL(uint64_t x) { return GHostToNetworkLL(x); } -#ifndef ntohll -#define ntohll(x) htonll(x) -#endif // ntohll +#ifndef NetworkToHostLL +#define NetworkToHostLL(x) GHostToNetworkLL(x) +#endif // NetworkToHostLL #endif // ICING_PORTABLE_ENDIAN_H_ diff --git a/icing/portable/gzip_stream.cc b/icing/portable/gzip_stream.cc new file mode 100644 index 0000000..f00a993 --- /dev/null +++ b/icing/portable/gzip_stream.cc @@ -0,0 +1,313 @@ +// Copyright (C) 2009 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the implementation of classes GzipInputStream and +// GzipOutputStream. It is forked from protobuf because these classes are only +// provided in libprotobuf-full but we would like to link libicing against the +// smaller libprotobuf-lite instead. + +#include "icing/portable/gzip_stream.h" +#include "icing/util/logging.h" + +namespace icing { +namespace lib { +namespace protobuf_ports { + +static const int kDefaultBufferSize = 65536; + +GzipInputStream::GzipInputStream(ZeroCopyInputStream* sub_stream, Format format, + int buffer_size) + : format_(format), sub_stream_(sub_stream), zerror_(Z_OK), byte_count_(0) { + zcontext_.state = Z_NULL; + zcontext_.zalloc = Z_NULL; + zcontext_.zfree = Z_NULL; + zcontext_.opaque = Z_NULL; + zcontext_.total_out = 0; + zcontext_.next_in = NULL; + zcontext_.avail_in = 0; + zcontext_.total_in = 0; + zcontext_.msg = NULL; + if (buffer_size == -1) { + output_buffer_length_ = kDefaultBufferSize; + } else { + output_buffer_length_ = buffer_size; + } + output_buffer_ = operator new(output_buffer_length_); + zcontext_.next_out = static_cast<Bytef*>(output_buffer_); + zcontext_.avail_out = output_buffer_length_; + output_position_ = output_buffer_; +} +GzipInputStream::~GzipInputStream() { + operator delete(output_buffer_); + zerror_ = inflateEnd(&zcontext_); +} + +static inline int internalInflateInit2(z_stream* zcontext, + GzipInputStream::Format format) { + int windowBitsFormat = 0; + switch (format) { + case GzipInputStream::GZIP: + windowBitsFormat = 16; + break; + case GzipInputStream::AUTO: + windowBitsFormat = 32; + break; + case GzipInputStream::ZLIB: + windowBitsFormat = 0; + break; + } + return inflateInit2(zcontext, /* windowBits */ 15 | windowBitsFormat); +} + +int GzipInputStream::Inflate(int flush) { + if ((zerror_ == Z_OK) && (zcontext_.avail_out == 0)) { + // previous inflate filled output buffer. don't change input params yet. + } else if (zcontext_.avail_in == 0) { + const void* in; + int in_size; + bool first = zcontext_.next_in == NULL; + bool ok = sub_stream_->Next(&in, &in_size); + if (!ok) { + zcontext_.next_out = NULL; + zcontext_.avail_out = 0; + return Z_STREAM_END; + } + zcontext_.next_in = static_cast<Bytef*>(const_cast<void*>(in)); + zcontext_.avail_in = in_size; + if (first) { + int error = internalInflateInit2(&zcontext_, format_); + if (error != Z_OK) { + return error; + } + } + } + zcontext_.next_out = static_cast<Bytef*>(output_buffer_); + zcontext_.avail_out = output_buffer_length_; + output_position_ = output_buffer_; + int error = inflate(&zcontext_, flush); + return error; +} + +void GzipInputStream::DoNextOutput(const void** data, int* size) { + *data = output_position_; + *size = ((uintptr_t)zcontext_.next_out) - ((uintptr_t)output_position_); + output_position_ = zcontext_.next_out; +} + +// implements ZeroCopyInputStream ---------------------------------- +bool GzipInputStream::Next(const void** data, int* size) { + bool ok = (zerror_ == Z_OK) || (zerror_ == Z_STREAM_END) || + (zerror_ == Z_BUF_ERROR); + if ((!ok) || (zcontext_.next_out == NULL)) { + return false; + } + if (zcontext_.next_out != output_position_) { + DoNextOutput(data, size); + return true; + } + if (zerror_ == Z_STREAM_END) { + if (zcontext_.next_out != NULL) { + // sub_stream_ may have concatenated streams to follow + zerror_ = inflateEnd(&zcontext_); + byte_count_ += zcontext_.total_out; + if (zerror_ != Z_OK) { + return false; + } + zerror_ = internalInflateInit2(&zcontext_, format_); + if (zerror_ != Z_OK) { + return false; + } + } else { + *data = NULL; + *size = 0; + return false; + } + } + zerror_ = Inflate(Z_NO_FLUSH); + if ((zerror_ == Z_STREAM_END) && (zcontext_.next_out == NULL)) { + // The underlying stream's Next returned false inside Inflate. + return false; + } + ok = (zerror_ == Z_OK) || (zerror_ == Z_STREAM_END) || + (zerror_ == Z_BUF_ERROR); + if (!ok) { + return false; + } + DoNextOutput(data, size); + return true; +} +void GzipInputStream::BackUp(int count) { + output_position_ = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(output_position_) - count); +} +bool GzipInputStream::Skip(int count) { + const void* data; + int size = 0; + bool ok = Next(&data, &size); + while (ok && (size < count)) { + count -= size; + ok = Next(&data, &size); + } + if (size > count) { + BackUp(size - count); + } + return ok; +} +int64_t GzipInputStream::ByteCount() const { + int64_t ret = byte_count_ + zcontext_.total_out; + if (zcontext_.next_out != NULL && output_position_ != NULL) { + ret += reinterpret_cast<uintptr_t>(zcontext_.next_out) - + reinterpret_cast<uintptr_t>(output_position_); + } + return ret; +} + +// ========================================================================= + +GzipOutputStream::Options::Options() + : format(GZIP), + buffer_size(kDefaultBufferSize), + compression_level(Z_DEFAULT_COMPRESSION), + compression_strategy(Z_DEFAULT_STRATEGY) {} + +GzipOutputStream::GzipOutputStream(ZeroCopyOutputStream* sub_stream) { + Init(sub_stream, Options()); +} + +GzipOutputStream::GzipOutputStream(ZeroCopyOutputStream* sub_stream, + const Options& options) { + Init(sub_stream, options); +} + +void GzipOutputStream::Init(ZeroCopyOutputStream* sub_stream, + const Options& options) { + sub_stream_ = sub_stream; + sub_data_ = NULL; + sub_data_size_ = 0; + + input_buffer_length_ = options.buffer_size; + input_buffer_ = operator new(input_buffer_length_); + + zcontext_.zalloc = Z_NULL; + zcontext_.zfree = Z_NULL; + zcontext_.opaque = Z_NULL; + zcontext_.next_out = NULL; + zcontext_.avail_out = 0; + zcontext_.total_out = 0; + zcontext_.next_in = NULL; + zcontext_.avail_in = 0; + zcontext_.total_in = 0; + zcontext_.msg = NULL; + // default to GZIP format + int windowBitsFormat = 16; + if (options.format == ZLIB) { + windowBitsFormat = 0; + } + zerror_ = + deflateInit2(&zcontext_, options.compression_level, Z_DEFLATED, + /* windowBits */ 15 | windowBitsFormat, + /* memLevel (default) */ 8, options.compression_strategy); +} + +GzipOutputStream::~GzipOutputStream() { + Close(); + operator delete(input_buffer_); +} + +// private +int GzipOutputStream::Deflate(int flush) { + int error = Z_OK; + do { + if ((sub_data_ == NULL) || (zcontext_.avail_out == 0)) { + bool ok = sub_stream_->Next(&sub_data_, &sub_data_size_); + if (!ok) { + sub_data_ = NULL; + sub_data_size_ = 0; + return Z_BUF_ERROR; + } + if (sub_data_size_ <= 0) { + ICING_LOG(FATAL) << "Failed to advance underlying stream"; + } + zcontext_.next_out = static_cast<Bytef*>(sub_data_); + zcontext_.avail_out = sub_data_size_; + } + error = deflate(&zcontext_, flush); + } while (error == Z_OK && zcontext_.avail_out == 0); + if ((flush == Z_FULL_FLUSH) || (flush == Z_FINISH)) { + // Notify lower layer of data. + sub_stream_->BackUp(zcontext_.avail_out); + // We don't own the buffer anymore. + sub_data_ = NULL; + sub_data_size_ = 0; + } + return error; +} + +// implements ZeroCopyOutputStream --------------------------------- +bool GzipOutputStream::Next(void** data, int* size) { + if ((zerror_ != Z_OK) && (zerror_ != Z_BUF_ERROR)) { + return false; + } + if (zcontext_.avail_in != 0) { + zerror_ = Deflate(Z_NO_FLUSH); + if (zerror_ != Z_OK) { + return false; + } + } + if (zcontext_.avail_in == 0) { + // all input was consumed. reset the buffer. + zcontext_.next_in = static_cast<Bytef*>(input_buffer_); + zcontext_.avail_in = input_buffer_length_; + *data = input_buffer_; + *size = input_buffer_length_; + } else { + // The loop in Deflate should consume all avail_in + ICING_LOG(ERROR) << "Deflate left bytes unconsumed"; + } + return true; +} +void GzipOutputStream::BackUp(int count) { + if (zcontext_.avail_in < static_cast<uInt>(count)) { + ICING_LOG(FATAL) << "Not enough data to back up " << count << " bytes"; + } + zcontext_.avail_in -= count; +} +int64_t GzipOutputStream::ByteCount() const { + return zcontext_.total_in + zcontext_.avail_in; +} + +bool GzipOutputStream::Flush() { + zerror_ = Deflate(Z_FULL_FLUSH); + // Return true if the flush succeeded or if it was a no-op. + return (zerror_ == Z_OK) || + (zerror_ == Z_BUF_ERROR && zcontext_.avail_in == 0 && + zcontext_.avail_out != 0); +} + +bool GzipOutputStream::Close() { + if ((zerror_ != Z_OK) && (zerror_ != Z_BUF_ERROR)) { + return false; + } + do { + zerror_ = Deflate(Z_FINISH); + } while (zerror_ == Z_OK); + zerror_ = deflateEnd(&zcontext_); + bool ok = zerror_ == Z_OK; + zerror_ = Z_STREAM_END; + return ok; +} + +} // namespace protobuf_ports +} // namespace lib +} // namespace icing diff --git a/icing/portable/gzip_stream.h b/icing/portable/gzip_stream.h new file mode 100644 index 0000000..602093f --- /dev/null +++ b/icing/portable/gzip_stream.h @@ -0,0 +1,181 @@ +// Copyright (C) 2009 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the definition for classes GzipInputStream and +// GzipOutputStream. It is forked from protobuf because these classes are only +// provided in libprotobuf-full but we would like to link libicing against the +// smaller libprotobuf-lite instead. +// +// GzipInputStream decompresses data from an underlying +// ZeroCopyInputStream and provides the decompressed data as a +// ZeroCopyInputStream. +// +// GzipOutputStream is an ZeroCopyOutputStream that compresses data to +// an underlying ZeroCopyOutputStream. + +#ifndef GOOGLE3_ICING_PORTABLE_GZIP_STREAM_H_ +#define GOOGLE3_ICING_PORTABLE_GZIP_STREAM_H_ + +#include <google/protobuf/io/zero_copy_stream_impl_lite.h> +#include "icing/portable/zlib.h" + +namespace icing { +namespace lib { +namespace protobuf_ports { + +// A ZeroCopyInputStream that reads compressed data through zlib +class GzipInputStream : public google::protobuf::io::ZeroCopyInputStream { + public: + // Format key for constructor + enum Format { + // zlib will autodetect gzip header or deflate stream + AUTO = 0, + + // GZIP streams have some extra header data for file attributes. + GZIP = 1, + + // Simpler zlib stream format. + ZLIB = 2, + }; + + // buffer_size and format may be -1 for default of 64kB and GZIP format + explicit GzipInputStream( + google::protobuf::io::ZeroCopyInputStream* sub_stream, + Format format = AUTO, int buffer_size = -1); + virtual ~GzipInputStream(); + + // Return last error message or NULL if no error. + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } + + // implements ZeroCopyInputStream ---------------------------------- + bool Next(const void** data, int* size) override; + void BackUp(int count) override; + bool Skip(int count) override; + int64_t ByteCount() const override; + + private: + Format format_; + + google::protobuf::io::ZeroCopyInputStream* sub_stream_; + + z_stream zcontext_; + int zerror_; + + void* output_buffer_; + void* output_position_; + size_t output_buffer_length_; + int64_t byte_count_; + + int Inflate(int flush); + void DoNextOutput(const void** data, int* size); +}; + +class GzipOutputStream : public google::protobuf::io::ZeroCopyOutputStream { + public: + // Format key for constructor + enum Format { + // GZIP streams have some extra header data for file attributes. + GZIP = 1, + + // Simpler zlib stream format. + ZLIB = 2, + }; + + struct Options { + // Defaults to GZIP. + Format format; + + // What size buffer to use internally. Defaults to 64kB. + int buffer_size; + + // A number between 0 and 9, where 0 is no compression and 9 is best + // compression. Defaults to Z_DEFAULT_COMPRESSION (see zlib.h). + int compression_level; + + // Defaults to Z_DEFAULT_STRATEGY. Can also be set to Z_FILTERED, + // Z_HUFFMAN_ONLY, or Z_RLE. See the documentation for deflateInit2 in + // zlib.h for definitions of these constants. + int compression_strategy; + + Options(); // Initializes with default values. + }; + + // Create a GzipOutputStream with default options. + explicit GzipOutputStream( + google::protobuf::io::ZeroCopyOutputStream* sub_stream); + + // Create a GzipOutputStream with the given options. + GzipOutputStream( + google::protobuf::io::ZeroCopyOutputStream* sub_stream, + const Options& options); + + virtual ~GzipOutputStream(); + + // Return last error message or NULL if no error. + inline const char* ZlibErrorMessage() const { return zcontext_.msg; } + inline int ZlibErrorCode() const { return zerror_; } + + // Flushes data written so far to zipped data in the underlying stream. + // It is the caller's responsibility to flush the underlying stream if + // necessary. + // Compression may be less efficient stopping and starting around flushes. + // Returns true if no error. + // + // Please ensure that block size is > 6. Here is an excerpt from the zlib + // doc that explains why: + // + // In the case of a Z_FULL_FLUSH or Z_SYNC_FLUSH, make sure that avail_out + // is greater than six to avoid repeated flush markers due to + // avail_out == 0 on return. + bool Flush(); + + // Writes out all data and closes the gzip stream. + // It is the caller's responsibility to close the underlying stream if + // necessary. + // Returns true if no error. + bool Close(); + + // implements ZeroCopyOutputStream --------------------------------- + bool Next(void** data, int* size) override; + void BackUp(int count) override; + int64_t ByteCount() const override; + + private: + google::protobuf::io::ZeroCopyOutputStream* sub_stream_; + // Result from calling Next() on sub_stream_ + void* sub_data_; + int sub_data_size_; + + z_stream zcontext_; + int zerror_; + void* input_buffer_; + size_t input_buffer_length_; + + // Shared constructor code. + void Init( + google::protobuf::io::ZeroCopyOutputStream* sub_stream, + const Options& options); + + // Do some compression. + // Takes zlib flush mode. + // Returns zlib error code. + int Deflate(int flush); +}; + +} // namespace protobuf_ports +} // namespace lib +} // namespace icing + +#endif // GOOGLE3_ICING_PORTABLE_GZIP_STREAM_H_ diff --git a/icing/query/query-processor.cc b/icing/query/query-processor.cc index 1f937fd..36c76db 100644 --- a/icing/query/query-processor.cc +++ b/icing/query/query-processor.cc @@ -182,7 +182,7 @@ QueryProcessor::ParseRawQuery(const SearchSpecProto& search_spec) { const Token& token = tokens.at(i); std::unique_ptr<DocHitInfoIterator> result_iterator; - // TODO(cassiewang): Handle negation tokens + // TODO(b/202076890): Handle negation tokens switch (token.type) { case Token::Type::QUERY_LEFT_PARENTHESES: { frames.emplace(ParserStateFrame()); diff --git a/icing/query/suggestion-processor.cc b/icing/query/suggestion-processor.cc new file mode 100644 index 0000000..9c60810 --- /dev/null +++ b/icing/query/suggestion-processor.cc @@ -0,0 +1,93 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/suggestion-processor.h" + +#include "icing/tokenization/tokenizer-factory.h" +#include "icing/tokenization/tokenizer.h" +#include "icing/transform/normalizer.h" + +namespace icing { +namespace lib { + +libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>> +SuggestionProcessor::Create(Index* index, + const LanguageSegmenter* language_segmenter, + const Normalizer* normalizer) { + ICING_RETURN_ERROR_IF_NULL(index); + ICING_RETURN_ERROR_IF_NULL(language_segmenter); + + return std::unique_ptr<SuggestionProcessor>( + new SuggestionProcessor(index, language_segmenter, normalizer)); +} + +libtextclassifier3::StatusOr<std::vector<TermMetadata>> +SuggestionProcessor::QuerySuggestions( + const icing::lib::SuggestionSpecProto& suggestion_spec, + const std::vector<NamespaceId>& namespace_ids) { + // We use query tokenizer to tokenize the give prefix, and we only use the + // last token to be the suggestion prefix. + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<Tokenizer> tokenizer, + tokenizer_factory::CreateIndexingTokenizer( + StringIndexingConfig::TokenizerType::PLAIN, &language_segmenter_)); + ICING_ASSIGN_OR_RETURN(std::unique_ptr<Tokenizer::Iterator> iterator, + tokenizer->Tokenize(suggestion_spec.prefix())); + + // If there are previous tokens, they are prepended to the suggestion, + // separated by spaces. + std::string last_token; + int token_start_pos; + while (iterator->Advance()) { + Token token = iterator->GetToken(); + last_token = token.text; + token_start_pos = token.text.data() - suggestion_spec.prefix().c_str(); + } + + // If the position of the last token is not the end of the prefix, it means + // there should be some operator tokens after it and are ignored by the + // tokenizer. + bool is_last_token = token_start_pos + last_token.length() >= + suggestion_spec.prefix().length(); + + if (!is_last_token || last_token.empty()) { + // We don't have a valid last token, return early. + return std::vector<TermMetadata>(); + } + + std::string query_prefix = + suggestion_spec.prefix().substr(0, token_start_pos); + // Run suggestion based on given SuggestionSpec. + // Normalize token text to lowercase since all tokens in the lexicon are + // lowercase. + ICING_ASSIGN_OR_RETURN( + std::vector<TermMetadata> terms, + index_.FindTermsByPrefix(normalizer_.NormalizeTerm(last_token), + namespace_ids, suggestion_spec.num_to_return())); + + for (TermMetadata& term : terms) { + term.content = query_prefix + term.content; + } + return terms; +} + +SuggestionProcessor::SuggestionProcessor( + Index* index, const LanguageSegmenter* language_segmenter, + const Normalizer* normalizer) + : index_(*index), + language_segmenter_(*language_segmenter), + normalizer_(*normalizer) {} + +} // namespace lib +} // namespace icing
\ No newline at end of file diff --git a/icing/query/suggestion-processor.h b/icing/query/suggestion-processor.h new file mode 100644 index 0000000..b10dc84 --- /dev/null +++ b/icing/query/suggestion-processor.h @@ -0,0 +1,68 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_QUERY_SUGGESTION_PROCESSOR_H_ +#define ICING_QUERY_SUGGESTION_PROCESSOR_H_ + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/index/index.h" +#include "icing/proto/search.pb.h" +#include "icing/tokenization/language-segmenter.h" +#include "icing/transform/normalizer.h" + +namespace icing { +namespace lib { + +// Processes SuggestionSpecProtos and retrieves the specified TermMedaData that +// satisfies the prefix and its restrictions. This also performs ranking, and +// returns TermMetaData ordered by their hit count. +class SuggestionProcessor { + public: + // Factory function to create a SuggestionProcessor which does not take + // ownership of any input components, and all pointers must refer to valid + // objects that outlive the created SuggestionProcessor instance. + // + // Returns: + // An SuggestionProcessor on success + // FAILED_PRECONDITION if any of the pointers is null. + static libtextclassifier3::StatusOr<std::unique_ptr<SuggestionProcessor>> + Create(Index* index, const LanguageSegmenter* language_segmenter, + const Normalizer* normalizer); + + // Query suggestions based on the given SuggestionSpecProto. + // + // Returns: + // On success, + // - One vector that represents the entire TermMetadata + // INTERNAL_ERROR on all other errors + libtextclassifier3::StatusOr<std::vector<TermMetadata>> QuerySuggestions( + const SuggestionSpecProto& suggestion_spec, + const std::vector<NamespaceId>& namespace_ids); + + private: + explicit SuggestionProcessor(Index* index, + const LanguageSegmenter* language_segmenter, + const Normalizer* normalizer); + + // Not const because we could modify/sort the TermMetaData buffer in the lite + // index. + Index& index_; + const LanguageSegmenter& language_segmenter_; + const Normalizer& normalizer_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_QUERY_SUGGESTION_PROCESSOR_H_ diff --git a/icing/query/suggestion-processor_test.cc b/icing/query/suggestion-processor_test.cc new file mode 100644 index 0000000..5e62277 --- /dev/null +++ b/icing/query/suggestion-processor_test.cc @@ -0,0 +1,324 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/query/suggestion-processor.h" + +#include "gmock/gmock.h" +#include "icing/helpers/icu/icu-data-file-helper.h" +#include "icing/store/document-store.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/jni-test-helpers.h" +#include "icing/testing/test-data.h" +#include "icing/testing/tmp-directory.h" +#include "icing/tokenization/language-segmenter-factory.h" +#include "icing/transform/normalizer-factory.h" +#include "unicode/uloc.h" + +namespace icing { +namespace lib { + +namespace { + +using ::testing::IsEmpty; +using ::testing::Test; + +class SuggestionProcessorTest : public Test { + protected: + SuggestionProcessorTest() + : test_dir_(GetTestTempDir() + "/icing"), + store_dir_(test_dir_ + "/store"), + index_dir_(test_dir_ + "/index") {} + + void SetUp() override { + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + filesystem_.CreateDirectoryRecursively(index_dir_.c_str()); + filesystem_.CreateDirectoryRecursively(store_dir_.c_str()); + + if (!IsCfStringTokenization() && !IsReverseJniTokenization()) { + // If we've specified using the reverse-JNI method for segmentation (i.e. + // not ICU), then we won't have the ICU data file included to set up. + // Technically, we could choose to use reverse-JNI for segmentation AND + // include an ICU data file, but that seems unlikely and our current BUILD + // setup doesn't do this. + ICING_ASSERT_OK( + // File generated via icu_data_file rule in //icing/BUILD. + icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + Index::Options options(index_dir_, + /*index_merge_size=*/1024 * 1024); + ICING_ASSERT_OK_AND_ASSIGN( + index_, Index::Create(options, &filesystem_, &icing_filesystem_)); + + language_segmenter_factory::SegmenterOptions segmenter_options( + ULOC_US, jni_cache_.get()); + ICING_ASSERT_OK_AND_ASSIGN( + language_segmenter_, + language_segmenter_factory::Create(segmenter_options)); + + ICING_ASSERT_OK_AND_ASSIGN(normalizer_, normalizer_factory::Create( + /*max_term_byte_size=*/1000)); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem_, store_dir_, &fake_clock_, + schema_store_.get())); + document_store_ = std::move(create_result.document_store); + } + + libtextclassifier3::Status AddTokenToIndex( + DocumentId document_id, SectionId section_id, + TermMatchType::Code term_match_type, const std::string& token) { + Index::Editor editor = index_->Edit(document_id, section_id, + term_match_type, /*namespace_id=*/0); + auto status = editor.BufferTerm(token.c_str()); + return status.ok() ? editor.IndexAllBufferedTerms() : status; + } + + void TearDown() override { + document_store_.reset(); + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + Filesystem filesystem_; + const std::string test_dir_; + const std::string store_dir_; + std::unique_ptr<Index> index_; + std::unique_ptr<LanguageSegmenter> language_segmenter_; + std::unique_ptr<Normalizer> normalizer_; + std::unique_ptr<DocumentStore> document_store_; + std::unique_ptr<SchemaStore> schema_store_; + std::unique_ptr<const JniCache> jni_cache_ = GetTestJniCache(); + FakeClock fake_clock_; + + private: + IcingFilesystem icing_filesystem_; + const std::string index_dir_; +}; + +constexpr DocumentId kDocumentId0 = 0; +constexpr SectionId kSectionId2 = 2; + +TEST_F(SuggestionProcessorTest, PrependedPrefixTokenTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix( + "prefix token should be prepended to the suggestion f"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, + "prefix token should be prepended to the suggestion foo"); +} + +TEST_F(SuggestionProcessorTest, NonExistentPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("nonExistTerm"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + + EXPECT_THAT(terms, IsEmpty()); +} + +TEST_F(SuggestionProcessorTest, PrefixTrailingSpaceTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("f "); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + + EXPECT_THAT(terms, IsEmpty()); +} + +TEST_F(SuggestionProcessorTest, NormalizePrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("F"); + suggestion_spec.set_num_to_return(10); + ICING_ASSERT_OK_AND_ASSIGN( + std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, "foo"); + + suggestion_spec.set_prefix("fO"); + ICING_ASSERT_OK_AND_ASSIGN( + terms, suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, "foo"); + + suggestion_spec.set_prefix("Fo"); + ICING_ASSERT_OK_AND_ASSIGN( + terms, suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, "foo"); + + suggestion_spec.set_prefix("FO"); + ICING_ASSERT_OK_AND_ASSIGN( + terms, suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms.at(0).content, "foo"); +} + +TEST_F(SuggestionProcessorTest, OrOperatorPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "original"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("f OR"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + + // Last Operator token will be used to query suggestion + EXPECT_THAT(terms.at(0).content, "f original"); +} + +TEST_F(SuggestionProcessorTest, ParenthesesOperatorPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("{f}"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); + + suggestion_spec.set_prefix("[f]"); + ICING_ASSERT_OK_AND_ASSIGN(terms, suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); + + suggestion_spec.set_prefix("(f)"); + ICING_ASSERT_OK_AND_ASSIGN(terms, suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); +} + +TEST_F(SuggestionProcessorTest, OtherSpecialPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "foo"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("f:"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); + + suggestion_spec.set_prefix("f-"); + ICING_ASSERT_OK_AND_ASSIGN( + terms, suggestion_processor->QuerySuggestions(suggestion_spec, + /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); +} + +TEST_F(SuggestionProcessorTest, InvalidPrefixTest) { + ASSERT_THAT(AddTokenToIndex(kDocumentId0, kSectionId2, + TermMatchType::EXACT_ONLY, "original"), + IsOk()); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SuggestionProcessor> suggestion_processor, + SuggestionProcessor::Create(index_.get(), language_segmenter_.get(), + normalizer_.get())); + + SuggestionSpecProto suggestion_spec; + suggestion_spec.set_prefix("OR OR - :"); + suggestion_spec.set_num_to_return(10); + + ICING_ASSERT_OK_AND_ASSIGN(std::vector<TermMetadata> terms, + suggestion_processor->QuerySuggestions( + suggestion_spec, /*namespace_ids=*/{})); + EXPECT_THAT(terms, IsEmpty()); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/result/snippet-retriever.cc b/icing/result/snippet-retriever.cc index 2a138ec..c46762e 100644 --- a/icing/result/snippet-retriever.cc +++ b/icing/result/snippet-retriever.cc @@ -78,7 +78,17 @@ inline std::string AddIndexToPath(int values_size, int index, class TokenMatcher { public: virtual ~TokenMatcher() = default; - virtual bool Matches(Token token) const = 0; + + // Returns a CharacterIterator pointing just past the end of the substring in + // token.text that matches a query term. Note that the utf* indices will be + // in relation to token.text's start. + // + // If there is no match, then it will construct a CharacterIterator with all + // of its indices set to -1. + // + // Ex. With an exact matcher, query terms=["foo","bar"] and token.text="bar", + // Matches will return a CharacterIterator(u8:3, u16:3, u32:3). + virtual CharacterIterator Matches(Token token) const = 0; }; class TokenMatcherExact : public TokenMatcher { @@ -91,10 +101,17 @@ class TokenMatcherExact : public TokenMatcher { restricted_query_terms_(restricted_query_terms), normalizer_(normalizer) {} - bool Matches(Token token) const override { + CharacterIterator Matches(Token token) const override { std::string s = normalizer_.NormalizeTerm(token.text); - return (unrestricted_query_terms_.count(s) > 0) || - (restricted_query_terms_.count(s) > 0); + auto itr = unrestricted_query_terms_.find(s); + if (itr == unrestricted_query_terms_.end()) { + itr = restricted_query_terms_.find(s); + } + if (itr != unrestricted_query_terms_.end() && + itr != restricted_query_terms_.end()) { + return normalizer_.FindNormalizedMatchEndPosition(token.text, *itr); + } + return CharacterIterator(token.text, -1, -1, -1); } private: @@ -113,22 +130,23 @@ class TokenMatcherPrefix : public TokenMatcher { restricted_query_terms_(restricted_query_terms), normalizer_(normalizer) {} - bool Matches(Token token) const override { + CharacterIterator Matches(Token token) const override { std::string s = normalizer_.NormalizeTerm(token.text); - if (std::any_of(unrestricted_query_terms_.begin(), - unrestricted_query_terms_.end(), - [&s](const std::string& term) { - return term.length() <= s.length() && - s.compare(0, term.length(), term) == 0; - })) { - return true; + for (const std::string& query_term : unrestricted_query_terms_) { + if (query_term.length() <= s.length() && + s.compare(0, query_term.length(), query_term) == 0) { + return normalizer_.FindNormalizedMatchEndPosition(token.text, + query_term); + } + } + for (const std::string& query_term : restricted_query_terms_) { + if (query_term.length() <= s.length() && + s.compare(0, query_term.length(), query_term) == 0) { + return normalizer_.FindNormalizedMatchEndPosition(token.text, + query_term); + } } - return std::any_of(restricted_query_terms_.begin(), - restricted_query_terms_.end(), - [&s](const std::string& term) { - return term.length() <= s.length() && - s.compare(0, term.length(), term) == 0; - }); + return CharacterIterator(token.text, -1, -1, -1); } private: @@ -364,7 +382,10 @@ void GetEntriesFromProperty(const PropertyProto* current_property, CharacterIterator char_iterator(value); while (iterator->Advance()) { Token token = iterator->GetToken(); - if (matcher->Matches(token)) { + CharacterIterator submatch_end = matcher->Matches(token); + // If the token matched a query term, then submatch_end will point to an + // actual position within token.text. + if (submatch_end.utf8_index() != -1) { if (!char_iterator.AdvanceToUtf8(token.text.data() - value.data())) { // We can't get the char_iterator to a valid position, so there's no // way for us to provide valid utf-16 indices. There's nothing more we @@ -393,7 +414,15 @@ void GetEntriesFromProperty(const PropertyProto* current_property, } } SnippetMatchProto match = std::move(match_or).ValueOrDie(); + // submatch_end refers to a position *within* token.text. + // This, conveniently enough, means that index that submatch_end points + // to is the length of the submatch (because the submatch starts at 0 in + // token.text). + match.set_submatch_byte_length(submatch_end.utf8_index()); + match.set_submatch_utf16_length(submatch_end.utf16_index()); + // Add the values for the submatch. snippet_entry.mutable_snippet_matches()->Add(std::move(match)); + if (--match_options->max_matches_remaining <= 0) { *snippet_proto->add_entries() = std::move(snippet_entry); return; diff --git a/icing/result/snippet-retriever_test.cc b/icing/result/snippet-retriever_test.cc index e7988ae..f811941 100644 --- a/icing/result/snippet-retriever_test.cc +++ b/icing/result/snippet-retriever_test.cc @@ -43,6 +43,7 @@ #include "icing/testing/tmp-directory.h" #include "icing/tokenization/language-segmenter-factory.h" #include "icing/tokenization/language-segmenter.h" +#include "icing/transform/map/map-normalizer.h" #include "icing/transform/normalizer-factory.h" #include "icing/transform/normalizer.h" #include "unicode/uloc.h" @@ -690,6 +691,7 @@ TEST_F(SnippetRetrieverTest, PrefixSnippeting) { EXPECT_THAT(GetWindows(content, snippet.entries(0)), ElementsAre("subject foo")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("foo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), ElementsAre("f")); } TEST_F(SnippetRetrieverTest, ExactSnippeting) { @@ -733,6 +735,7 @@ TEST_F(SnippetRetrieverTest, SimpleSnippetingNoWindowing) { GetString(&document, snippet.entries(0).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(0)), ElementsAre("")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("foo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), ElementsAre("foo")); } TEST_F(SnippetRetrieverTest, SnippetingMultipleMatches) { @@ -779,12 +782,15 @@ TEST_F(SnippetRetrieverTest, SnippetingMultipleMatches) { "we need to begin considering our options regarding body bar.")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("foo", "bar")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), + ElementsAre("foo", "bar")); EXPECT_THAT(snippet.entries(1).property_name(), Eq("subject")); content = GetString(&document, snippet.entries(1).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(1)), ElementsAre("subject foo")); EXPECT_THAT(GetMatches(content, snippet.entries(1)), ElementsAre("foo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(1)), ElementsAre("foo")); } TEST_F(SnippetRetrieverTest, SnippetingMultipleMatchesSectionRestrict) { @@ -834,6 +840,8 @@ TEST_F(SnippetRetrieverTest, SnippetingMultipleMatchesSectionRestrict) { "we need to begin considering our options regarding body bar.")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("foo", "bar")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), + ElementsAre("foo", "bar")); } TEST_F(SnippetRetrieverTest, SnippetingMultipleMatchesSectionRestrictedTerm) { @@ -884,12 +892,16 @@ TEST_F(SnippetRetrieverTest, SnippetingMultipleMatchesSectionRestrictedTerm) { "Concerning the subject of foo, we need to begin considering our")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("subject", "foo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), + ElementsAre("subject", "foo")); EXPECT_THAT(snippet.entries(1).property_name(), Eq("subject")); content = GetString(&document, snippet.entries(1).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(1)), ElementsAre("subject foo")); EXPECT_THAT(GetMatches(content, snippet.entries(1)), ElementsAre("subject")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(1)), + ElementsAre("subject")); } TEST_F(SnippetRetrieverTest, SnippetingMultipleMatchesOneMatchPerProperty) { @@ -933,12 +945,14 @@ TEST_F(SnippetRetrieverTest, SnippetingMultipleMatchesOneMatchPerProperty) { ElementsAre( "Concerning the subject of foo, we need to begin considering our")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("foo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), ElementsAre("foo")); EXPECT_THAT(snippet.entries(1).property_name(), Eq("subject")); content = GetString(&document, snippet.entries(1).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(1)), ElementsAre("subject foo")); EXPECT_THAT(GetMatches(content, snippet.entries(1)), ElementsAre("foo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(1)), ElementsAre("foo")); } TEST_F(SnippetRetrieverTest, PrefixSnippetingNormalization) { @@ -960,6 +974,7 @@ TEST_F(SnippetRetrieverTest, PrefixSnippetingNormalization) { GetString(&document, snippet.entries(0).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(0)), ElementsAre("MDI team")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("MDI")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), ElementsAre("MD")); } TEST_F(SnippetRetrieverTest, ExactSnippetingNormalization) { @@ -983,6 +998,9 @@ TEST_F(SnippetRetrieverTest, ExactSnippetingNormalization) { EXPECT_THAT(GetWindows(content, snippet.entries(0)), ElementsAre("Some members are in Zürich.")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("Zürich")); + + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), + ElementsAre("Zürich")); } TEST_F(SnippetRetrieverTest, SnippetingTestOneLevel) { @@ -1043,11 +1061,13 @@ TEST_F(SnippetRetrieverTest, SnippetingTestOneLevel) { GetString(&document, snippet.entries(0).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(0)), ElementsAre("polo")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("polo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), ElementsAre("polo")); EXPECT_THAT(snippet.entries(1).property_name(), Eq("X[3]")); content = GetString(&document, snippet.entries(1).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(1)), ElementsAre("polo")); EXPECT_THAT(GetMatches(content, snippet.entries(1)), ElementsAre("polo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(1)), ElementsAre("polo")); EXPECT_THAT(GetPropertyPaths(snippet), ElementsAre("X[1]", "X[3]", "Y[1]", "Y[3]", "Z[1]", "Z[3]")); @@ -1144,11 +1164,13 @@ TEST_F(SnippetRetrieverTest, SnippetingTestMultiLevel) { GetString(&document, snippet.entries(0).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(0)), ElementsAre("polo")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("polo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), ElementsAre("polo")); EXPECT_THAT(snippet.entries(1).property_name(), Eq("A.X[3]")); content = GetString(&document, snippet.entries(1).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(1)), ElementsAre("polo")); EXPECT_THAT(GetMatches(content, snippet.entries(1)), ElementsAre("polo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(1)), ElementsAre("polo")); EXPECT_THAT( GetPropertyPaths(snippet), @@ -1251,11 +1273,13 @@ TEST_F(SnippetRetrieverTest, SnippetingTestMultiLevelRepeated) { GetString(&document, snippet.entries(0).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(0)), ElementsAre("polo")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("polo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), ElementsAre("polo")); EXPECT_THAT(snippet.entries(1).property_name(), Eq("A[0].X[3]")); content = GetString(&document, snippet.entries(1).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(1)), ElementsAre("polo")); EXPECT_THAT(GetMatches(content, snippet.entries(1)), ElementsAre("polo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(1)), ElementsAre("polo")); EXPECT_THAT(GetPropertyPaths(snippet), ElementsAre("A[0].X[1]", "A[0].X[3]", "A[1].X[1]", "A[1].X[3]", @@ -1356,11 +1380,13 @@ TEST_F(SnippetRetrieverTest, SnippetingTestMultiLevelSingleValue) { GetString(&document, snippet.entries(0).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(0)), ElementsAre("polo")); EXPECT_THAT(GetMatches(content, snippet.entries(0)), ElementsAre("polo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(0)), ElementsAre("polo")); EXPECT_THAT(snippet.entries(1).property_name(), Eq("A[1].X")); content = GetString(&document, snippet.entries(1).property_name()); EXPECT_THAT(GetWindows(content, snippet.entries(1)), ElementsAre("polo")); EXPECT_THAT(GetMatches(content, snippet.entries(1)), ElementsAre("polo")); + EXPECT_THAT(GetSubMatches(content, snippet.entries(1)), ElementsAre("polo")); EXPECT_THAT( GetPropertyPaths(snippet), @@ -1404,10 +1430,12 @@ TEST_F(SnippetRetrieverTest, CJKSnippetMatchTest) { // Ensure that the match is correct. EXPECT_THAT(GetMatches(content, *entry), ElementsAre("走路")); + EXPECT_THAT(GetSubMatches(content, *entry), ElementsAre("走")); // Ensure that the utf-16 values are also as expected EXPECT_THAT(match_proto.exact_match_utf16_position(), Eq(3)); EXPECT_THAT(match_proto.exact_match_utf16_length(), Eq(2)); + EXPECT_THAT(match_proto.submatch_utf16_length(), Eq(1)); } TEST_F(SnippetRetrieverTest, CJKSnippetWindowTest) { @@ -1507,10 +1535,12 @@ TEST_F(SnippetRetrieverTest, Utf16MultiCodeUnitSnippetMatchTest) { // Ensure that the match is correct. EXPECT_THAT(GetMatches(content, *entry), ElementsAre("𐀂𐀃")); + EXPECT_THAT(GetSubMatches(content, *entry), ElementsAre("𐀂")); // Ensure that the utf-16 values are also as expected EXPECT_THAT(match_proto.exact_match_utf16_position(), Eq(5)); EXPECT_THAT(match_proto.exact_match_utf16_length(), Eq(4)); + EXPECT_THAT(match_proto.submatch_utf16_length(), Eq(2)); } TEST_F(SnippetRetrieverTest, Utf16MultiCodeUnitWindowTest) { diff --git a/icing/schema/schema-store.cc b/icing/schema/schema-store.cc index 3307638..67528ab 100644 --- a/icing/schema/schema-store.cc +++ b/icing/schema/schema-store.cc @@ -491,5 +491,10 @@ SchemaStoreStorageInfoProto SchemaStore::GetStorageInfo() const { return storage_info; } +libtextclassifier3::StatusOr<const std::vector<SectionMetadata>*> +SchemaStore::GetSectionMetadata(const std::string& schema_type) const { + return section_manager_->GetMetadataList(schema_type); +} + } // namespace lib } // namespace icing diff --git a/icing/schema/schema-store.h b/icing/schema/schema-store.h index b9be6c0..6b6528d 100644 --- a/icing/schema/schema-store.h +++ b/icing/schema/schema-store.h @@ -246,6 +246,12 @@ class SchemaStore { // INTERNAL_ERROR on compute error libtextclassifier3::StatusOr<Crc32> ComputeChecksum() const; + // Returns: + // - On success, the section metadata list for the specified schema type + // - NOT_FOUND if the schema type is not present in the schema + libtextclassifier3::StatusOr<const std::vector<SectionMetadata>*> + GetSectionMetadata(const std::string& schema_type) const; + // Calculates the StorageInfo for the Schema Store. // // If an IO error occurs while trying to calculate the value for a field, then diff --git a/icing/scoring/bm25f-calculator.cc b/icing/scoring/bm25f-calculator.cc index 4822d7f..28d385e 100644 --- a/icing/scoring/bm25f-calculator.cc +++ b/icing/scoring/bm25f-calculator.cc @@ -26,6 +26,7 @@ #include "icing/store/corpus-associated-scoring-data.h" #include "icing/store/corpus-id.h" #include "icing/store/document-associated-score-data.h" +#include "icing/store/document-filter-data.h" #include "icing/store/document-id.h" namespace icing { @@ -42,8 +43,11 @@ constexpr float k1_ = 1.2f; constexpr float b_ = 0.7f; // TODO(b/158603900): add tests for Bm25fCalculator -Bm25fCalculator::Bm25fCalculator(const DocumentStore* document_store) - : document_store_(document_store) {} +Bm25fCalculator::Bm25fCalculator( + const DocumentStore* document_store, + std::unique_ptr<SectionWeights> section_weights) + : document_store_(document_store), + section_weights_(std::move(section_weights)) {} // During initialization, Bm25fCalculator iterates through // hit-iterators for each query term to pre-compute n(q_i) for each corpus under @@ -121,9 +125,9 @@ float Bm25fCalculator::ComputeScore(const DocHitInfoIterator* query_it, // Compute inverse document frequency (IDF) weight for query term in the given // corpus, and cache it in the map. // -// N - n(q_i) + 0.5 -// IDF(q_i) = log(1 + ------------------) -// n(q_i) + 0.5 +// N - n(q_i) + 0.5 +// IDF(q_i) = ln(1 + ------------------) +// n(q_i) + 0.5 // // where N is the number of documents in the corpus, and n(q_i) is the number // of documents in the corpus containing the query term q_i. @@ -149,7 +153,7 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, uint32_t num_docs = csdata.num_docs(); uint32_t nqi = corpus_nqi_map_[corpus_term_info.value]; float idf = - nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi - 0.5f)) : 0.0f; + nqi != 0 ? log(1.0f + (num_docs - nqi + 0.5f) / (nqi + 0.5f)) : 0.0f; corpus_idf_map_.insert({corpus_term_info.value, idf}); ICING_VLOG(1) << IcingStringUtil::StringPrintf( "corpus_id:%d term:%s N:%d nqi:%d idf:%f", corpus_id, @@ -158,6 +162,11 @@ float Bm25fCalculator::GetCorpusIdfWeightForTerm(std::string_view term, } // Get per corpus average document length and cache the result in the map. +// The average doc length is calculated as: +// +// total_tokens_in_corpus +// Avg Doc Length = ------------------------- +// num_docs_in_corpus + 1 float Bm25fCalculator::GetCorpusAvgDocLength(CorpusId corpus_id) { auto iter = corpus_avgdl_map_.find(corpus_id); if (iter != corpus_avgdl_map_.end()) { @@ -191,8 +200,8 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency( const DocumentAssociatedScoreData& data) { uint32_t dl = data.length_in_tokens(); float avgdl = GetCorpusAvgDocLength(data.corpus_id()); - float f_q = - ComputeTermFrequencyForMatchedSections(data.corpus_id(), term_match_info); + float f_q = ComputeTermFrequencyForMatchedSections( + data.corpus_id(), term_match_info, hit_info.document_id()); float normalized_tf = f_q * (k1_ + 1) / (f_q + k1_ * (1 - b_ + b_ * dl / avgdl)); @@ -202,23 +211,41 @@ float Bm25fCalculator::ComputedNormalizedTermFrequency( return normalized_tf; } -// Note: once we support section weights, we should update this function to -// compute the weighted term frequency. float Bm25fCalculator::ComputeTermFrequencyForMatchedSections( - CorpusId corpus_id, const TermMatchInfo& term_match_info) const { + CorpusId corpus_id, const TermMatchInfo& term_match_info, + DocumentId document_id) const { float sum = 0.0f; SectionIdMask sections = term_match_info.section_ids_mask; + SchemaTypeId schema_type_id = GetSchemaTypeId(document_id); + while (sections != 0) { SectionId section_id = __builtin_ctz(sections); sections &= ~(1u << section_id); Hit::TermFrequency tf = term_match_info.term_frequencies[section_id]; + double weighted_tf = tf * section_weights_->GetNormalizedSectionWeight( + schema_type_id, section_id); if (tf != Hit::kNoTermFrequency) { - sum += tf; + sum += weighted_tf; } } return sum; } +SchemaTypeId Bm25fCalculator::GetSchemaTypeId(DocumentId document_id) const { + auto filter_data_or = document_store_->GetDocumentFilterData(document_id); + if (!filter_data_or.ok()) { + // This should never happen. The only failure case for + // GetDocumentFilterData is if the document_id is outside of the range of + // allocated document_ids, which shouldn't be possible since we're getting + // this document_id from the posting lists. + ICING_LOG(WARNING) << IcingStringUtil::StringPrintf( + "No document filter data for document [%d]", document_id); + return kInvalidSchemaTypeId; + } + DocumentFilterData data = filter_data_or.ValueOrDie(); + return data.schema_type_id(); +} + } // namespace lib } // namespace icing diff --git a/icing/scoring/bm25f-calculator.h b/icing/scoring/bm25f-calculator.h index 91b4f24..05009d8 100644 --- a/icing/scoring/bm25f-calculator.h +++ b/icing/scoring/bm25f-calculator.h @@ -22,6 +22,7 @@ #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/legacy/index/icing-bit-util.h" +#include "icing/scoring/section-weights.h" #include "icing/store/corpus-id.h" #include "icing/store/document-store.h" @@ -62,7 +63,8 @@ namespace lib { // see: glossary/bm25 class Bm25fCalculator { public: - explicit Bm25fCalculator(const DocumentStore *document_store_); + explicit Bm25fCalculator(const DocumentStore *document_store_, + std::unique_ptr<SectionWeights> section_weights_); // Precompute and cache statistics relevant to BM25F. // Populates term_id_map_ and corpus_nqi_map_ for use while scoring other @@ -108,18 +110,43 @@ class Bm25fCalculator { } }; + // Returns idf weight for the term and provided corpus. float GetCorpusIdfWeightForTerm(std::string_view term, CorpusId corpus_id); + + // Returns the average document length for the corpus. The average is + // calculated as the sum of tokens in the corpus' documents over the total + // number of documents plus one. float GetCorpusAvgDocLength(CorpusId corpus_id); + + // Returns the normalized term frequency for the term match and document hit. + // This normalizes the term frequency by applying smoothing parameters and + // factoring document length. float ComputedNormalizedTermFrequency( const TermMatchInfo &term_match_info, const DocHitInfo &hit_info, const DocumentAssociatedScoreData &data); + + // Returns the weighted term frequency for the term match and document. For + // each section the term is present, we scale the term frequency by its + // section weight. We return the sum of the weighted term frequencies over all + // sections. float ComputeTermFrequencyForMatchedSections( - CorpusId corpus_id, const TermMatchInfo &term_match_info) const; + CorpusId corpus_id, const TermMatchInfo &term_match_info, + DocumentId document_id) const; + // Returns the schema type id for the document by retrieving it from the + // DocumentFilterData. + SchemaTypeId GetSchemaTypeId(DocumentId document_id) const; + + // Clears cached scoring data and prepares the calculator for a new scoring + // run. void Clear(); const DocumentStore *document_store_; // Does not own. + // Used for accessing normalized section weights when computing the weighted + // term frequency. + std::unique_ptr<SectionWeights> section_weights_; + // Map from query term to compact term ID. // Necessary as a key to the other maps. // The use of the string_view as key here means that the query_term_iterators @@ -130,7 +157,6 @@ class Bm25fCalculator { // Necessary to calculate the normalized term frequency. // This information is cached in the DocumentStore::CorpusScoreCache std::unordered_map<CorpusId, float> corpus_avgdl_map_; - // Map from <corpus ID, term ID> to number of documents containing term q_i, // called n(q_i). // Necessary to calculate IDF(q_i) (inverse document frequency). diff --git a/icing/scoring/ranker.cc b/icing/scoring/ranker.cc index fecee82..117f44c 100644 --- a/icing/scoring/ranker.cc +++ b/icing/scoring/ranker.cc @@ -32,6 +32,7 @@ namespace { // Helper function to wrap the heapify algorithm, it heapifies the target // subtree node in place. +// TODO(b/152934343) refactor the heapify function and making it into a class. void Heapify( std::vector<ScoredDocumentHit>* scored_document_hits, int target_subtree_root_index, @@ -71,6 +72,80 @@ void Heapify( } } +// Heapify the given term vector from top to bottom. Call it after add or +// replace an element at the front of the vector. +void HeapifyTermDown(std::vector<TermMetadata>& scored_terms, + int target_subtree_root_index) { + int heap_size = scored_terms.size(); + if (target_subtree_root_index >= heap_size) { + return; + } + + // Initializes subtree root as the current minimum node. + int min = target_subtree_root_index; + // If we represent a heap in an array/vector, indices of left and right + // children can be calculated as such. + const int left = target_subtree_root_index * 2 + 1; + const int right = target_subtree_root_index * 2 + 2; + + // If left child is smaller than current minimum. + if (left < heap_size && + scored_terms.at(left).hit_count < scored_terms.at(min).hit_count) { + min = left; + } + + // If right child is smaller than current minimum. + if (right < heap_size && + scored_terms.at(right).hit_count < scored_terms.at(min).hit_count) { + min = right; + } + + // If the minimum is not the subtree root, swap and continue heapifying the + // lower level subtree. + if (min != target_subtree_root_index) { + std::swap(scored_terms.at(min), + scored_terms.at(target_subtree_root_index)); + HeapifyTermDown(scored_terms, min); + } +} + +// Heapify the given term vector from bottom to top. Call it after add an +// element at the end of the vector. +void HeapifyTermUp(std::vector<TermMetadata>& scored_terms, + int target_subtree_child_index) { + // If we represent a heap in an array/vector, indices of root can be + // calculated as such. + const int root = (target_subtree_child_index + 1) / 2 - 1; + + // If the current child is smaller than the root, swap and continue heapifying + // the upper level subtree + if (root >= 0 && scored_terms.at(target_subtree_child_index).hit_count < + scored_terms.at(root).hit_count) { + std::swap(scored_terms.at(root), + scored_terms.at(target_subtree_child_index)); + HeapifyTermUp(scored_terms, root); + } +} + +TermMetadata PopRootTerm(std::vector<TermMetadata>& scored_terms) { + if (scored_terms.empty()) { + // Return an invalid TermMetadata as a sentinel value. + return TermMetadata(/*content_in=*/"", /*hit_count_in=*/-1); + } + + // Steps to extract root from heap: + // 1. copy out root + TermMetadata root = scored_terms.at(0); + const size_t last_node_index = scored_terms.size() - 1; + // 2. swap root and the last node + std::swap(scored_terms.at(0), scored_terms.at(last_node_index)); + // 3. remove last node + scored_terms.pop_back(); + // 4. heapify root + HeapifyTermDown(scored_terms, /*target_subtree_root_index=*/0); + return root; +} + // Helper function to extract the root from the heap. The heap structure will be // maintained. // @@ -115,6 +190,19 @@ void BuildHeapInPlace( } } +void PushToTermHeap(TermMetadata term, int number_to_return, + std::vector<TermMetadata>& scored_terms_heap) { + if (scored_terms_heap.size() < number_to_return) { + scored_terms_heap.push_back(std::move(term)); + // We insert at end, so we should heapify bottom up. + HeapifyTermUp(scored_terms_heap, scored_terms_heap.size() - 1); + } else if (scored_terms_heap.at(0).hit_count < term.hit_count) { + scored_terms_heap.at(0) = std::move(term); + // We insert at root, so we should heapify top down. + HeapifyTermDown(scored_terms_heap, /*target_subtree_root_index=*/0); + } +} + std::vector<ScoredDocumentHit> PopTopResultsFromHeap( std::vector<ScoredDocumentHit>* scored_document_hits_heap, int num_results, const ScoredDocumentHitComparator& scored_document_hit_comparator) { @@ -134,5 +222,15 @@ std::vector<ScoredDocumentHit> PopTopResultsFromHeap( return scored_document_hit_result; } +std::vector<TermMetadata> PopAllTermsFromHeap( + std::vector<TermMetadata>& scored_terms_heap) { + std::vector<TermMetadata> top_term_result; + top_term_result.reserve(scored_terms_heap.size()); + while (!scored_terms_heap.empty()) { + top_term_result.push_back(PopRootTerm(scored_terms_heap)); + } + return top_term_result; +} + } // namespace lib } // namespace icing diff --git a/icing/scoring/ranker.h b/icing/scoring/ranker.h index 785c133..81838f3 100644 --- a/icing/scoring/ranker.h +++ b/icing/scoring/ranker.h @@ -17,6 +17,7 @@ #include <vector> +#include "icing/index/term-metadata.h" #include "icing/scoring/scored-document-hit.h" // Provides functionality to get the top N results from an unsorted vector. @@ -39,6 +40,18 @@ std::vector<ScoredDocumentHit> PopTopResultsFromHeap( std::vector<ScoredDocumentHit>* scored_document_hits_heap, int num_results, const ScoredDocumentHitComparator& scored_document_hit_comparator); +// The heap is a min-heap. So that we can avoid some push operations by +// comparing to the root term, and only pushing if greater than root. The time +// complexity for a single push is O(lgK) which K is the number_to_return. +// REQUIRED: scored_terms_heap is not null. +void PushToTermHeap(TermMetadata term, int number_to_return, + std::vector<TermMetadata>& scored_terms_heap); + +// Return all terms from the given terms heap. And since the heap is a min-heap, +// the output vector will be increasing order. +// REQUIRED: scored_terms_heap is not null. +std::vector<TermMetadata> PopAllTermsFromHeap( + std::vector<TermMetadata>& scored_terms_heap); } // namespace lib } // namespace icing diff --git a/icing/scoring/score-and-rank_benchmark.cc b/icing/scoring/score-and-rank_benchmark.cc index e940e98..cc1d995 100644 --- a/icing/scoring/score-and-rank_benchmark.cc +++ b/icing/scoring/score-and-rank_benchmark.cc @@ -117,7 +117,8 @@ void BM_ScoreAndRankDocumentHitsByDocumentScore(benchmark::State& state) { scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get())); + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -220,7 +221,8 @@ void BM_ScoreAndRankDocumentHitsByCreationTime(benchmark::State& state) { ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get())); + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -322,7 +324,8 @@ void BM_ScoreAndRankDocumentHitsNoScoring(benchmark::State& state) { scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::NONE); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(scoring_spec, document_store.get())); + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); int num_to_score = state.range(0); int num_of_documents = state.range(1); @@ -390,6 +393,122 @@ BENCHMARK(BM_ScoreAndRankDocumentHitsNoScoring) ->ArgPair(10000, 18000) ->ArgPair(10000, 20000); +void BM_ScoreAndRankDocumentHitsByRelevanceScoring(benchmark::State& state) { + const std::string base_dir = GetTestTempDir() + "/score_and_rank_benchmark"; + const std::string document_store_dir = base_dir + "/document_store"; + const std::string schema_store_dir = base_dir + "/schema_store"; + + // Creates file directories + Filesystem filesystem; + filesystem.DeleteDirectoryRecursively(base_dir.c_str()); + filesystem.CreateDirectoryRecursively(document_store_dir.c_str()); + filesystem.CreateDirectoryRecursively(schema_store_dir.c_str()); + + Clock clock; + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SchemaStore> schema_store, + SchemaStore::Create(&filesystem, base_dir, &clock)); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentStore::CreateResult create_result, + DocumentStore::Create(&filesystem, document_store_dir, &clock, + schema_store.get())); + std::unique_ptr<DocumentStore> document_store = + std::move(create_result.document_store); + + ICING_ASSERT_OK(schema_store->SetSchema(CreateSchemaWithEmailType())); + + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(scoring_spec, document_store.get(), + schema_store.get())); + + int num_to_score = state.range(0); + int num_of_documents = state.range(1); + + std::mt19937 random_generator; + std::uniform_int_distribution<int> distribution( + 1, std::numeric_limits<int>::max()); + + SectionId section_id = 0; + SectionIdMask section_id_mask = 1U << section_id; + + // Puts documents into document store + std::vector<DocHitInfo> doc_hit_infos; + for (int i = 0; i < num_of_documents; i++) { + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id, + document_store->Put(CreateEmailDocument( + /*id=*/i, /*document_score=*/1, + /*creation_timestamp_ms=*/1), + /*num_tokens=*/10)); + DocHitInfo doc_hit = DocHitInfo(document_id, section_id_mask); + // Set five matches for term "foo" for each document hit. + doc_hit.UpdateSection(section_id, /*hit_term_frequency=*/5); + doc_hit_infos.push_back(doc_hit); + } + + ScoredDocumentHitComparator scored_document_hit_comparator( + /*is_descending=*/true); + + for (auto _ : state) { + // Creates a dummy DocHitInfoIterator with results, we need to pause the + // timer here so that the cost of copying test data is not included. + state.PauseTiming(); + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + // Create a query term iterator that assigns the document hits to term + // "foo". + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + state.ResumeTiming(); + + std::vector<ScoredDocumentHit> scored_document_hits = + scoring_processor->Score(std::move(doc_hit_info_iterator), num_to_score, + &query_term_iterators); + + BuildHeapInPlace(&scored_document_hits, scored_document_hit_comparator); + // Ranks and gets the first page, 20 is a common page size + std::vector<ScoredDocumentHit> results = + PopTopResultsFromHeap(&scored_document_hits, /*num_results=*/20, + scored_document_hit_comparator); + } + + // Clean up + document_store.reset(); + schema_store.reset(); + filesystem.DeleteDirectoryRecursively(base_dir.c_str()); +} +BENCHMARK(BM_ScoreAndRankDocumentHitsByRelevanceScoring) + // num_to_score, num_of_documents in document store + ->ArgPair(1000, 30000) + ->ArgPair(3000, 30000) + ->ArgPair(5000, 30000) + ->ArgPair(7000, 30000) + ->ArgPair(9000, 30000) + ->ArgPair(11000, 30000) + ->ArgPair(13000, 30000) + ->ArgPair(15000, 30000) + ->ArgPair(17000, 30000) + ->ArgPair(19000, 30000) + ->ArgPair(21000, 30000) + ->ArgPair(23000, 30000) + ->ArgPair(25000, 30000) + ->ArgPair(27000, 30000) + ->ArgPair(29000, 30000) + // Starting from this line, we're trying to see if num_of_documents affects + // performance + ->ArgPair(10000, 10000) + ->ArgPair(10000, 12000) + ->ArgPair(10000, 14000) + ->ArgPair(10000, 16000) + ->ArgPair(10000, 18000) + ->ArgPair(10000, 20000); + } // namespace } // namespace lib diff --git a/icing/scoring/scorer.cc b/icing/scoring/scorer.cc index a4734b4..5f33e66 100644 --- a/icing/scoring/scorer.cc +++ b/icing/scoring/scorer.cc @@ -22,6 +22,7 @@ #include "icing/index/iterator/doc-hit-info-iterator.h" #include "icing/proto/scoring.pb.h" #include "icing/scoring/bm25f-calculator.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/util/status-macros.h" @@ -156,11 +157,12 @@ class NoScorer : public Scorer { }; libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( - ScoringSpecProto::RankingStrategy::Code rank_by, double default_score, - const DocumentStore* document_store) { + const ScoringSpecProto& scoring_spec, double default_score, + const DocumentStore* document_store, const SchemaStore* schema_store) { ICING_RETURN_ERROR_IF_NULL(document_store); + ICING_RETURN_ERROR_IF_NULL(schema_store); - switch (rank_by) { + switch (scoring_spec.rank_by()) { case ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE: return std::make_unique<DocumentScoreScorer>(document_store, default_score); @@ -168,7 +170,12 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( return std::make_unique<DocumentCreationTimestampScorer>(document_store, default_score); case ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE: { - auto bm25f_calculator = std::make_unique<Bm25fCalculator>(document_store); + ICING_ASSIGN_OR_RETURN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store, scoring_spec)); + + auto bm25f_calculator = std::make_unique<Bm25fCalculator>( + document_store, std::move(section_weights)); return std::make_unique<RelevanceScoreScorer>(std::move(bm25f_calculator), default_score); } @@ -183,8 +190,8 @@ libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Scorer::Create( case ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP: [[fallthrough]]; case ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP: - return std::make_unique<UsageScorer>(document_store, rank_by, - default_score); + return std::make_unique<UsageScorer>( + document_store, scoring_spec.rank_by(), default_score); case ScoringSpecProto::RankingStrategy::NONE: return std::make_unique<NoScorer>(default_score); } diff --git a/icing/scoring/scorer.h b/icing/scoring/scorer.h index a22db0f..abdd5ca 100644 --- a/icing/scoring/scorer.h +++ b/icing/scoring/scorer.h @@ -43,8 +43,8 @@ class Scorer { // FAILED_PRECONDITION on any null pointer input // INVALID_ARGUMENT if fails to create an instance static libtextclassifier3::StatusOr<std::unique_ptr<Scorer>> Create( - ScoringSpecProto::RankingStrategy::Code rank_by, double default_score, - const DocumentStore* document_store); + const ScoringSpecProto& scoring_spec, double default_score, + const DocumentStore* document_store, const SchemaStore* schema_store); // Returns a non-negative score of a document. The score can be a // document-associated score which comes from the DocumentProto directly, an diff --git a/icing/scoring/scorer_test.cc b/icing/scoring/scorer_test.cc index 8b89514..f22a31a 100644 --- a/icing/scoring/scorer_test.cc +++ b/icing/scoring/scorer_test.cc @@ -27,6 +27,7 @@ #include "icing/proto/scoring.pb.h" #include "icing/schema-builder.h" #include "icing/schema/schema-store.h" +#include "icing/scoring/section-weights.h" #include "icing/store/document-id.h" #include "icing/store/document-store.h" #include "icing/testing/common-matchers.h" @@ -91,6 +92,8 @@ class ScorerTest : public testing::Test { DocumentStore* document_store() { return document_store_.get(); } + SchemaStore* schema_store() { return schema_store_.get(); } + const FakeClock& fake_clock1() { return fake_clock1_; } const FakeClock& fake_clock2() { return fake_clock2_; } @@ -121,17 +124,37 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri, return usage_report; } -TEST_F(ScorerTest, CreationWithNullPointerShouldFail) { - EXPECT_THAT(Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/0, /*document_store=*/nullptr), - StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +ScoringSpecProto CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::Code ranking_strategy) { + ScoringSpecProto scoring_spec; + scoring_spec.set_rank_by(ranking_strategy); + return scoring_spec; +} + +TEST_F(ScorerTest, CreationWithNullDocumentStoreShouldFail) { + EXPECT_THAT( + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, /*document_store=*/nullptr, + schema_store()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(ScorerTest, CreationWithNullSchemaStoreShouldFail) { + EXPECT_THAT( + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, document_store(), + /*schema_store=*/nullptr), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentDoesntExist) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); // Non existent document id DocHitInfo docHitInfo = DocHitInfo(/*document_id_in=*/1); @@ -153,8 +176,9 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsDeleted) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -185,8 +209,9 @@ TEST_F(ScorerTest, ShouldGetDefaultScoreIfDocumentIsExpired) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); @@ -213,8 +238,9 @@ TEST_F(ScorerTest, ShouldGetDefaultDocumentScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(0)); @@ -235,8 +261,9 @@ TEST_F(ScorerTest, ShouldGetCorrectDocumentScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(5)); @@ -259,8 +286,9 @@ TEST_F(ScorerTest, QueryIteratorNullRelevanceScoreShouldReturnDefaultScore) { document_store()->Put(test_document)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE, - /*default_score=*/10, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE), + /*default_score=*/10, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer->GetScore(docHitInfo), Eq(10)); @@ -290,8 +318,9 @@ TEST_F(ScorerTest, ShouldGetCorrectCreationTimestampScore) { document_store()->Put(test_document2)); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::CREATION_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(document_id1); DocHitInfo docHitInfo2 = DocHitInfo(document_id2); @@ -316,16 +345,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -357,16 +389,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -398,16 +433,19 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageCountScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE1_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE2_COUNT), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create(ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::USAGE_TYPE3_COUNT), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -439,19 +477,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType1) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -499,19 +540,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType2) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -559,19 +603,22 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { // Create 3 scorers for 3 different usage types. ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer2, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE2_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE2_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer3, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE3_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE3_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); EXPECT_THAT(scorer1->GetScore(docHitInfo), Eq(0)); EXPECT_THAT(scorer2->GetScore(docHitInfo), Eq(0)); @@ -607,8 +654,9 @@ TEST_F(ScorerTest, ShouldGetCorrectUsageTimestampScoreForType3) { TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer, - Scorer::Create(ScoringSpecProto::RankingStrategy::NONE, - /*default_score=*/3, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE), + /*default_score=*/3, document_store(), schema_store())); DocHitInfo docHitInfo1 = DocHitInfo(/*document_id_in=*/0); DocHitInfo docHitInfo2 = DocHitInfo(/*document_id_in=*/1); @@ -618,8 +666,10 @@ TEST_F(ScorerTest, NoScorerShouldAlwaysReturnDefaultScore) { EXPECT_THAT(scorer->GetScore(docHitInfo3), Eq(3)); ICING_ASSERT_OK_AND_ASSIGN( - scorer, Scorer::Create(ScoringSpecProto::RankingStrategy::NONE, - /*default_score=*/111, document_store())); + scorer, + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy::NONE), + /*default_score=*/111, document_store(), schema_store())); docHitInfo1 = DocHitInfo(/*document_id_in=*/4); docHitInfo2 = DocHitInfo(/*document_id_in=*/5); @@ -643,9 +693,10 @@ TEST_F(ScorerTest, ShouldScaleUsageTimestampScoreForMaxTimestamp) { ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<Scorer> scorer1, - Scorer::Create( - ScoringSpecProto::RankingStrategy::USAGE_TYPE1_LAST_USED_TIMESTAMP, - /*default_score=*/0, document_store())); + Scorer::Create(CreateScoringSpecForRankingStrategy( + ScoringSpecProto::RankingStrategy:: + USAGE_TYPE1_LAST_USED_TIMESTAMP), + /*default_score=*/0, document_store(), schema_store())); DocHitInfo docHitInfo = DocHitInfo(document_id); // Create usage report for the maximum allowable timestamp. diff --git a/icing/scoring/scoring-processor.cc b/icing/scoring/scoring-processor.cc index 24480ef..e36f3bb 100644 --- a/icing/scoring/scoring-processor.cc +++ b/icing/scoring/scoring-processor.cc @@ -39,19 +39,20 @@ constexpr double kDefaultScoreInAscendingOrder = libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> ScoringProcessor::Create(const ScoringSpecProto& scoring_spec, - const DocumentStore* document_store) { + const DocumentStore* document_store, + const SchemaStore* schema_store) { ICING_RETURN_ERROR_IF_NULL(document_store); + ICING_RETURN_ERROR_IF_NULL(schema_store); bool is_descending_order = scoring_spec.order_by() == ScoringSpecProto::Order::DESC; ICING_ASSIGN_OR_RETURN( std::unique_ptr<Scorer> scorer, - Scorer::Create(scoring_spec.rank_by(), + Scorer::Create(scoring_spec, is_descending_order ? kDefaultScoreInDescendingOrder : kDefaultScoreInAscendingOrder, - document_store)); - + document_store, schema_store)); // Using `new` to access a non-public constructor. return std::unique_ptr<ScoringProcessor>( new ScoringProcessor(std::move(scorer))); diff --git a/icing/scoring/scoring-processor.h b/icing/scoring/scoring-processor.h index 2289605..e7d09b1 100644 --- a/icing/scoring/scoring-processor.h +++ b/icing/scoring/scoring-processor.h @@ -40,8 +40,8 @@ class ScoringProcessor { // A ScoringProcessor on success // FAILED_PRECONDITION on any null pointer input static libtextclassifier3::StatusOr<std::unique_ptr<ScoringProcessor>> Create( - const ScoringSpecProto& scoring_spec, - const DocumentStore* document_store); + const ScoringSpecProto& scoring_spec, const DocumentStore* document_store, + const SchemaStore* schema_store); // Assigns scores to DocHitInfos from the given DocHitInfoIterator and returns // a vector of ScoredDocumentHits. The size of results is no more than diff --git a/icing/scoring/scoring-processor_test.cc b/icing/scoring/scoring-processor_test.cc index 125e2a7..7e5cb0f 100644 --- a/icing/scoring/scoring-processor_test.cc +++ b/icing/scoring/scoring-processor_test.cc @@ -69,11 +69,24 @@ class ScoringProcessorTest : public testing::Test { // Creates a simple email schema SchemaProto test_email_schema = SchemaBuilder() - .AddType(SchemaTypeConfigBuilder().SetType("email").AddProperty( - PropertyConfigBuilder() - .SetName("subject") - .SetDataType(TYPE_STRING) - .SetCardinality(CARDINALITY_OPTIONAL))) + .AddType(SchemaTypeConfigBuilder() + .SetType("email") + .AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(TYPE_STRING) + .SetCardinality(CARDINALITY_OPTIONAL))) .Build(); ICING_ASSERT_OK(schema_store_->SetSchema(test_email_schema)); } @@ -86,6 +99,8 @@ class ScoringProcessorTest : public testing::Test { DocumentStore* document_store() { return document_store_.get(); } + SchemaStore* schema_store() { return schema_store_.get(); } + private: const std::string test_dir_; const std::string doc_store_dir_; @@ -139,16 +154,46 @@ UsageReport CreateUsageReport(std::string name_space, std::string uri, return usage_report; } -TEST_F(ScoringProcessorTest, CreationWithNullPointerShouldFail) { +TypePropertyWeights CreateTypePropertyWeights( + std::string schema_type, std::vector<PropertyWeight> property_weights) { + TypePropertyWeights type_property_weights; + type_property_weights.set_schema_type(std::move(schema_type)); + type_property_weights.mutable_property_weights()->Reserve( + property_weights.size()); + + for (PropertyWeight& property_weight : property_weights) { + *type_property_weights.add_property_weights() = std::move(property_weight); + } + + return type_property_weights; +} + +PropertyWeight CreatePropertyWeight(std::string path, double weight) { + PropertyWeight property_weight; + property_weight.set_path(std::move(path)); + property_weight.set_weight(weight); + return property_weight; +} + +TEST_F(ScoringProcessorTest, CreationWithNullDocumentStoreShouldFail) { + ScoringSpecProto spec_proto; + EXPECT_THAT(ScoringProcessor::Create(spec_proto, /*document_store=*/nullptr, + schema_store()), + StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); +} + +TEST_F(ScoringProcessorTest, CreationWithNullSchemaStoreShouldFail) { ScoringSpecProto spec_proto; - EXPECT_THAT(ScoringProcessor::Create(spec_proto, /*document_store=*/nullptr), + EXPECT_THAT(ScoringProcessor::Create(spec_proto, document_store(), + /*schema_store=*/nullptr), StatusIs(libtextclassifier3::StatusCode::FAILED_PRECONDITION)); } TEST_F(ScoringProcessorTest, ShouldCreateInstance) { ScoringSpecProto spec_proto; spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::DOCUMENT_SCORE); - ICING_EXPECT_OK(ScoringProcessor::Create(spec_proto, document_store())); + ICING_EXPECT_OK( + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); } TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { @@ -163,7 +208,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleEmptyDocHitIterator) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/5), @@ -189,7 +234,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNonPositiveNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/-1), @@ -219,7 +264,7 @@ TEST_F(ScoringProcessorTest, ShouldRespectNumToScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/2), @@ -251,7 +296,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByDocumentScore) { // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -306,7 +351,7 @@ TEST_F(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -316,11 +361,11 @@ TEST_F(ScoringProcessorTest, // the document's length determines the final score. Document shorter than the // average corpus length are slightly boosted. ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask, - /*score=*/0.255482); + /*score=*/0.187114); ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask, - /*score=*/0.115927); + /*score=*/0.084904); ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask, - /*score=*/0.166435); + /*score=*/0.121896); EXPECT_THAT( scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3, &query_term_iterators), @@ -375,7 +420,7 @@ TEST_F(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -384,11 +429,11 @@ TEST_F(ScoringProcessorTest, // Since the three documents all contain the query term "foo" exactly once // and they have the same length, they will have the same BM25F scoret. ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask, - /*score=*/0.16173716); + /*score=*/0.118455); ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask, - /*score=*/0.16173716); + /*score=*/0.118455); ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask, - /*score=*/0.16173716); + /*score=*/0.118455); EXPECT_THAT( scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3, &query_term_iterators), @@ -448,7 +493,7 @@ TEST_F(ScoringProcessorTest, // Creates a ScoringProcessor ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> query_term_iterators; @@ -457,11 +502,11 @@ TEST_F(ScoringProcessorTest, // Since the three documents all have the same length, the score is decided by // the frequency of the query term "foo". ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask1, - /*score=*/0.309497); + /*score=*/0.226674); ScoredDocumentHit expected_scored_doc_hit2(document_id2, section_id_mask2, - /*score=*/0.16173716); + /*score=*/0.118455); ScoredDocumentHit expected_scored_doc_hit3(document_id3, section_id_mask3, - /*score=*/0.268599); + /*score=*/0.196720); EXPECT_THAT( scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3, &query_term_iterators), @@ -470,6 +515,280 @@ TEST_F(ScoringProcessorTest, EqualsScoredDocumentHit(expected_scored_doc_hit3))); } +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_HitTermWithZeroFrequency) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/10)); + + // Document 1 contains the term "foo" 0 times in the "subject" property + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(/*section_id*/ 0, /*hit_term_frequency=*/0); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1}; + + // Creates a dummy DocHitInfoIterator with 1 result for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask section_id_mask1 = 0b00000001; + + // Since the document hit has zero frequency, expect a score of zero. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, section_id_mask1, + /*score=*/0.000000); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/1, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_SameHitFrequencyDifferentPropertyWeights) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store()->Put(document2, /*num_tokens=*/1)); + + // Document 1 contains the term "foo" 1 time in the "body" property + SectionId body_section_id = 0; + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); + + // Document 2 contains the term "foo" 1 time in the "subject" property + SectionId subject_section_id = 1; + DocHitInfo doc_hit_info2(document_id2); + doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2}; + + // Creates a dummy DocHitInfoIterator with 2 results for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + PropertyWeight body_property_weight = + CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5); + PropertyWeight subject_property_weight = + CreatePropertyWeight(/*path=*/"subject", /*weight=*/2.0); + *spec_proto.add_type_property_weights() = CreateTypePropertyWeights( + /*schema_type=*/"email", {body_property_weight, subject_property_weight}); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask body_section_id_mask = 1U << body_section_id; + SectionIdMask subject_section_id_mask = 1U << subject_section_id; + + // We expect document 2 to have a higher score than document 1 as it matches + // "foo" in the "subject" property, which is weighed higher than the "body" + // property. Final scores are computed with smoothing applied. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, body_section_id_mask, + /*score=*/0.053624); + ScoredDocumentHit expected_scored_doc_hit2(document_id2, + subject_section_id_mask, + /*score=*/0.153094); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/2, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1), + EqualsScoredDocumentHit(expected_scored_doc_hit2))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_WithImplicitPropertyWeight) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/1)); + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id2, + document_store()->Put(document2, /*num_tokens=*/1)); + + // Document 1 contains the term "foo" 1 time in the "body" property + SectionId body_section_id = 0; + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); + + // Document 2 contains the term "foo" 1 time in the "subject" property + SectionId subject_section_id = 1; + DocHitInfo doc_hit_info2(document_id2); + doc_hit_info2.UpdateSection(subject_section_id, /*hit_term_frequency=*/1); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1, doc_hit_info2}; + + // Creates a dummy DocHitInfoIterator with 2 results for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + PropertyWeight body_property_weight = + CreatePropertyWeight(/*path=*/"body", /*weight=*/0.5); + *spec_proto.add_type_property_weights() = CreateTypePropertyWeights( + /*schema_type=*/"email", {body_property_weight}); + + // Creates a ScoringProcessor + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask body_section_id_mask = 1U << body_section_id; + SectionIdMask subject_section_id_mask = 1U << subject_section_id; + + // We expect document 2 to have a higher score than document 1 as it matches + // "foo" in the "subject" property, which is weighed higher than the "body" + // property. This is because the "subject" property is implictly given a + // a weight of 1.0, the default weight value. Final scores are computed with + // smoothing applied. + ScoredDocumentHit expected_scored_doc_hit1(document_id1, body_section_id_mask, + /*score=*/0.094601); + ScoredDocumentHit expected_scored_doc_hit2(document_id2, + subject_section_id_mask, + /*score=*/0.153094); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/2, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit1), + EqualsScoredDocumentHit(expected_scored_doc_hit2))); +} + +TEST_F(ScoringProcessorTest, + ShouldScoreByRelevanceScore_WithDefaultPropertyWeight) { + DocumentProto document1 = + CreateDocument("icing", "email/1", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + DocumentProto document2 = + CreateDocument("icing", "email/2", kDefaultScore, + /*creation_timestamp_ms=*/kDefaultCreationTimestampMs); + + ICING_ASSERT_OK_AND_ASSIGN( + DocumentId document_id1, + document_store()->Put(document1, /*num_tokens=*/1)); + + // Document 1 contains the term "foo" 1 time in the "body" property + SectionId body_section_id = 0; + DocHitInfo doc_hit_info1(document_id1); + doc_hit_info1.UpdateSection(body_section_id, /*hit_term_frequency=*/1); + + // Creates input doc_hit_infos and expected output scored_document_hits + std::vector<DocHitInfo> doc_hit_infos = {doc_hit_info1}; + + // Creates a dummy DocHitInfoIterator with 1 result for the query "foo" + std::unique_ptr<DocHitInfoIterator> doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + ScoringSpecProto spec_proto; + spec_proto.set_rank_by(ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + *spec_proto.add_type_property_weights() = + CreateTypePropertyWeights(/*schema_type=*/"email", {}); + + // Creates a ScoringProcessor with no explicit weights set. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor, + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); + + ScoringSpecProto spec_proto_with_weights; + spec_proto_with_weights.set_rank_by( + ScoringSpecProto::RankingStrategy::RELEVANCE_SCORE); + + PropertyWeight body_property_weight = CreatePropertyWeight(/*path=*/"body", + /*weight=*/1.0); + *spec_proto_with_weights.add_type_property_weights() = + CreateTypePropertyWeights(/*schema_type=*/"email", + {body_property_weight}); + + // Creates a ScoringProcessor with default weight set for "body" property. + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<ScoringProcessor> scoring_processor_with_weights, + ScoringProcessor::Create(spec_proto_with_weights, document_store(), + schema_store())); + + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators; + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + // Create a doc hit iterator + std::unordered_map<std::string, std::unique_ptr<DocHitInfoIterator>> + query_term_iterators_scoring_with_weights; + query_term_iterators_scoring_with_weights["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + SectionIdMask body_section_id_mask = 1U << body_section_id; + + // We expect document 1 to have the same score whether a weight is explicitly + // set to 1.0 or implictly scored with the default weight. Final scores are + // computed with smoothing applied. + ScoredDocumentHit expected_scored_doc_hit(document_id1, body_section_id_mask, + /*score=*/0.208191); + EXPECT_THAT( + scoring_processor->Score(std::move(doc_hit_info_iterator), + /*num_to_score=*/1, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit))); + + // Restore ownership of doc hit iterator and query term iterator to test. + doc_hit_info_iterator = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + query_term_iterators["foo"] = + std::make_unique<DocHitInfoIteratorDummy>(doc_hit_infos, "foo"); + + EXPECT_THAT(scoring_processor_with_weights->Score( + std::move(doc_hit_info_iterator), + /*num_to_score=*/1, &query_term_iterators), + ElementsAre(EqualsScoredDocumentHit(expected_scored_doc_hit))); +} + TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { DocumentProto document1 = CreateDocument("icing", "email/1", kDefaultScore, @@ -509,7 +828,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByCreationTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -569,7 +888,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageCount) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -629,7 +948,7 @@ TEST_F(ScoringProcessorTest, ShouldScoreByUsageTimestamp) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), @@ -665,7 +984,7 @@ TEST_F(ScoringProcessorTest, ShouldHandleNoScores) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/4), ElementsAre(EqualsScoredDocumentHit(scored_document_hit_default), @@ -714,7 +1033,7 @@ TEST_F(ScoringProcessorTest, ShouldWrapResultsWhenNoScoring) { // Creates a ScoringProcessor which ranks in descending order ICING_ASSERT_OK_AND_ASSIGN( std::unique_ptr<ScoringProcessor> scoring_processor, - ScoringProcessor::Create(spec_proto, document_store())); + ScoringProcessor::Create(spec_proto, document_store(), schema_store())); EXPECT_THAT(scoring_processor->Score(std::move(doc_hit_info_iterator), /*num_to_score=*/3), diff --git a/icing/scoring/section-weights.cc b/icing/scoring/section-weights.cc new file mode 100644 index 0000000..c4afe7f --- /dev/null +++ b/icing/scoring/section-weights.cc @@ -0,0 +1,146 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/section-weights.h" + +#include <cfloat> +#include <unordered_map> +#include <utility> + +#include "icing/proto/scoring.pb.h" +#include "icing/schema/section.h" +#include "icing/util/logging.h" + +namespace icing { +namespace lib { + +namespace { + +// Normalizes all weights in the map to be in range (0.0, 1.0], where the max +// weight is normalized to 1.0. +inline void NormalizeSectionWeights( + double max_weight, std::unordered_map<SectionId, double>& section_weights) { + for (auto& raw_weight : section_weights) { + raw_weight.second = raw_weight.second / max_weight; + } +} +} // namespace + +libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>> +SectionWeights::Create(const SchemaStore* schema_store, + const ScoringSpecProto& scoring_spec) { + ICING_RETURN_ERROR_IF_NULL(schema_store); + + std::unordered_map<SchemaTypeId, NormalizedSectionWeights> + schema_property_weight_map; + for (const TypePropertyWeights& type_property_weights : + scoring_spec.type_property_weights()) { + std::string_view schema_type = type_property_weights.schema_type(); + auto schema_type_id_or = schema_store->GetSchemaTypeId(schema_type); + if (!schema_type_id_or.ok()) { + ICING_LOG(WARNING) << "No schema type id found for schema type: " + << schema_type; + continue; + } + SchemaTypeId schema_type_id = schema_type_id_or.ValueOrDie(); + auto section_metadata_list_or = + schema_store->GetSectionMetadata(schema_type.data()); + if (!section_metadata_list_or.ok()) { + ICING_LOG(WARNING) << "No metadata found for schema type: " + << schema_type; + continue; + } + + const std::vector<SectionMetadata>* metadata_list = + section_metadata_list_or.ValueOrDie(); + + std::unordered_map<std::string, double> property_paths_weights; + for (const PropertyWeight& property_weight : + type_property_weights.property_weights()) { + double property_path_weight = property_weight.weight(); + + // Return error on negative and zero weights. + if (property_path_weight <= 0.0) { + return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf( + "Property weight for property path \"%s\" is negative or zero. " + "Negative and zero weights are invalid.", + property_weight.path().c_str())); + } + property_paths_weights.insert( + {property_weight.path(), property_path_weight}); + } + NormalizedSectionWeights normalized_section_weights = + ExtractNormalizedSectionWeights(property_paths_weights, *metadata_list); + + schema_property_weight_map.insert( + {schema_type_id, + {/*section_weights*/ std::move( + normalized_section_weights.section_weights), + /*default_weight*/ normalized_section_weights.default_weight}}); + } + // Using `new` to access a non-public constructor. + return std::unique_ptr<SectionWeights>( + new SectionWeights(std::move(schema_property_weight_map))); +} + +double SectionWeights::GetNormalizedSectionWeight(SchemaTypeId schema_type_id, + SectionId section_id) const { + auto schema_type_map = schema_section_weight_map_.find(schema_type_id); + if (schema_type_map == schema_section_weight_map_.end()) { + // Return default weight if the schema type has no weights specified. + return kDefaultSectionWeight; + } + + auto section_weight = + schema_type_map->second.section_weights.find(section_id); + if (section_weight == schema_type_map->second.section_weights.end()) { + // If there is no entry for SectionId, the weight is implicitly the + // normalized default weight. + return schema_type_map->second.default_weight; + } + return section_weight->second; +} + +inline SectionWeights::NormalizedSectionWeights +SectionWeights::ExtractNormalizedSectionWeights( + const std::unordered_map<std::string, double>& raw_weights, + const std::vector<SectionMetadata>& metadata_list) { + double max_weight = 0.0; + std::unordered_map<SectionId, double> section_weights; + for (const SectionMetadata& section_metadata : metadata_list) { + std::string_view metadata_path = section_metadata.path; + double section_weight = kDefaultSectionWeight; + auto iter = raw_weights.find(metadata_path.data()); + if (iter != raw_weights.end()) { + section_weight = iter->second; + section_weights.insert({section_metadata.id, section_weight}); + } + // Replace max if we see new max weight. + max_weight = std::max(max_weight, section_weight); + } + + NormalizeSectionWeights(max_weight, section_weights); + // Set normalized default weight to 1.0 in case there is no section + // metadata and max_weight is 0.0 (we should not see this case). + double normalized_default_weight = max_weight == 0.0 + ? kDefaultSectionWeight + : kDefaultSectionWeight / max_weight; + SectionWeights::NormalizedSectionWeights normalized_section_weights = + SectionWeights::NormalizedSectionWeights(); + normalized_section_weights.section_weights = std::move(section_weights); + normalized_section_weights.default_weight = normalized_default_weight; + return normalized_section_weights; +} +} // namespace lib +} // namespace icing diff --git a/icing/scoring/section-weights.h b/icing/scoring/section-weights.h new file mode 100644 index 0000000..23a9188 --- /dev/null +++ b/icing/scoring/section-weights.h @@ -0,0 +1,95 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ICING_SCORING_SECTION_WEIGHTS_H_ +#define ICING_SCORING_SECTION_WEIGHTS_H_ + +#include <unordered_map> + +#include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/schema/schema-store.h" +#include "icing/store/document-store.h" + +namespace icing { +namespace lib { + +inline constexpr double kDefaultSectionWeight = 1.0; + +// Provides functions for setting and retrieving section weights for schema +// type properties. Section weights are used to promote and demote term matches +// in sections when scoring results. Section weights are provided by property +// path, and can range from (0, DBL_MAX]. The SectionId is matched to the +// property path by going over the schema type's section metadata. Weights that +// correspond to a valid property path are then normalized against the maxmium +// section weight, and put into map for quick access for scorers. By default, +// a section is given a raw, pre-normalized weight of 1.0. +class SectionWeights { + public: + // SectionWeights instances should not be copied. + SectionWeights(const SectionWeights&) = delete; + SectionWeights& operator=(const SectionWeights&) = delete; + + // Factory function to create a SectionWeights instance. Raw weights are + // provided through the ScoringSpecProto. Provided property paths for weights + // are validated against the schema type's section metadata. If the property + // path doesn't exist, the property weight is ignored. If a weight is 0 or + // negative, an invalid argument error is returned. Raw weights are then + // normalized against the maximum weight for that schema type. + // + // Returns: + // A SectionWeights instance on success + // FAILED_PRECONDITION on any null pointer input + // INVALID_ARGUMENT if a provided weight for a property path is less than or + // equal to 0. + static libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>> Create( + const SchemaStore* schema_store, const ScoringSpecProto& scoring_spec); + + // Returns the normalized section weight by SchemaTypeId and SectionId. If + // the SchemaTypeId, or the SectionId for a SchemaTypeId, is not found in the + // normalized weights map, the default weight is returned instead. + double GetNormalizedSectionWeight(SchemaTypeId schema_type_id, + SectionId section_id) const; + + private: + // Holds the normalized section weights for a schema type, as well as the + // normalized default weight for sections that have no weight set. + struct NormalizedSectionWeights { + std::unordered_map<SectionId, double> section_weights; + double default_weight; + }; + + explicit SectionWeights( + const std::unordered_map<SchemaTypeId, NormalizedSectionWeights> + schema_section_weight_map) + : schema_section_weight_map_(std::move(schema_section_weight_map)) {} + + // Creates a map of section ids to normalized weights from the raw property + // path weight map and section metadata and calculates the normalized default + // section weight. + static inline SectionWeights::NormalizedSectionWeights + ExtractNormalizedSectionWeights( + const std::unordered_map<std::string, double>& raw_weights, + const std::vector<SectionMetadata>& metadata_list); + + // A map of (SchemaTypeId -> SectionId -> Normalized Weight), allows for fast + // look up of normalized weights. This is precomputed when creating a + // SectionWeights instance. + std::unordered_map<SchemaTypeId, NormalizedSectionWeights> + schema_section_weight_map_; +}; + +} // namespace lib +} // namespace icing + +#endif // ICING_SCORING_SECTION_WEIGHTS_H_ diff --git a/icing/scoring/section-weights_test.cc b/icing/scoring/section-weights_test.cc new file mode 100644 index 0000000..b90c3d5 --- /dev/null +++ b/icing/scoring/section-weights_test.cc @@ -0,0 +1,386 @@ +// Copyright (C) 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/scoring/section-weights.h" + +#include <cfloat> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/proto/scoring.pb.h" +#include "icing/schema-builder.h" +#include "icing/testing/common-matchers.h" +#include "icing/testing/fake-clock.h" +#include "icing/testing/tmp-directory.h" + +namespace icing { +namespace lib { + +namespace { +using ::testing::Eq; + +class SectionWeightsTest : public testing::Test { + protected: + SectionWeightsTest() + : test_dir_(GetTestTempDir() + "/icing"), + schema_store_dir_(test_dir_ + "/schema_store") {} + + void SetUp() override { + // Creates file directories + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + filesystem_.CreateDirectoryRecursively(schema_store_dir_.c_str()); + + ICING_ASSERT_OK_AND_ASSIGN( + schema_store_, + SchemaStore::Create(&filesystem_, test_dir_, &fake_clock_)); + + SchemaTypeConfigProto sender_schema = + SchemaTypeConfigBuilder() + .SetType("sender") + .AddProperty(PropertyConfigBuilder() + .SetName("name") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .Build(); + SchemaTypeConfigProto email_schema = + SchemaTypeConfigBuilder() + .SetType("email") + .AddProperty( + PropertyConfigBuilder() + .SetName("subject") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(PropertyConfigProto_DataType_Code_STRING) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .AddProperty( + PropertyConfigBuilder() + .SetName("body") + .SetDataTypeString( + TermMatchType::PREFIX, + StringIndexingConfig::TokenizerType::PLAIN) + .SetDataType(PropertyConfigProto_DataType_Code_STRING) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .AddProperty(PropertyConfigBuilder() + .SetName("sender") + .SetDataTypeDocument( + "sender", /*index_nested_properties=*/true) + .SetCardinality( + PropertyConfigProto_Cardinality_Code_OPTIONAL)) + .Build(); + SchemaProto schema = + SchemaBuilder().AddType(sender_schema).AddType(email_schema).Build(); + + ICING_ASSERT_OK(schema_store_->SetSchema(schema)); + } + + void TearDown() override { + schema_store_.reset(); + filesystem_.DeleteDirectoryRecursively(test_dir_.c_str()); + } + + SchemaStore *schema_store() { return schema_store_.get(); } + + private: + const std::string test_dir_; + const std::string schema_store_dir_; + Filesystem filesystem_; + FakeClock fake_clock_; + std::unique_ptr<SchemaStore> schema_store_; +}; + +TEST_F(SectionWeightsTest, ShouldNormalizeSinglePropertyWeight) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("sender"); + + PropertyWeight *property_weight = + type_property_weights->add_property_weights(); + property_weight->set_weight(5.0); + property_weight->set_path("name"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id, + schema_store()->GetSchemaTypeId("sender")); + + // section_id 0 corresponds to property "name". + // We expect 1.0 as there is only one property in the "sender" schema type + // so it should take the max normalized weight of 1.0. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id, + /*section_id=*/0), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldAcceptMaxWeightValue) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("sender"); + + PropertyWeight *property_weight = + type_property_weights->add_property_weights(); + property_weight->set_weight(DBL_MAX); + property_weight->set_path("name"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id, + schema_store()->GetSchemaTypeId("sender")); + + // section_id 0 corresponds to property "name". + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id, + /*section_id=*/0), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldFailWithNegativeWeights) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + PropertyWeight *body_propery_weight = + type_property_weights->add_property_weights(); + body_propery_weight->set_weight(-100.0); + body_propery_weight->set_path("body"); + + EXPECT_THAT(SectionWeights::Create(schema_store(), spec_proto).status(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(SectionWeightsTest, ShouldFailWithZeroWeight) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("sender"); + + PropertyWeight *property_weight = + type_property_weights->add_property_weights(); + property_weight->set_weight(0.0); + property_weight->set_path("name"); + + EXPECT_THAT(SectionWeights::Create(schema_store(), spec_proto).status(), + StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); +} + +TEST_F(SectionWeightsTest, ShouldReturnDefaultIfTypePropertyWeightsNotSet) { + ScoringSpecProto spec_proto; + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(kDefaultSectionWeight)); +} + +TEST_F(SectionWeightsTest, ShouldSetNestedPropertyWeights) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + PropertyWeight *body_property_weight = + type_property_weights->add_property_weights(); + body_property_weight->set_weight(1.0); + body_property_weight->set_path("body"); + + PropertyWeight *subject_property_weight = + type_property_weights->add_property_weights(); + subject_property_weight->set_weight(100.0); + subject_property_weight->set_path("subject"); + + PropertyWeight *nested_property_weight = + type_property_weights->add_property_weights(); + nested_property_weight->set_weight(50.0); + nested_property_weight->set_path("sender.name"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + // Normalized weight for "body" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(0.01)); + // Normalized weight for "sender.name" property (the nested property). + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(0.5)); + // Normalized weight for "subject" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldNormalizeIfAllWeightsBelowOne) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + PropertyWeight *body_property_weight = + type_property_weights->add_property_weights(); + body_property_weight->set_weight(0.1); + body_property_weight->set_path("body"); + + PropertyWeight *sender_name_weight = + type_property_weights->add_property_weights(); + sender_name_weight->set_weight(0.2); + sender_name_weight->set_path("sender.name"); + + PropertyWeight *subject_property_weight = + type_property_weights->add_property_weights(); + subject_property_weight->set_weight(0.4); + subject_property_weight->set_path("subject"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + // Normalized weight for "body" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(1.0 / 4.0)); + // Normalized weight for "sender.name" property (the nested property). + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(2.0 / 4.0)); + // Normalized weight for "subject" property. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldSetNestedPropertyWeightSeparatelyForTypes) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *email_type_property_weights = + spec_proto.add_type_property_weights(); + email_type_property_weights->set_schema_type("email"); + + PropertyWeight *body_property_weight = + email_type_property_weights->add_property_weights(); + body_property_weight->set_weight(1.0); + body_property_weight->set_path("body"); + + PropertyWeight *subject_property_weight = + email_type_property_weights->add_property_weights(); + subject_property_weight->set_weight(100.0); + subject_property_weight->set_path("subject"); + + PropertyWeight *sender_name_property_weight = + email_type_property_weights->add_property_weights(); + sender_name_property_weight->set_weight(50.0); + sender_name_property_weight->set_path("sender.name"); + + TypePropertyWeights *sender_type_property_weights = + spec_proto.add_type_property_weights(); + sender_type_property_weights->set_schema_type("sender"); + + PropertyWeight *sender_property_weight = + sender_type_property_weights->add_property_weights(); + sender_property_weight->set_weight(25.0); + sender_property_weight->set_path("sender"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId sender_schema_type_id, + schema_store()->GetSchemaTypeId("sender")); + + // Normalized weight for "sender.name" property (the nested property) + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(0.5)); + // Normalized weight for "name" property for "sender" schema type. As it is + // the only property of the type, it should take the max normalized weight of + // 1.0. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(sender_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +TEST_F(SectionWeightsTest, ShouldSkipNonExistentPathWhenSettingWeights) { + ScoringSpecProto spec_proto; + + TypePropertyWeights *type_property_weights = + spec_proto.add_type_property_weights(); + type_property_weights->set_schema_type("email"); + + // If this property weight isn't skipped, then the max property weight would + // be set to 100.0 and all weights would be normalized against the max. + PropertyWeight *non_valid_property_weight = + type_property_weights->add_property_weights(); + non_valid_property_weight->set_weight(100.0); + non_valid_property_weight->set_path("sender.organization"); + + PropertyWeight *subject_property_weight = + type_property_weights->add_property_weights(); + subject_property_weight->set_weight(10.0); + subject_property_weight->set_path("subject"); + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<SectionWeights> section_weights, + SectionWeights::Create(schema_store(), spec_proto)); + ICING_ASSERT_OK_AND_ASSIGN(SchemaTypeId email_schema_type_id, + schema_store()->GetSchemaTypeId("email")); + + // Normalized weight for "body" property. Because the weight is not explicitly + // set, it is set to the default of 1.0 before being normalized. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/0), + Eq(0.1)); + // Normalized weight for "sender.name" property (the nested property). Because + // the weight is not explicitly set, it is set to the default of 1.0 before + // being normalized. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/1), + Eq(0.1)); + // Normalized weight for "subject" property. Because the invalid property path + // is skipped when assigning weights, subject takes the max normalized weight + // of 1.0 instead. + EXPECT_THAT(section_weights->GetNormalizedSectionWeight(email_schema_type_id, + /*section_id=*/2), + Eq(1.0)); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/testing/random-string.cc b/icing/testing/random-string.cc new file mode 100644 index 0000000..27f83bc --- /dev/null +++ b/icing/testing/random-string.cc @@ -0,0 +1,54 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/testing/random-string.h" + +namespace icing { +namespace lib { + +std::vector<std::string> GenerateUniqueTerms(int num_terms) { + char before_a = 'a' - 1; + std::string term(1, before_a); + std::vector<std::string> terms; + int current_char = 0; + for (int permutation = 0; permutation < num_terms; ++permutation) { + if (term[current_char] != 'z') { + ++term[current_char]; + } else { + if (current_char < term.length() - 1) { + // The string currently looks something like this "zzzaa" + // 1. Find the first char after this one that isn't + current_char = term.find_first_not_of('z', current_char); + if (current_char != std::string::npos) { + // 2. Increment that character + ++term[current_char]; + + // 3. Set every character prior to current_char to 'a' + term.replace(0, current_char, current_char, 'a'); + } else { + // Every character in this string is a 'z'. We need to grow. + term = std::string(term.length() + 1, 'a'); + } + } else { + term = std::string(term.length() + 1, 'a'); + } + current_char = 0; + } + terms.push_back(term); + } + return terms; +} + +} // namespace lib +} // namespace icing diff --git a/icing/testing/random-string.h b/icing/testing/random-string.h index ac36924..3165bf6 100644 --- a/icing/testing/random-string.h +++ b/icing/testing/random-string.h @@ -36,6 +36,10 @@ std::string RandomString(const std::string_view alphabet, size_t len, return result; } +// Returns a vector containing num_terms unique terms. Terms are created in +// non-random order starting with "a" to "z" to "aa" to "zz", etc. +std::vector<std::string> GenerateUniqueTerms(int num_terms); + } // namespace lib } // namespace icing diff --git a/icing/testing/random-string_test.cc b/icing/testing/random-string_test.cc new file mode 100644 index 0000000..759fec0 --- /dev/null +++ b/icing/testing/random-string_test.cc @@ -0,0 +1,54 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/testing/random-string.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using ::testing::ElementsAre; +using ::testing::Eq; +using ::testing::IsEmpty; + +namespace icing { +namespace lib { + +namespace { + +TEST(RandomStringTest, GenerateUniqueTerms) { + EXPECT_THAT(GenerateUniqueTerms(0), IsEmpty()); + EXPECT_THAT(GenerateUniqueTerms(1), ElementsAre("a")); + EXPECT_THAT(GenerateUniqueTerms(4), ElementsAre("a", "b", "c", "d")); + EXPECT_THAT(GenerateUniqueTerms(29), + ElementsAre("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", + "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", + "w", "x", "y", "z", "aa", "ba", "ca")); + EXPECT_THAT(GenerateUniqueTerms(56), + ElementsAre("a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", + "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", + "w", "x", "y", "z", "aa", "ba", "ca", "da", "ea", + "fa", "ga", "ha", "ia", "ja", "ka", "la", "ma", "na", + "oa", "pa", "qa", "ra", "sa", "ta", "ua", "va", "wa", + "xa", "ya", "za", "ab", "bb", "cb", "db")); + EXPECT_THAT(GenerateUniqueTerms(56).at(54), Eq("cb")); + EXPECT_THAT(GenerateUniqueTerms(26 * 26 * 26).at(26), Eq("aa")); + EXPECT_THAT(GenerateUniqueTerms(26 * 26 * 26).at(26 * 27), Eq("aaa")); + EXPECT_THAT(GenerateUniqueTerms(26 * 26 * 26).at(26 * 27 - 6), Eq("uz")); + EXPECT_THAT(GenerateUniqueTerms(26 * 26 * 26).at(26 * 27 + 5), Eq("faa")); +} + +} // namespace + +} // namespace lib +} // namespace icing diff --git a/icing/testing/snippet-helpers.cc b/icing/testing/snippet-helpers.cc index cfd20c2..7a71987 100644 --- a/icing/testing/snippet-helpers.cc +++ b/icing/testing/snippet-helpers.cc @@ -77,6 +77,16 @@ std::vector<std::string_view> GetMatches( return matches; } +std::vector<std::string_view> GetSubMatches( + std::string_view content, const SnippetProto::EntryProto& snippet_proto) { + std::vector<std::string_view> matches; + for (const SnippetMatchProto& match : snippet_proto.snippet_matches()) { + matches.push_back(content.substr(match.exact_match_byte_position(), + match.submatch_byte_length())); + } + return matches; +} + std::string_view GetString(const DocumentProto* document, std::string_view property_path) { std::vector<std::string_view> properties = diff --git a/icing/testing/snippet-helpers.h b/icing/testing/snippet-helpers.h index defadeb..73b2ce2 100644 --- a/icing/testing/snippet-helpers.h +++ b/icing/testing/snippet-helpers.h @@ -40,6 +40,10 @@ std::vector<std::string_view> GetWindows( std::vector<std::string_view> GetMatches( std::string_view content, const SnippetProto::EntryProto& snippet_proto); +// Retrieves all submatches defined by the snippet_proto for the content. +std::vector<std::string_view> GetSubMatches( + std::string_view content, const SnippetProto::EntryProto& snippet_proto); + // Retrieves the string value held in the document corresponding to the // property_path. // Example: diff --git a/icing/tokenization/icu/icu-language-segmenter.cc b/icing/tokenization/icu/icu-language-segmenter.cc index cb31441..598ede7 100644 --- a/icing/tokenization/icu/icu-language-segmenter.cc +++ b/icing/tokenization/icu/icu-language-segmenter.cc @@ -300,9 +300,10 @@ class IcuLanguageSegmenterIterator : public LanguageSegmenter::Iterator { UChar32 uchar32 = i18n_utils::GetUChar32At(text_.data(), text_.length(), term_start_index_); - // Rule 2: for non-ASCII terms, only the alphabetic terms are returned. - // We know it's an alphabetic term by checking the first unicode character. - if (u_isUAlphabetic(uchar32)) { + // Rule 2: for non-ASCII terms, only the alphanumeric terms are returned. + // We know it's an alphanumeric term by checking the first unicode + // character. + if (i18n_utils::IsAlphaNumeric(uchar32)) { return true; } return false; diff --git a/icing/tokenization/icu/icu-language-segmenter_test.cc b/icing/tokenization/icu/icu-language-segmenter_test.cc index 01eb7d8..3090087 100644 --- a/icing/tokenization/icu/icu-language-segmenter_test.cc +++ b/icing/tokenization/icu/icu-language-segmenter_test.cc @@ -372,6 +372,15 @@ TEST_P(IcuLanguageSegmenterAllLocalesTest, Number) { IsOkAndHolds(ElementsAre("-", "123"))); } +TEST_P(IcuLanguageSegmenterAllLocalesTest, FullWidthNumbers) { + ICING_ASSERT_OK_AND_ASSIGN( + auto language_segmenter, + language_segmenter_factory::Create( + GetSegmenterOptions(GetLocale(), jni_cache_.get()))); + EXPECT_THAT(language_segmenter->GetAllTerms("0123456789"), + IsOkAndHolds(ElementsAre("0123456789"))); +} + TEST_P(IcuLanguageSegmenterAllLocalesTest, ContinuousWhitespaces) { ICING_ASSERT_OK_AND_ASSIGN( auto language_segmenter, diff --git a/icing/tokenization/raw-query-tokenizer.cc b/icing/tokenization/raw-query-tokenizer.cc index 205d3a2..2d461ee 100644 --- a/icing/tokenization/raw-query-tokenizer.cc +++ b/icing/tokenization/raw-query-tokenizer.cc @@ -14,9 +14,8 @@ #include "icing/tokenization/raw-query-tokenizer.h" -#include <stddef.h> - #include <cctype> +#include <cstddef> #include <memory> #include <string> #include <string_view> diff --git a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc index 6b1cb3a..8e1e563 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-break-iterator.cc @@ -15,10 +15,10 @@ #include "icing/tokenization/reverse_jni/reverse-jni-break-iterator.h" #include <jni.h> -#include <math.h> #include <cassert> #include <cctype> +#include <cmath> #include <map> #include "icing/jni/jni-cache.h" diff --git a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.cc b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.cc index 76219b5..b936f2b 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter.cc @@ -291,9 +291,12 @@ class ReverseJniLanguageSegmenterIterator : public LanguageSegmenter::Iterator { return true; } - // Rule 2: for non-ASCII terms, only the alphabetic terms are returned. - // We know it's an alphabetic term by checking the first unicode character. - if (i18n_utils::IsAlphabeticAt(text_, term_start_.utf8_index())) { + UChar32 uchar32 = i18n_utils::GetUChar32At(text_.data(), text_.length(), + term_start_.utf8_index()); + // Rule 2: for non-ASCII terms, only the alphanumeric terms are returned. + // We know it's an alphanumeric term by checking the first unicode + // character. + if (i18n_utils::IsAlphaNumeric(uchar32)) { return true; } return false; diff --git a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc index b1a8f72..45d6475 100644 --- a/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc +++ b/icing/tokenization/reverse_jni/reverse-jni-language-segmenter_test.cc @@ -366,6 +366,17 @@ TEST_P(ReverseJniLanguageSegmenterTest, Number) { IsOkAndHolds(ElementsAre("-", "123"))); } +TEST_P(ReverseJniLanguageSegmenterTest, FullWidthNumbers) { + ICING_ASSERT_OK_AND_ASSIGN( + auto language_segmenter, + language_segmenter_factory::Create( + GetSegmenterOptions(GetLocale(), jni_cache_.get()))); + + EXPECT_THAT(language_segmenter->GetAllTerms("0123456789"), + IsOkAndHolds(ElementsAre("0", "1", "2", "3", "4", "5", "6", + "7", "8", "9"))); +} + TEST_P(ReverseJniLanguageSegmenterTest, ContinuousWhitespaces) { ICING_ASSERT_OK_AND_ASSIGN( auto language_segmenter, diff --git a/icing/transform/icu/icu-normalizer.cc b/icing/transform/icu/icu-normalizer.cc index eb0eead..aceb11d 100644 --- a/icing/transform/icu/icu-normalizer.cc +++ b/icing/transform/icu/icu-normalizer.cc @@ -29,6 +29,7 @@ #include "icing/util/status-macros.h" #include "unicode/umachine.h" #include "unicode/unorm2.h" +#include "unicode/ustring.h" #include "unicode/utrans.h" namespace icing { @@ -157,14 +158,18 @@ std::string IcuNormalizer::NormalizeLatin(const UNormalizer2* normalizer2, const std::string_view term) const { std::string result; result.reserve(term.length()); - for (int i = 0; i < term.length(); i++) { - if (i18n_utils::IsAscii(term[i])) { - result.push_back(std::tolower(term[i])); - } else if (i18n_utils::IsLeadUtf8Byte(term[i])) { - UChar32 uchar32 = i18n_utils::GetUChar32At(term.data(), term.length(), i); + int current_pos = 0; + while (current_pos < term.length()) { + if (i18n_utils::IsAscii(term[current_pos])) { + result.push_back(std::tolower(term[current_pos])); + ++current_pos; + } else { + UChar32 uchar32 = + i18n_utils::GetUChar32At(term.data(), term.length(), current_pos); if (uchar32 == i18n_utils::kInvalidUChar32) { ICING_LOG(WARNING) << "Unable to get uchar32 from " << term - << " at position" << i; + << " at position" << current_pos; + current_pos += i18n_utils::GetUtf8Length(uchar32); continue; } char ascii_char; @@ -177,8 +182,9 @@ std::string IcuNormalizer::NormalizeLatin(const UNormalizer2* normalizer2, // tokenized. We handle it here in case there're something wrong with // the tokenizers. int utf8_length = i18n_utils::GetUtf8Length(uchar32); - absl_ports::StrAppend(&result, term.substr(i, utf8_length)); + absl_ports::StrAppend(&result, term.substr(current_pos, utf8_length)); } + current_pos += i18n_utils::GetUtf8Length(uchar32); } } @@ -261,5 +267,106 @@ std::string IcuNormalizer::TermTransformer::Transform( return std::move(utf8_term_or).ValueOrDie(); } +CharacterIterator FindNormalizedLatinMatchEndPosition( + const UNormalizer2* normalizer2, std::string_view term, + CharacterIterator char_itr, std::string_view normalized_term) { + CharacterIterator normalized_char_itr(normalized_term); + char ascii_char; + while (char_itr.utf8_index() < term.length() && + normalized_char_itr.utf8_index() < normalized_term.length()) { + UChar32 c = char_itr.GetCurrentChar(); + if (i18n_utils::IsAscii(c)) { + c = std::tolower(c); + } else if (DiacriticCharToAscii(normalizer2, c, &ascii_char)) { + c = ascii_char; + } + UChar32 normalized_c = normalized_char_itr.GetCurrentChar(); + if (c != normalized_c) { + return char_itr; + } + char_itr.AdvanceToUtf32(char_itr.utf32_index() + 1); + normalized_char_itr.AdvanceToUtf32(normalized_char_itr.utf32_index() + 1); + } + return char_itr; +} + +CharacterIterator +IcuNormalizer::TermTransformer::FindNormalizedNonLatinMatchEndPosition( + std::string_view term, CharacterIterator char_itr, + std::string_view normalized_term) const { + CharacterIterator normalized_char_itr(normalized_term); + UErrorCode status = U_ZERO_ERROR; + + constexpr int kUtf16CharBufferLength = 6; + UChar c16[kUtf16CharBufferLength]; + int32_t c16_length; + int32_t limit; + + constexpr int kCharBufferLength = 3 * 4; + char normalized_buffer[kCharBufferLength]; + int32_t c8_length; + while (char_itr.utf8_index() < term.length() && + normalized_char_itr.utf8_index() < normalized_term.length()) { + UChar32 c = char_itr.GetCurrentChar(); + int c_lenth = i18n_utils::GetUtf8Length(c); + u_strFromUTF8(c16, kUtf16CharBufferLength, &c16_length, + term.data() + char_itr.utf8_index(), + /*srcLength=*/c_lenth, &status); + if (U_FAILURE(status)) { + break; + } + + limit = c16_length; + utrans_transUChars(u_transliterator_, c16, &c16_length, + kUtf16CharBufferLength, + /*start=*/0, &limit, &status); + if (U_FAILURE(status)) { + break; + } + + u_strToUTF8(normalized_buffer, kCharBufferLength, &c8_length, c16, + c16_length, &status); + if (U_FAILURE(status)) { + break; + } + + for (int i = 0; i < c8_length; ++i) { + if (normalized_buffer[i] != + normalized_term[normalized_char_itr.utf8_index() + i]) { + return char_itr; + } + } + normalized_char_itr.AdvanceToUtf8(normalized_char_itr.utf8_index() + + c8_length); + char_itr.AdvanceToUtf32(char_itr.utf32_index() + 1); + } + if (U_FAILURE(status)) { + // Failed to transform, return its original form. + ICING_LOG(WARNING) << "Failed to normalize UTF8 term: " << term; + } + return char_itr; +} + +CharacterIterator IcuNormalizer::FindNormalizedMatchEndPosition( + std::string_view term, std::string_view normalized_term) const { + UErrorCode status = U_ZERO_ERROR; + // ICU manages the singleton instance + const UNormalizer2* normalizer2 = unorm2_getNFCInstance(&status); + if (U_FAILURE(status)) { + ICING_LOG(WARNING) << "Failed to create a UNormalizer2 instance"; + } + + CharacterIterator char_itr(term); + UChar32 first_uchar32 = char_itr.GetCurrentChar(); + if (normalizer2 != nullptr && first_uchar32 != i18n_utils::kInvalidUChar32 && + DiacriticCharToAscii(normalizer2, first_uchar32, /*char_out=*/nullptr)) { + return FindNormalizedLatinMatchEndPosition(normalizer2, term, char_itr, + normalized_term); + } else { + return term_transformer_->FindNormalizedNonLatinMatchEndPosition( + term, char_itr, normalized_term); + } +} + } // namespace lib } // namespace icing diff --git a/icing/transform/icu/icu-normalizer.h b/icing/transform/icu/icu-normalizer.h index f20a9fb..d4f1ebd 100644 --- a/icing/transform/icu/icu-normalizer.h +++ b/icing/transform/icu/icu-normalizer.h @@ -21,6 +21,7 @@ #include "icing/text_classifier/lib3/utils/base/statusor.h" #include "icing/transform/normalizer.h" +#include "icing/util/character-iterator.h" #include "unicode/unorm2.h" #include "unicode/utrans.h" @@ -56,6 +57,17 @@ class IcuNormalizer : public Normalizer { // result in the non-Latin characters not properly being normalized std::string NormalizeTerm(std::string_view term) const override; + // Returns a CharacterIterator pointing to one past the end of the segment of + // term that (once normalized) matches with normalized_term. + // + // Ex. FindNormalizedMatchEndPosition("YELLOW", "yell") will return + // CharacterIterator(u8:4, u16:4, u32:4). + // + // Ex. FindNormalizedMatchEndPosition("YELLOW", "red") will return + // CharacterIterator(u8:0, u16:0, u32:0). + CharacterIterator FindNormalizedMatchEndPosition( + std::string_view term, std::string_view normalized_term) const override; + private: // A handler class that helps manage the lifecycle of UTransliterator. It's // used in IcuNormalizer to transform terms into the formats we need. @@ -75,6 +87,12 @@ class IcuNormalizer : public Normalizer { // Transforms the text based on our rules described at top of this file std::string Transform(std::string_view term) const; + // Returns a CharacterIterator pointing to one past the end of the segment + // of a non-latin term that (once normalized) matches with normalized_term. + CharacterIterator FindNormalizedNonLatinMatchEndPosition( + std::string_view term, CharacterIterator char_itr, + std::string_view normalized_term) const; + private: explicit TermTransformer(UTransliterator* u_transliterator); diff --git a/icing/transform/icu/icu-normalizer_benchmark.cc b/icing/transform/icu/icu-normalizer_benchmark.cc index b037538..8d09be2 100644 --- a/icing/transform/icu/icu-normalizer_benchmark.cc +++ b/icing/transform/icu/icu-normalizer_benchmark.cc @@ -161,6 +161,124 @@ BENCHMARK(BM_NormalizeHiragana) ->Arg(2048000) ->Arg(4096000); +void BM_UppercaseSubTokenLength(benchmark::State& state) { + bool run_via_adb = absl::GetFlag(FLAGS_adb); + if (!run_via_adb) { + ICING_ASSERT_OK(icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Normalizer> normalizer, + normalizer_factory::Create( + + /*max_term_byte_size=*/std::numeric_limits<int>::max())); + + std::string input_string(state.range(0), 'A'); + std::string normalized_input_string(state.range(0), 'a'); + for (auto _ : state) { + normalizer->FindNormalizedMatchEndPosition(input_string, + normalized_input_string); + } +} +BENCHMARK(BM_UppercaseSubTokenLength) + ->Arg(1000) + ->Arg(2000) + ->Arg(4000) + ->Arg(8000) + ->Arg(16000) + ->Arg(32000) + ->Arg(64000) + ->Arg(128000) + ->Arg(256000) + ->Arg(384000) + ->Arg(512000) + ->Arg(1024000) + ->Arg(2048000) + ->Arg(4096000); + +void BM_AccentSubTokenLength(benchmark::State& state) { + bool run_via_adb = absl::GetFlag(FLAGS_adb); + if (!run_via_adb) { + ICING_ASSERT_OK(icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Normalizer> normalizer, + normalizer_factory::Create( + + /*max_term_byte_size=*/std::numeric_limits<int>::max())); + + std::string input_string; + std::string normalized_input_string; + while (input_string.length() < state.range(0)) { + input_string.append("àáâãā"); + normalized_input_string.append("aaaaa"); + } + + for (auto _ : state) { + normalizer->FindNormalizedMatchEndPosition(input_string, + normalized_input_string); + } +} +BENCHMARK(BM_AccentSubTokenLength) + ->Arg(1000) + ->Arg(2000) + ->Arg(4000) + ->Arg(8000) + ->Arg(16000) + ->Arg(32000) + ->Arg(64000) + ->Arg(128000) + ->Arg(256000) + ->Arg(384000) + ->Arg(512000) + ->Arg(1024000) + ->Arg(2048000) + ->Arg(4096000); + +void BM_HiraganaSubTokenLength(benchmark::State& state) { + bool run_via_adb = absl::GetFlag(FLAGS_adb); + if (!run_via_adb) { + ICING_ASSERT_OK(icu_data_file_helper::SetUpICUDataFile( + GetTestFilePath("icing/icu.dat"))); + } + + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Normalizer> normalizer, + normalizer_factory::Create( + + /*max_term_byte_size=*/std::numeric_limits<int>::max())); + + std::string input_string; + std::string normalized_input_string; + while (input_string.length() < state.range(0)) { + input_string.append("あいうえお"); + normalized_input_string.append("アイウエオ"); + } + + for (auto _ : state) { + normalizer->FindNormalizedMatchEndPosition(input_string, + normalized_input_string); + } +} +BENCHMARK(BM_HiraganaSubTokenLength) + ->Arg(1000) + ->Arg(2000) + ->Arg(4000) + ->Arg(8000) + ->Arg(16000) + ->Arg(32000) + ->Arg(64000) + ->Arg(128000) + ->Arg(256000) + ->Arg(384000) + ->Arg(512000) + ->Arg(1024000) + ->Arg(2048000) + ->Arg(4096000); + } // namespace } // namespace lib diff --git a/icing/transform/icu/icu-normalizer_test.cc b/icing/transform/icu/icu-normalizer_test.cc index f5d20ff..a46fcc7 100644 --- a/icing/transform/icu/icu-normalizer_test.cc +++ b/icing/transform/icu/icu-normalizer_test.cc @@ -231,6 +231,104 @@ TEST_F(IcuNormalizerTest, Truncate) { } } +TEST_F(IcuNormalizerTest, PrefixMatchLength) { + // Verify that FindNormalizedMatchEndPosition will properly find the length of + // the prefix match when given a non-normalized term and a normalized term + // is a prefix of the non-normalized one. + ICING_ASSERT_OK_AND_ASSIGN(auto normalizer, normalizer_factory::Create( + /*max_term_byte_size=*/1000)); + + // Upper to lower + std::string term = "MDI"; + CharacterIterator match_end = + normalizer->FindNormalizedMatchEndPosition(term, "md"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("MD")); + + term = "Icing"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "icin"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Icin")); + + // Full-width + term = "525600"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "525"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("525")); + + term = "FULLWIDTH"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "full"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("FULL")); + + // Hiragana to Katakana + term = "あいうえお"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "アイ"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("あい")); + + term = "かきくけこ"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "カ"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("か")); + + // Latin accents + term = "Zürich"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "zur"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Zür")); + + term = "après-midi"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "apre"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("aprè")); + + term = "Buenos días"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "buenos di"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Buenos dí")); +} + +TEST_F(IcuNormalizerTest, SharedPrefixMatchLength) { + // Verify that FindNormalizedMatchEndPosition will properly find the length of + // the prefix match when given a non-normalized term and a normalized term + // that share a common prefix. + ICING_ASSERT_OK_AND_ASSIGN(auto normalizer, normalizer_factory::Create( + /*max_term_byte_size=*/1000)); + + // Upper to lower + std::string term = "MDI"; + CharacterIterator match_end = + normalizer->FindNormalizedMatchEndPosition(term, "mgm"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("M")); + + term = "Icing"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "icky"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Ic")); + + // Full-width + term = "525600"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "525788"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("525")); + + term = "FULLWIDTH"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "fully"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("FULL")); + + // Hiragana to Katakana + term = "あいうえお"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "アイエオ"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("あい")); + + term = "かきくけこ"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "カケコ"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("か")); + + // Latin accents + term = "Zürich"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "zurg"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Zür")); + + term = "après-midi"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "apreciate"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("aprè")); + + term = "días"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "diamond"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("día")); +} + } // namespace } // namespace lib } // namespace icing diff --git a/icing/transform/map/map-normalizer.cc b/icing/transform/map/map-normalizer.cc index c888551..61fce65 100644 --- a/icing/transform/map/map-normalizer.cc +++ b/icing/transform/map/map-normalizer.cc @@ -14,8 +14,7 @@ #include "icing/transform/map/map-normalizer.h" -#include <ctype.h> - +#include <cctype> #include <string> #include <string_view> #include <unordered_map> @@ -23,6 +22,7 @@ #include "icing/absl_ports/str_cat.h" #include "icing/transform/map/normalization-map.h" +#include "icing/util/character-iterator.h" #include "icing/util/i18n-utils.h" #include "icing/util/logging.h" #include "unicode/utypes.h" @@ -30,48 +30,70 @@ namespace icing { namespace lib { +namespace { + +UChar32 NormalizeChar(UChar32 c) { + if (i18n_utils::GetUtf16Length(c) > 1) { + // All the characters we need to normalize can be encoded into a + // single char16_t. If this character needs more than 1 char16_t code + // unit, we can skip normalization and append it directly. + return c; + } + + // The original character can be encoded into a single char16_t. + const std::unordered_map<char16_t, char16_t>* normalization_map = + GetNormalizationMap(); + if (normalization_map == nullptr) { + // Normalization map couldn't be properly initialized, append the original + // character. + ICING_LOG(WARNING) << "Unable to get a valid pointer to normalization map!"; + return c; + } + auto iterator = normalization_map->find(static_cast<char16_t>(c)); + if (iterator == normalization_map->end()) { + // Normalization mapping not found, append the original character. + return c; + } + + // Found a normalization mapping. The normalized character (stored in a + // char16_t) can have 1 or 2 bytes. + if (i18n_utils::IsAscii(iterator->second)) { + // The normalized character has 1 byte. It may be an upper-case char. + // Lower-case it before returning it. + return std::tolower(static_cast<char>(iterator->second)); + } else { + return iterator->second; + } +} + +} // namespace + std::string MapNormalizer::NormalizeTerm(std::string_view term) const { std::string normalized_text; normalized_text.reserve(term.length()); - for (int i = 0; i < term.length(); ++i) { - if (i18n_utils::IsAscii(term[i])) { - // The original character has 1 byte. - normalized_text.push_back(std::tolower(term[i])); - } else if (i18n_utils::IsLeadUtf8Byte(term[i])) { - UChar32 uchar32 = i18n_utils::GetUChar32At(term.data(), term.length(), i); + int current_pos = 0; + while (current_pos < term.length()) { + if (i18n_utils::IsAscii(term[current_pos])) { + normalized_text.push_back(std::tolower(term[current_pos])); + ++current_pos; + } else { + UChar32 uchar32 = + i18n_utils::GetUChar32At(term.data(), term.length(), current_pos); if (uchar32 == i18n_utils::kInvalidUChar32) { ICING_LOG(WARNING) << "Unable to get uchar32 from " << term - << " at position" << i; + << " at position" << current_pos; + ++current_pos; continue; } - int utf8_length = i18n_utils::GetUtf8Length(uchar32); - if (i18n_utils::GetUtf16Length(uchar32) > 1) { - // All the characters we need to normalize can be encoded into a - // single char16_t. If this character needs more than 1 char16_t code - // unit, we can skip normalization and append it directly. - absl_ports::StrAppend(&normalized_text, term.substr(i, utf8_length)); - continue; - } - // The original character can be encoded into a single char16_t. - const std::unordered_map<char16_t, char16_t>& normalization_map = - GetNormalizationMap(); - auto iterator = normalization_map.find(static_cast<char16_t>(uchar32)); - if (iterator != normalization_map.end()) { - // Found a normalization mapping. The normalized character (stored in a - // char16_t) can have 1 or 2 bytes. - if (i18n_utils::IsAscii(iterator->second)) { - // The normalized character has 1 byte. - normalized_text.push_back( - std::tolower(static_cast<char>(iterator->second))); - } else { - // The normalized character has 2 bytes. - i18n_utils::AppendUchar32ToUtf8(&normalized_text, iterator->second); - } + UChar32 normalized_char32 = NormalizeChar(uchar32); + if (i18n_utils::IsAscii(normalized_char32)) { + normalized_text.push_back(normalized_char32); } else { - // Normalization mapping not found, append the original character. - absl_ports::StrAppend(&normalized_text, term.substr(i, utf8_length)); + // The normalized character has 2 bytes. + i18n_utils::AppendUchar32ToUtf8(&normalized_text, normalized_char32); } + current_pos += i18n_utils::GetUtf8Length(uchar32); } } @@ -82,5 +104,27 @@ std::string MapNormalizer::NormalizeTerm(std::string_view term) const { return normalized_text; } +CharacterIterator MapNormalizer::FindNormalizedMatchEndPosition( + std::string_view term, std::string_view normalized_term) const { + CharacterIterator char_itr(term); + CharacterIterator normalized_char_itr(normalized_term); + while (char_itr.utf8_index() < term.length() && + normalized_char_itr.utf8_index() < normalized_term.length()) { + UChar32 c = char_itr.GetCurrentChar(); + if (i18n_utils::IsAscii(c)) { + c = std::tolower(c); + } else { + c = NormalizeChar(c); + } + UChar32 normalized_c = normalized_char_itr.GetCurrentChar(); + if (c != normalized_c) { + return char_itr; + } + char_itr.AdvanceToUtf32(char_itr.utf32_index() + 1); + normalized_char_itr.AdvanceToUtf32(normalized_char_itr.utf32_index() + 1); + } + return char_itr; +} + } // namespace lib } // namespace icing diff --git a/icing/transform/map/map-normalizer.h b/icing/transform/map/map-normalizer.h index f9c0e42..ed996ae 100644 --- a/icing/transform/map/map-normalizer.h +++ b/icing/transform/map/map-normalizer.h @@ -19,6 +19,7 @@ #include <string_view> #include "icing/transform/normalizer.h" +#include "icing/util/character-iterator.h" namespace icing { namespace lib { @@ -39,6 +40,17 @@ class MapNormalizer : public Normalizer { // Read more mapping details in normalization-map.cc std::string NormalizeTerm(std::string_view term) const override; + // Returns a CharacterIterator pointing to one past the end of the segment of + // term that (once normalized) matches with normalized_term. + // + // Ex. FindNormalizedMatchEndPosition("YELLOW", "yell") will return + // CharacterIterator(u8:4, u16:4, u32:4). + // + // Ex. FindNormalizedMatchEndPosition("YELLOW", "red") will return + // CharacterIterator(u8:0, u16:0, u32:0). + CharacterIterator FindNormalizedMatchEndPosition( + std::string_view term, std::string_view normalized_term) const override; + private: // The maximum term length allowed after normalization. int max_term_byte_size_; diff --git a/icing/transform/map/map-normalizer_benchmark.cc b/icing/transform/map/map-normalizer_benchmark.cc index 691afc6..8268541 100644 --- a/icing/transform/map/map-normalizer_benchmark.cc +++ b/icing/transform/map/map-normalizer_benchmark.cc @@ -143,6 +143,104 @@ BENCHMARK(BM_NormalizeHiragana) ->Arg(2048000) ->Arg(4096000); +void BM_UppercaseSubTokenLength(benchmark::State& state) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Normalizer> normalizer, + normalizer_factory::Create( + + /*max_term_byte_size=*/std::numeric_limits<int>::max())); + + std::string input_string(state.range(0), 'A'); + std::string normalized_input_string(state.range(0), 'a'); + for (auto _ : state) { + normalizer->FindNormalizedMatchEndPosition(input_string, + normalized_input_string); + } +} +BENCHMARK(BM_UppercaseSubTokenLength) + ->Arg(1000) + ->Arg(2000) + ->Arg(4000) + ->Arg(8000) + ->Arg(16000) + ->Arg(32000) + ->Arg(64000) + ->Arg(128000) + ->Arg(256000) + ->Arg(384000) + ->Arg(512000) + ->Arg(1024000) + ->Arg(2048000) + ->Arg(4096000); + +void BM_AccentSubTokenLength(benchmark::State& state) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Normalizer> normalizer, + normalizer_factory::Create( + /*max_term_byte_size=*/std::numeric_limits<int>::max())); + + std::string input_string; + std::string normalized_input_string; + while (input_string.length() < state.range(0)) { + input_string.append("àáâãā"); + normalized_input_string.append("aaaaa"); + } + + for (auto _ : state) { + normalizer->FindNormalizedMatchEndPosition(input_string, + normalized_input_string); + } +} +BENCHMARK(BM_AccentSubTokenLength) + ->Arg(1000) + ->Arg(2000) + ->Arg(4000) + ->Arg(8000) + ->Arg(16000) + ->Arg(32000) + ->Arg(64000) + ->Arg(128000) + ->Arg(256000) + ->Arg(384000) + ->Arg(512000) + ->Arg(1024000) + ->Arg(2048000) + ->Arg(4096000); + +void BM_HiraganaSubTokenLength(benchmark::State& state) { + ICING_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<Normalizer> normalizer, + normalizer_factory::Create( + /*max_term_byte_size=*/std::numeric_limits<int>::max())); + + std::string input_string; + std::string normalized_input_string; + while (input_string.length() < state.range(0)) { + input_string.append("あいうえお"); + normalized_input_string.append("アイウエオ"); + } + + for (auto _ : state) { + normalizer->FindNormalizedMatchEndPosition(input_string, + normalized_input_string); + } +} +BENCHMARK(BM_HiraganaSubTokenLength) + ->Arg(1000) + ->Arg(2000) + ->Arg(4000) + ->Arg(8000) + ->Arg(16000) + ->Arg(32000) + ->Arg(64000) + ->Arg(128000) + ->Arg(256000) + ->Arg(384000) + ->Arg(512000) + ->Arg(1024000) + ->Arg(2048000) + ->Arg(4096000); + } // namespace } // namespace lib diff --git a/icing/transform/map/map-normalizer_test.cc b/icing/transform/map/map-normalizer_test.cc index b62ae0e..adc5623 100644 --- a/icing/transform/map/map-normalizer_test.cc +++ b/icing/transform/map/map-normalizer_test.cc @@ -23,6 +23,7 @@ #include "icing/testing/icu-i18n-test-utils.h" #include "icing/transform/normalizer-factory.h" #include "icing/transform/normalizer.h" +#include "icing/util/character-iterator.h" namespace icing { namespace lib { @@ -199,6 +200,104 @@ TEST(MapNormalizerTest, Truncate) { } } +TEST(MapNormalizerTest, PrefixMatchLength) { + // Verify that FindNormalizedMatchEndPosition will properly find the length of + // the prefix match when given a non-normalized term and a normalized term + // is a prefix of the non-normalized one. + ICING_ASSERT_OK_AND_ASSIGN(auto normalizer, normalizer_factory::Create( + /*max_term_byte_size=*/1000)); + + // Upper to lower + std::string term = "MDI"; + CharacterIterator match_end = + normalizer->FindNormalizedMatchEndPosition(term, "md"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("MD")); + + term = "Icing"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "icin"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Icin")); + + // Full-width + term = "525600"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "525"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("525")); + + term = "FULLWIDTH"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "full"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("FULL")); + + // Hiragana to Katakana + term = "あいうえお"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "アイ"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("あい")); + + term = "かきくけこ"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "カ"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("か")); + + // Latin accents + term = "Zürich"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "zur"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Zür")); + + term = "après-midi"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "apre"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("aprè")); + + term = "Buenos días"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "buenos di"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Buenos dí")); +} + +TEST(MapNormalizerTest, SharedPrefixMatchLength) { + // Verify that FindNormalizedMatchEndPosition will properly find the length of + // the prefix match when given a non-normalized term and a normalized term + // that share a common prefix. + ICING_ASSERT_OK_AND_ASSIGN(auto normalizer, normalizer_factory::Create( + /*max_term_byte_size=*/1000)); + + // Upper to lower + std::string term = "MDI"; + CharacterIterator match_end = + normalizer->FindNormalizedMatchEndPosition(term, "mgm"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("M")); + + term = "Icing"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "icky"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Ic")); + + // Full-width + term = "525600"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "525788"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("525")); + + term = "FULLWIDTH"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "fully"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("FULL")); + + // Hiragana to Katakana + term = "あいうえお"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "アイエオ"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("あい")); + + term = "かきくけこ"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "カケコ"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("か")); + + // Latin accents + term = "Zürich"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "zurg"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("Zür")); + + term = "après-midi"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "apreciate"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("aprè")); + + term = "días"; + match_end = normalizer->FindNormalizedMatchEndPosition(term, "diamond"); + EXPECT_THAT(term.substr(0, match_end.utf8_index()), Eq("día")); +} + } // namespace } // namespace lib diff --git a/icing/transform/map/normalization-map.cc b/icing/transform/map/normalization-map.cc index c318036..0994ab8 100644 --- a/icing/transform/map/normalization-map.cc +++ b/icing/transform/map/normalization-map.cc @@ -691,19 +691,21 @@ constexpr NormalizationPair kNormalizationMappings[] = { } // namespace -const std::unordered_map<char16_t, char16_t>& GetNormalizationMap() { +const std::unordered_map<char16_t, char16_t> *GetNormalizationMap() { // The map is allocated dynamically the first time this function is executed. - static const std::unordered_map<char16_t, char16_t> normalization_map = [] { - std::unordered_map<char16_t, char16_t> map; - // Size of all the mappings is about 2.5 KiB. - constexpr int numMappings = - sizeof(kNormalizationMappings) / sizeof(NormalizationPair); - map.reserve(numMappings); - for (size_t i = 0; i < numMappings; ++i) { - map.emplace(kNormalizationMappings[i].from, kNormalizationMappings[i].to); - } - return map; - }(); + static const std::unordered_map<char16_t, char16_t> *const normalization_map = + [] { + auto *map = new std::unordered_map<char16_t, char16_t>(); + // Size of all the mappings is about 2.5 KiB. + constexpr int numMappings = + sizeof(kNormalizationMappings) / sizeof(NormalizationPair); + map->reserve(numMappings); + for (size_t i = 0; i < numMappings; ++i) { + map->emplace(kNormalizationMappings[i].from, + kNormalizationMappings[i].to); + } + return map; + }(); return normalization_map; } diff --git a/icing/transform/map/normalization-map.h b/icing/transform/map/normalization-map.h index aea85bd..ac7872b 100644 --- a/icing/transform/map/normalization-map.h +++ b/icing/transform/map/normalization-map.h @@ -23,7 +23,7 @@ namespace lib { // Returns a map containing normalization mappings. A mapping (A -> B) means // that we'll transform every character 'A' into 'B'. See normalization-map.cc // for mapping details. -const std::unordered_map<char16_t, char16_t>& GetNormalizationMap(); +const std::unordered_map<char16_t, char16_t>* GetNormalizationMap(); } // namespace lib } // namespace icing diff --git a/icing/transform/normalizer.h b/icing/transform/normalizer.h index 4cbfa63..2110f0f 100644 --- a/icing/transform/normalizer.h +++ b/icing/transform/normalizer.h @@ -20,6 +20,7 @@ #include <string_view> #include "icing/text_classifier/lib3/utils/base/statusor.h" +#include "icing/util/character-iterator.h" namespace icing { namespace lib { @@ -39,6 +40,17 @@ class Normalizer { // Normalizes the input term based on rules. See implementation classes for // specific transformation rules. virtual std::string NormalizeTerm(std::string_view term) const = 0; + + // Returns a CharacterIterator pointing to one past the end of the segment of + // term that (once normalized) matches with normalized_term. + // + // Ex. FindNormalizedMatchEndPosition("YELLOW", "yell") will return + // CharacterIterator(u8:4, u16:4, u32:4). + // + // Ex. FindNormalizedMatchEndPosition("YELLOW", "red") will return + // CharacterIterator(u8:0, u16:0, u32:0). + virtual CharacterIterator FindNormalizedMatchEndPosition( + std::string_view term, std::string_view normalized_term) const = 0; }; } // namespace lib diff --git a/icing/transform/simple/none-normalizer-factory.cc b/icing/transform/simple/none-normalizer-factory.cc deleted file mode 100644 index 6b35270..0000000 --- a/icing/transform/simple/none-normalizer-factory.cc +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (C) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef ICING_TRANSFORM_SIMPLE_NONE_NORMALIZER_FACTORY_H_ -#define ICING_TRANSFORM_SIMPLE_NONE_NORMALIZER_FACTORY_H_ - -#include <memory> -#include <string_view> - -#include "icing/text_classifier/lib3/utils/base/statusor.h" -#include "icing/absl_ports/canonical_errors.h" -#include "icing/transform/normalizer.h" -#include "icing/transform/simple/none-normalizer.h" - -namespace icing { -namespace lib { - -namespace normalizer_factory { - -// Creates a dummy normalizer. The term is not normalized, but -// the text will be truncated to max_term_byte_size if it exceeds the max size. -// -// Returns: -// A normalizer on success -// INVALID_ARGUMENT if max_term_byte_size <= 0 -// INTERNAL_ERROR on errors -libtextclassifier3::StatusOr<std::unique_ptr<Normalizer>> Create( - int max_term_byte_size) { - if (max_term_byte_size <= 0) { - return absl_ports::InvalidArgumentError( - "max_term_byte_size must be greater than zero."); - } - - return std::make_unique<NoneNormalizer>(max_term_byte_size); -} - -} // namespace normalizer_factory - -} // namespace lib -} // namespace icing - -#endif // ICING_TRANSFORM_SIMPLE_NONE_NORMALIZER_FACTORY_H_ diff --git a/icing/transform/simple/none-normalizer.h b/icing/transform/simple/none-normalizer.h deleted file mode 100644 index 47085e1..0000000 --- a/icing/transform/simple/none-normalizer.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (C) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef ICING_TRANSFORM_SIMPLE_NONE_NORMALIZER_H_ -#define ICING_TRANSFORM_SIMPLE_NONE_NORMALIZER_H_ - -#include <string> -#include <string_view> - -#include "icing/transform/normalizer.h" - -namespace icing { -namespace lib { - -// This normalizer is not meant for production use. Currently only used to get -// the Icing library to compile in Jetpack. -// -// No normalization is done, but the term is truncated if it exceeds -// max_term_byte_size. -class NoneNormalizer : public Normalizer { - public: - explicit NoneNormalizer(int max_term_byte_size) - : max_term_byte_size_(max_term_byte_size){}; - - std::string NormalizeTerm(std::string_view term) const override { - if (term.length() > max_term_byte_size_) { - return std::string(term.substr(0, max_term_byte_size_)); - } - return std::string(term); - } - - private: - // The maximum term length allowed after normalization. - int max_term_byte_size_; -}; - -} // namespace lib -} // namespace icing - -#endif // ICING_TRANSFORM_SIMPLE_NONE_NORMALIZER_H_ diff --git a/icing/transform/simple/none-normalizer_test.cc b/icing/transform/simple/none-normalizer_test.cc deleted file mode 100644 index e074828..0000000 --- a/icing/transform/simple/none-normalizer_test.cc +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (C) 2019 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <memory> - -#include "gmock/gmock.h" -#include "gtest/gtest.h" -#include "icing/testing/common-matchers.h" -#include "icing/transform/normalizer-factory.h" -#include "icing/transform/normalizer.h" - -namespace icing { -namespace lib { -namespace { - -using ::testing::Eq; - -TEST(NoneNormalizerTest, Creation) { - EXPECT_THAT(normalizer_factory::Create( - /*max_term_byte_size=*/5), - IsOk()); - EXPECT_THAT(normalizer_factory::Create( - /*max_term_byte_size=*/0), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); - EXPECT_THAT(normalizer_factory::Create( - /*max_term_byte_size=*/-1), - StatusIs(libtextclassifier3::StatusCode::INVALID_ARGUMENT)); -} - -TEST(IcuNormalizerTest, NoNormalizationDone) { - ICING_ASSERT_OK_AND_ASSIGN(auto normalizer, normalizer_factory::Create( - /*max_term_byte_size=*/1000)); - EXPECT_THAT(normalizer->NormalizeTerm(""), Eq("")); - EXPECT_THAT(normalizer->NormalizeTerm("hello world"), Eq("hello world")); - - // Capitalization - EXPECT_THAT(normalizer->NormalizeTerm("MDI"), Eq("MDI")); - - // Accents - EXPECT_THAT(normalizer->NormalizeTerm("Zürich"), Eq("Zürich")); - - // Full-width punctuation to ASCII punctuation - EXPECT_THAT(normalizer->NormalizeTerm("。,!?:”"), Eq("。,!?:”")); - - // Half-width katakana - EXPECT_THAT(normalizer->NormalizeTerm("カ"), Eq("カ")); -} - -TEST(NoneNormalizerTest, Truncate) { - ICING_ASSERT_OK_AND_ASSIGN(auto normalizer, normalizer_factory::Create( - /*max_term_byte_size=*/5)); - - // Won't be truncated - EXPECT_THAT(normalizer->NormalizeTerm("hi"), Eq("hi")); - EXPECT_THAT(normalizer->NormalizeTerm("hello"), Eq("hello")); - - // Truncated to length 5. - EXPECT_THAT(normalizer->NormalizeTerm("hello!"), Eq("hello")); -} - -} // namespace -} // namespace lib -} // namespace icing diff --git a/icing/util/character-iterator.cc b/icing/util/character-iterator.cc index 6c5faef..d483031 100644 --- a/icing/util/character-iterator.cc +++ b/icing/util/character-iterator.cc @@ -14,6 +14,8 @@ #include "icing/util/character-iterator.h" +#include "icing/util/i18n-utils.h" + namespace icing { namespace lib { @@ -30,6 +32,17 @@ int GetUTF8StartPosition(std::string_view text, int current_byte_index) { } // namespace +UChar32 CharacterIterator::GetCurrentChar() { + if (cached_current_char_ == i18n_utils::kInvalidUChar32) { + // Our indices point to the right character, we just need to read that + // character. No need to worry about an error. If GetUChar32At fails, then + // current_char will be i18n_utils::kInvalidUChar32. + cached_current_char_ = + i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); + } + return cached_current_char_; +} + bool CharacterIterator::MoveToUtf8(int desired_utf8_index) { return (desired_utf8_index > utf8_index_) ? AdvanceToUtf8(desired_utf8_index) : RewindToUtf8(desired_utf8_index); @@ -41,11 +54,13 @@ bool CharacterIterator::AdvanceToUtf8(int desired_utf8_index) { return false; } // Need to work forwards. + UChar32 uchar32 = cached_current_char_; while (utf8_index_ < desired_utf8_index) { - UChar32 uchar32 = + uchar32 = i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); if (uchar32 == i18n_utils::kInvalidUChar32) { // Unable to retrieve a valid UTF-32 character at the previous position. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } int utf8_length = i18n_utils::GetUtf8Length(uchar32); @@ -57,6 +72,8 @@ bool CharacterIterator::AdvanceToUtf8(int desired_utf8_index) { utf16_index_ += i18n_utils::GetUtf16Length(uchar32); ++utf32_index_; } + cached_current_char_ = + i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); return true; } @@ -66,21 +83,30 @@ bool CharacterIterator::RewindToUtf8(int desired_utf8_index) { return false; } // Need to work backwards. + UChar32 uchar32 = cached_current_char_; while (utf8_index_ > desired_utf8_index) { - --utf8_index_; - utf8_index_ = GetUTF8StartPosition(text_, utf8_index_); - if (utf8_index_ < 0) { + int utf8_index = utf8_index_ - 1; + utf8_index = GetUTF8StartPosition(text_, utf8_index); + if (utf8_index < 0) { // Somehow, there wasn't a single UTF-8 lead byte at // requested_byte_index or an earlier byte. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } // We've found the start of a unicode char! - UChar32 uchar32 = - i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); - if (uchar32 == i18n_utils::kInvalidUChar32) { - // Unable to retrieve a valid UTF-32 character at the previous position. + uchar32 = + i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index); + int expected_length = utf8_index_ - utf8_index; + if (uchar32 == i18n_utils::kInvalidUChar32 || + expected_length != i18n_utils::GetUtf8Length(uchar32)) { + // Either unable to retrieve a valid UTF-32 character at the previous + // position or we skipped past an invalid sequence while seeking the + // previous start position. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } + cached_current_char_ = uchar32; + utf8_index_ = utf8_index; utf16_index_ -= i18n_utils::GetUtf16Length(uchar32); --utf32_index_; } @@ -94,11 +120,13 @@ bool CharacterIterator::MoveToUtf16(int desired_utf16_index) { } bool CharacterIterator::AdvanceToUtf16(int desired_utf16_index) { + UChar32 uchar32 = cached_current_char_; while (utf16_index_ < desired_utf16_index) { - UChar32 uchar32 = + uchar32 = i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); if (uchar32 == i18n_utils::kInvalidUChar32) { // Unable to retrieve a valid UTF-32 character at the previous position. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } int utf16_length = i18n_utils::GetUtf16Length(uchar32); @@ -109,12 +137,15 @@ bool CharacterIterator::AdvanceToUtf16(int desired_utf16_index) { int utf8_length = i18n_utils::GetUtf8Length(uchar32); if (utf8_index_ + utf8_length > text_.length()) { // Enforce the requirement. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } utf8_index_ += utf8_length; utf16_index_ += utf16_length; ++utf32_index_; } + cached_current_char_ = + i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); return true; } @@ -122,21 +153,30 @@ bool CharacterIterator::RewindToUtf16(int desired_utf16_index) { if (desired_utf16_index < 0) { return false; } + UChar32 uchar32 = cached_current_char_; while (utf16_index_ > desired_utf16_index) { - --utf8_index_; - utf8_index_ = GetUTF8StartPosition(text_, utf8_index_); - if (utf8_index_ < 0) { + int utf8_index = utf8_index_ - 1; + utf8_index = GetUTF8StartPosition(text_, utf8_index); + if (utf8_index < 0) { // Somehow, there wasn't a single UTF-8 lead byte at // requested_byte_index or an earlier byte. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } // We've found the start of a unicode char! - UChar32 uchar32 = - i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); - if (uchar32 == i18n_utils::kInvalidUChar32) { - // Unable to retrieve a valid UTF-32 character at the previous position. + uchar32 = + i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index); + int expected_length = utf8_index_ - utf8_index; + if (uchar32 == i18n_utils::kInvalidUChar32 || + expected_length != i18n_utils::GetUtf8Length(uchar32)) { + // Either unable to retrieve a valid UTF-32 character at the previous + // position or we skipped past an invalid sequence while seeking the + // previous start position. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } + cached_current_char_ = uchar32; + utf8_index_ = utf8_index; utf16_index_ -= i18n_utils::GetUtf16Length(uchar32); --utf32_index_; } @@ -150,23 +190,28 @@ bool CharacterIterator::MoveToUtf32(int desired_utf32_index) { } bool CharacterIterator::AdvanceToUtf32(int desired_utf32_index) { + UChar32 uchar32 = cached_current_char_; while (utf32_index_ < desired_utf32_index) { - UChar32 uchar32 = + uchar32 = i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); if (uchar32 == i18n_utils::kInvalidUChar32) { // Unable to retrieve a valid UTF-32 character at the previous position. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } int utf16_length = i18n_utils::GetUtf16Length(uchar32); int utf8_length = i18n_utils::GetUtf8Length(uchar32); if (utf8_index_ + utf8_length > text_.length()) { // Enforce the requirement. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } utf8_index_ += utf8_length; utf16_index_ += utf16_length; ++utf32_index_; } + cached_current_char_ = + i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); return true; } @@ -174,21 +219,30 @@ bool CharacterIterator::RewindToUtf32(int desired_utf32_index) { if (desired_utf32_index < 0) { return false; } + UChar32 uchar32 = cached_current_char_; while (utf32_index_ > desired_utf32_index) { - --utf8_index_; - utf8_index_ = GetUTF8StartPosition(text_, utf8_index_); - if (utf8_index_ < 0) { + int utf8_index = utf8_index_ - 1; + utf8_index = GetUTF8StartPosition(text_, utf8_index); + if (utf8_index < 0) { // Somehow, there wasn't a single UTF-8 lead byte at // requested_byte_index or an earlier byte. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } // We've found the start of a unicode char! - UChar32 uchar32 = - i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index_); - if (uchar32 == i18n_utils::kInvalidUChar32) { - // Unable to retrieve a valid UTF-32 character at the previous position. + uchar32 = + i18n_utils::GetUChar32At(text_.data(), text_.length(), utf8_index); + int expected_length = utf8_index_ - utf8_index; + if (uchar32 == i18n_utils::kInvalidUChar32 || + expected_length != i18n_utils::GetUtf8Length(uchar32)) { + // Either unable to retrieve a valid UTF-32 character at the previous + // position or we skipped past an invalid sequence while seeking the + // previous start position. + cached_current_char_ = i18n_utils::kInvalidUChar32; return false; } + cached_current_char_ = uchar32; + utf8_index_ = utf8_index; utf16_index_ -= i18n_utils::GetUtf16Length(uchar32); --utf32_index_; } diff --git a/icing/util/character-iterator.h b/icing/util/character-iterator.h index 9df7bee..c7569a7 100644 --- a/icing/util/character-iterator.h +++ b/icing/util/character-iterator.h @@ -29,10 +29,15 @@ class CharacterIterator { CharacterIterator(std::string_view text, int utf8_index, int utf16_index, int utf32_index) : text_(text), + cached_current_char_(i18n_utils::kInvalidUChar32), utf8_index_(utf8_index), utf16_index_(utf16_index), utf32_index_(utf32_index) {} + // Returns the character that the iterator currently points to. + // i18n_utils::kInvalidUChar32 if unable to read that character. + UChar32 GetCurrentChar(); + // Moves current position to desired_utf8_index. // REQUIRES: 0 <= desired_utf8_index <= text_.length() bool MoveToUtf8(int desired_utf8_index); @@ -82,6 +87,8 @@ class CharacterIterator { int utf32_index() const { return utf32_index_; } bool operator==(const CharacterIterator& rhs) const { + // cached_current_char_ is just that: a cached value. As such, it's not + // considered for equality. return text_ == rhs.text_ && utf8_index_ == rhs.utf8_index_ && utf16_index_ == rhs.utf16_index_ && utf32_index_ == rhs.utf32_index_; } @@ -93,6 +100,7 @@ class CharacterIterator { private: std::string_view text_; + UChar32 cached_current_char_; int utf8_index_; int utf16_index_; int utf32_index_; diff --git a/icing/util/character-iterator_test.cc b/icing/util/character-iterator_test.cc new file mode 100644 index 0000000..445f837 --- /dev/null +++ b/icing/util/character-iterator_test.cc @@ -0,0 +1,235 @@ +// Copyright (C) 2019 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "icing/util/character-iterator.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "icing/testing/icu-i18n-test-utils.h" + +namespace icing { +namespace lib { + +using ::testing::Eq; +using ::testing::IsFalse; +using ::testing::IsTrue; + +TEST(CharacterIteratorTest, BasicUtf8) { + constexpr std::string_view kText = "¿Dónde está la biblioteca?"; + CharacterIterator iterator(kText); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("¿")); + + EXPECT_THAT(iterator.AdvanceToUtf8(4), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("ó")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/3, /*utf16_index=*/2, + /*utf32_index=*/2))); + + EXPECT_THAT(iterator.AdvanceToUtf8(18), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("b")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/18, /*utf16_index=*/15, + /*utf32_index=*/15))); + + EXPECT_THAT(iterator.AdvanceToUtf8(28), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("?")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/28, /*utf16_index=*/25, + /*utf32_index=*/25))); + + EXPECT_THAT(iterator.AdvanceToUtf8(29), IsTrue()); + EXPECT_THAT(iterator.GetCurrentChar(), Eq(0)); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/29, /*utf16_index=*/26, + /*utf32_index=*/26))); + + EXPECT_THAT(iterator.RewindToUtf8(28), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("?")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/28, /*utf16_index=*/25, + /*utf32_index=*/25))); + + EXPECT_THAT(iterator.RewindToUtf8(18), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("b")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/18, /*utf16_index=*/15, + /*utf32_index=*/15))); + + EXPECT_THAT(iterator.RewindToUtf8(4), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("ó")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/3, /*utf16_index=*/2, + /*utf32_index=*/2))); + + EXPECT_THAT(iterator.RewindToUtf8(0), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("¿")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/0, /*utf16_index=*/0, + /*utf32_index=*/0))); +} + +TEST(CharacterIteratorTest, BasicUtf16) { + constexpr std::string_view kText = "¿Dónde está la biblioteca?"; + CharacterIterator iterator(kText); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("¿")); + + EXPECT_THAT(iterator.AdvanceToUtf16(2), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("ó")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/3, /*utf16_index=*/2, + /*utf32_index=*/2))); + + EXPECT_THAT(iterator.AdvanceToUtf16(15), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("b")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/18, /*utf16_index=*/15, + /*utf32_index=*/15))); + + EXPECT_THAT(iterator.AdvanceToUtf16(25), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("?")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/28, /*utf16_index=*/25, + /*utf32_index=*/25))); + + EXPECT_THAT(iterator.AdvanceToUtf16(26), IsTrue()); + EXPECT_THAT(iterator.GetCurrentChar(), Eq(0)); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/29, /*utf16_index=*/26, + /*utf32_index=*/26))); + + EXPECT_THAT(iterator.RewindToUtf16(25), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("?")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/28, /*utf16_index=*/25, + /*utf32_index=*/25))); + + EXPECT_THAT(iterator.RewindToUtf16(15), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("b")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/18, /*utf16_index=*/15, + /*utf32_index=*/15))); + + EXPECT_THAT(iterator.RewindToUtf16(2), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("ó")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/3, /*utf16_index=*/2, + /*utf32_index=*/2))); + + EXPECT_THAT(iterator.RewindToUtf8(0), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("¿")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/0, /*utf16_index=*/0, + /*utf32_index=*/0))); +} + +TEST(CharacterIteratorTest, BasicUtf32) { + constexpr std::string_view kText = "¿Dónde está la biblioteca?"; + CharacterIterator iterator(kText); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("¿")); + + EXPECT_THAT(iterator.AdvanceToUtf32(2), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("ó")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/3, /*utf16_index=*/2, + /*utf32_index=*/2))); + + EXPECT_THAT(iterator.AdvanceToUtf32(15), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("b")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/18, /*utf16_index=*/15, + /*utf32_index=*/15))); + + EXPECT_THAT(iterator.AdvanceToUtf32(25), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("?")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/28, /*utf16_index=*/25, + /*utf32_index=*/25))); + + EXPECT_THAT(iterator.AdvanceToUtf32(26), IsTrue()); + EXPECT_THAT(iterator.GetCurrentChar(), Eq(0)); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/29, /*utf16_index=*/26, + /*utf32_index=*/26))); + + EXPECT_THAT(iterator.RewindToUtf32(25), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("?")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/28, /*utf16_index=*/25, + /*utf32_index=*/25))); + + EXPECT_THAT(iterator.RewindToUtf32(15), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("b")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/18, /*utf16_index=*/15, + /*utf32_index=*/15))); + + EXPECT_THAT(iterator.RewindToUtf32(2), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("ó")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/3, /*utf16_index=*/2, + /*utf32_index=*/2))); + + EXPECT_THAT(iterator.RewindToUtf32(0), IsTrue()); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("¿")); + EXPECT_THAT(iterator, + Eq(CharacterIterator(kText, /*utf8_index=*/0, /*utf16_index=*/0, + /*utf32_index=*/0))); +} + +TEST(CharacterIteratorTest, InvalidUtf) { + // "\255" is an invalid sequence. + constexpr std::string_view kText = "foo \255 bar"; + CharacterIterator iterator(kText); + + // Try to advance to the 'b' in 'bar'. This will fail and leave us pointed at + // the invalid sequence '\255'. Get CurrentChar() should return an invalid + // character. + EXPECT_THAT(iterator.AdvanceToUtf8(6), IsFalse()); + EXPECT_THAT(iterator.GetCurrentChar(), Eq(i18n_utils::kInvalidUChar32)); + CharacterIterator exp_iterator(kText, /*utf8_index=*/4, /*utf16_index=*/4, + /*utf32_index=*/4); + EXPECT_THAT(iterator, Eq(exp_iterator)); + + EXPECT_THAT(iterator.AdvanceToUtf16(6), IsFalse()); + EXPECT_THAT(iterator.GetCurrentChar(), Eq(i18n_utils::kInvalidUChar32)); + EXPECT_THAT(iterator, Eq(exp_iterator)); + + EXPECT_THAT(iterator.AdvanceToUtf32(6), IsFalse()); + EXPECT_THAT(iterator.GetCurrentChar(), Eq(i18n_utils::kInvalidUChar32)); + EXPECT_THAT(iterator, Eq(exp_iterator)); + + // Create the iterator with it pointing at the 'b' in 'bar'. + iterator = CharacterIterator(kText, /*utf8_index=*/6, /*utf16_index=*/6, + /*utf32_index=*/6); + EXPECT_THAT(UCharToString(iterator.GetCurrentChar()), Eq("b")); + + // Try to advance to the last 'o' in 'foo'. This will fail and leave us + // pointed at the ' ' before the invalid sequence '\255'. + exp_iterator = CharacterIterator(kText, /*utf8_index=*/5, /*utf16_index=*/5, + /*utf32_index=*/5); + EXPECT_THAT(iterator.RewindToUtf8(2), IsFalse()); + EXPECT_THAT(iterator.GetCurrentChar(), Eq(' ')); + EXPECT_THAT(iterator, Eq(exp_iterator)); + + EXPECT_THAT(iterator.RewindToUtf16(2), IsFalse()); + EXPECT_THAT(iterator.GetCurrentChar(), Eq(' ')); + EXPECT_THAT(iterator, Eq(exp_iterator)); + + EXPECT_THAT(iterator.RewindToUtf32(2), IsFalse()); + EXPECT_THAT(iterator.GetCurrentChar(), Eq(' ')); + EXPECT_THAT(iterator, Eq(exp_iterator)); +} + +} // namespace lib +} // namespace icing diff --git a/icing/util/i18n-utils.cc b/icing/util/i18n-utils.cc index cd0a227..ec327ad 100644 --- a/icing/util/i18n-utils.cc +++ b/icing/util/i18n-utils.cc @@ -116,6 +116,8 @@ bool IsAscii(char c) { return U8_IS_SINGLE((uint8_t)c); } bool IsAscii(UChar32 c) { return U8_LENGTH(c) == 1; } +bool IsAlphaNumeric(UChar32 c) { return u_isalnum(c); } + int GetUtf8Length(UChar32 c) { return U8_LENGTH(c); } int GetUtf16Length(UChar32 c) { return U16_LENGTH(c); } diff --git a/icing/util/i18n-utils.h b/icing/util/i18n-utils.h index 82ae828..491df6b 100644 --- a/icing/util/i18n-utils.h +++ b/icing/util/i18n-utils.h @@ -67,6 +67,9 @@ bool IsAscii(char c); // Checks if the Unicode char is within ASCII range. bool IsAscii(UChar32 c); +// Checks if the Unicode char is alphanumeric. +bool IsAlphaNumeric(UChar32 c); + // Returns how many code units (char) are used for the UTF-8 encoding of this // Unicode character. Returns 0 if not valid. int GetUtf8Length(UChar32 c); diff --git a/java/src/com/google/android/icing/IcingSearchEngine.java b/java/src/com/google/android/icing/IcingSearchEngine.java index 1f5fb51..95e0c84 100644 --- a/java/src/com/google/android/icing/IcingSearchEngine.java +++ b/java/src/com/google/android/icing/IcingSearchEngine.java @@ -43,6 +43,8 @@ import com.google.android.icing.proto.SearchSpecProto; import com.google.android.icing.proto.SetSchemaResultProto; import com.google.android.icing.proto.StatusProto; import com.google.android.icing.proto.StorageInfoResultProto; +import com.google.android.icing.proto.SuggestionResponse; +import com.google.android.icing.proto.SuggestionSpecProto; import com.google.android.icing.proto.UsageReport; import com.google.protobuf.ExtensionRegistryLite; import com.google.protobuf.InvalidProtocolBufferException; @@ -370,6 +372,26 @@ public class IcingSearchEngine implements Closeable { } @NonNull + public SuggestionResponse searchSuggestions(@NonNull SuggestionSpecProto suggestionSpec) { + byte[] suggestionResponseBytes = nativeSearchSuggestions(this, suggestionSpec.toByteArray()); + if (suggestionResponseBytes == null) { + Log.e(TAG, "Received null suggestionResponseBytes from native."); + return SuggestionResponse.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + + try { + return SuggestionResponse.parseFrom(suggestionResponseBytes, EXTENSION_REGISTRY_LITE); + } catch (InvalidProtocolBufferException e) { + Log.e(TAG, "Error parsing suggestionResponseBytes.", e); + return SuggestionResponse.newBuilder() + .setStatus(StatusProto.newBuilder().setCode(StatusProto.Code.INTERNAL)) + .build(); + } + } + + @NonNull public DeleteByNamespaceResultProto deleteByNamespace(@NonNull String namespace) { throwIfClosed(); @@ -604,4 +626,7 @@ public class IcingSearchEngine implements Closeable { private static native byte[] nativeGetStorageInfo(IcingSearchEngine instance); private static native byte[] nativeReset(IcingSearchEngine instance); + + private static native byte[] nativeSearchSuggestions( + IcingSearchEngine instance, byte[] suggestionSpecBytes); } diff --git a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java index 0cee80c..cb28331 100644 --- a/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java +++ b/java/tests/instrumentation/src/com/google/android/icing/IcingSearchEngineTest.java @@ -51,6 +51,8 @@ import com.google.android.icing.proto.StatusProto; import com.google.android.icing.proto.StorageInfoResultProto; import com.google.android.icing.proto.StringIndexingConfig; import com.google.android.icing.proto.StringIndexingConfig.TokenizerType; +import com.google.android.icing.proto.SuggestionResponse; +import com.google.android.icing.proto.SuggestionSpecProto; import com.google.android.icing.proto.TermMatchType; import com.google.android.icing.proto.UsageReport; import com.google.android.icing.IcingSearchEngine; @@ -623,6 +625,40 @@ public final class IcingSearchEngineTest { assertThat(match).isEqualTo("𐀂𐀃"); } + @Test + public void testSearchSuggestions() { + assertStatusOk(icingSearchEngine.initialize().getStatus()); + + SchemaTypeConfigProto emailTypeConfig = createEmailTypeConfig(); + SchemaProto schema = SchemaProto.newBuilder().addTypes(emailTypeConfig).build(); + assertThat( + icingSearchEngine + .setSchema(schema, /*ignoreErrorsAndDeleteDocuments=*/ false) + .getStatus() + .getCode()) + .isEqualTo(StatusProto.Code.OK); + + DocumentProto emailDocument1 = + createEmailDocument("namespace", "uri1").toBuilder() + .addProperties(PropertyProto.newBuilder().setName("subject").addStringValues("fo")) + .build(); + DocumentProto emailDocument2 = + createEmailDocument("namespace", "uri2").toBuilder() + .addProperties(PropertyProto.newBuilder().setName("subject").addStringValues("foo")) + .build(); + assertStatusOk(icingSearchEngine.put(emailDocument1).getStatus()); + assertStatusOk(icingSearchEngine.put(emailDocument2).getStatus()); + + SuggestionSpecProto suggestionSpec = + SuggestionSpecProto.newBuilder().setPrefix("f").setNumToReturn(10).build(); + + SuggestionResponse response = icingSearchEngine.searchSuggestions(suggestionSpec); + assertStatusOk(response.getStatus()); + assertThat(response.getSuggestionsList()).hasSize(2); + assertThat(response.getSuggestions(0).getQuery()).isEqualTo("foo"); + assertThat(response.getSuggestions(1).getQuery()).isEqualTo("fo"); + } + private static void assertStatusOk(StatusProto status) { assertWithMessage(status.getMessage()).that(status.getCode()).isEqualTo(StatusProto.Code.OK); } diff --git a/proto/icing/proto/logging.proto b/proto/icing/proto/logging.proto index 4dcfecf..2f1f271 100644 --- a/proto/icing/proto/logging.proto +++ b/proto/icing/proto/logging.proto @@ -23,7 +23,7 @@ option java_multiple_files = true; option objc_class_prefix = "ICNG"; // Stats of the top-level function IcingSearchEngine::Initialize(). -// Next tag: 11 +// Next tag: 12 message InitializeStatsProto { // Overall time used for the function call. optional int32 latency_ms = 1; @@ -92,6 +92,10 @@ message InitializeStatsProto { // Number of schema types currently in schema store. optional int32 num_schema_types = 10; + + // Number of consecutive initialization failures that immediately preceded + // this initialization. + optional int32 num_previous_init_failures = 11; } // Stats of the top-level function IcingSearchEngine::Put(). @@ -114,12 +118,10 @@ message PutDocumentStatsProto { optional int32 document_size = 5; message TokenizationStats { - // Whether the number of tokens to be indexed exceeded the max number of - // tokens per document. - optional bool exceeded_max_token_num = 2; - // Number of tokens added to the index. optional int32 num_tokens_indexed = 1; + + reserved 2; } optional TokenizationStats tokenization_stats = 6; } diff --git a/proto/icing/proto/scoring.proto b/proto/icing/proto/scoring.proto index 6186fde..a3a64df 100644 --- a/proto/icing/proto/scoring.proto +++ b/proto/icing/proto/scoring.proto @@ -23,7 +23,7 @@ option objc_class_prefix = "ICNG"; // Encapsulates the configurations on how Icing should score and rank the search // results. // TODO(b/170347684): Change all timestamps to seconds. -// Next tag: 3 +// Next tag: 4 message ScoringSpecProto { // OPTIONAL: Indicates how the search results will be ranked. message RankingStrategy { @@ -83,4 +83,41 @@ message ScoringSpecProto { } } optional Order.Code order_by = 2; + + // OPTIONAL: Specifies property weights for RELEVANCE_SCORE scoring strategy. + // Property weights are used for promoting or demoting query term matches in a + // document property. When property weights are provided, the term frequency + // is multiplied by the normalized property weight when computing the + // normalized term frequency component of BM25F. To prefer query term matches + // in the "subject" property over the "body" property of "Email" documents, + // set a higher property weight value for "subject" than "body". By default, + // all properties that are not specified are given a raw, pre-normalized + // weight of 1.0 when scoring. + repeated TypePropertyWeights type_property_weights = 3; +} + +// Next tag: 3 +message TypePropertyWeights { + // Schema type to apply property weights to. + optional string schema_type = 1; + + // Property weights to apply to the schema type. + repeated PropertyWeight property_weights = 2; +} + +// Next tag: 3 +message PropertyWeight { + // Property path to assign property weight to. Property paths must be composed + // only of property names and property separators (the '.' character). + // For example, if an "Email" schema type has string property "subject" and + // document property "sender", which has string property "name", the property + // path for the email's subject would just be "subject" and the property path + // for the sender's name would be "sender.name". If an invalid path is + // specified, the property weight is discarded. + optional string path = 1; + + // Property weight, valid values are positive. Zero and negative weights are + // invalid and will result in an error. By default, a property is given a raw, + // pre-normalized weight of 1.0. + optional double weight = 2; } diff --git a/proto/icing/proto/search.proto b/proto/icing/proto/search.proto index 66fdbe6..c712ab2 100644 --- a/proto/icing/proto/search.proto +++ b/proto/icing/proto/search.proto @@ -136,27 +136,57 @@ message ResultSpecProto { } // The representation of a single match within a DocumentProto property. -// Next tag: 10 +// +// Example : A document whose content is "Necesito comprar comida mañana." and a +// query for "mana" with window=15 +// Next tag: 12 message SnippetMatchProto { // The index of the byte in the string at which the match begins and the // length in bytes of the match. + // + // For the example above, the values of these fields would be + // exact_match_byte_position=24, exact_match_byte_length=7 "mañana" optional int32 exact_match_byte_position = 2; optional int32 exact_match_byte_length = 3; + // The length in bytes of the subterm that matches the query. The beginning of + // the submatch is the same as exact_match_byte_position. + // + // For the example above, the value of this field would be 5. With + // exact_match_byte_position=24 above, it would produce the substring "maña" + optional int32 submatch_byte_length = 10; + // The index of the UTF-16 code unit in the string at which the match begins // and the length in UTF-16 code units of the match. This is for use with // UTF-16 encoded strings like Java.lang.String. + // + // For the example above, the values of these fields would be + // exact_match_utf16_position=24, exact_match_utf16_length=6 "mañana" optional int32 exact_match_utf16_position = 6; optional int32 exact_match_utf16_length = 7; + // The length in UTF-16 code units of the subterm that matches the query. The + // beginning of the submatch is the same as exact_match_utf16_position. This + // is for use with UTF-16 encoded strings like Java.lang.String. + // + // For the example above, the value of this field would be 4. With + // exact_match_utf16_position=24 above, it would produce the substring "maña" + optional int32 submatch_utf16_length = 11; + // The index of the byte in the string at which the suggested snippet window // begins and the length in bytes of the window. + // + // For the example above, the values of these fields would be + // window_byte_position=17, window_byte_length=15 "comida mañana." optional int32 window_byte_position = 4; optional int32 window_byte_length = 5; // The index of the UTF-16 code unit in the string at which the suggested // snippet window begins and the length in UTF-16 code units of the window. // This is for use with UTF-16 encoded strings like Java.lang.String. + // + // For the example above, the values of these fields would be + // window_utf16_position=17, window_utf16_length=14 "comida mañana." optional int32 window_utf16_position = 8; optional int32 window_utf16_length = 9; @@ -278,3 +308,37 @@ message GetResultSpecProto { // type will be retrieved. repeated TypePropertyMask type_property_masks = 1; } + +// Next tag: 4 +message SuggestionSpecProto { + // REQUIRED: The "raw" prefix string that users may type. For example, "f" + // will search for suggested query that start with "f" like "foo", "fool". + optional string prefix = 1; + + // OPTIONAL: Only search for suggestions that under the specified namespaces. + // If unset, the suggestion will search over all namespaces. Note that this + // applies to the entire 'prefix'. To issue different suggestions for + // different namespaces, separate RunSuggestion()'s will need to be made. + repeated string namespace_filters = 2; + + // REQUIRED: The number of suggestions to be returned. + optional int32 num_to_return = 3; +} + +// Next tag: 3 +message SuggestionResponse { + message Suggestion { + // The suggested query string for client to search for. + optional string query = 1; + } + + // Status code can be one of: + // OK + // FAILED_PRECONDITION + // INTERNAL + // + // See status.proto for more details. + optional StatusProto status = 1; + + repeated Suggestion suggestions = 2; +} diff --git a/synced_AOSP_CL_number.txt b/synced_AOSP_CL_number.txt index 69dfc00..7e0431b 100644 --- a/synced_AOSP_CL_number.txt +++ b/synced_AOSP_CL_number.txt @@ -1 +1 @@ -set(synced_AOSP_CL_number=385604495) +set(synced_AOSP_CL_number=404879391) |