aboutsummaryrefslogtreecommitdiff
path: root/icing/index/iterator/doc-hit-info-iterator-test-util.h
diff options
context:
space:
mode:
Diffstat (limited to 'icing/index/iterator/doc-hit-info-iterator-test-util.h')
-rw-r--r--icing/index/iterator/doc-hit-info-iterator-test-util.h77
1 files changed, 61 insertions, 16 deletions
diff --git a/icing/index/iterator/doc-hit-info-iterator-test-util.h b/icing/index/iterator/doc-hit-info-iterator-test-util.h
index 45acc8f..ed6db23 100644
--- a/icing/index/iterator/doc-hit-info-iterator-test-util.h
+++ b/icing/index/iterator/doc-hit-info-iterator-test-util.h
@@ -31,6 +31,40 @@
namespace icing {
namespace lib {
+class DocHitInfoTermFrequencyPair {
+ public:
+ DocHitInfoTermFrequencyPair(
+ const DocHitInfo& doc_hit_info,
+ const Hit::TermFrequencyArray& hit_term_frequency = {})
+ : doc_hit_info_(doc_hit_info), hit_term_frequency_(hit_term_frequency) {}
+
+ void UpdateSection(SectionId section_id,
+ Hit::TermFrequency hit_term_frequency) {
+ doc_hit_info_.UpdateSection(section_id);
+ hit_term_frequency_[section_id] = hit_term_frequency;
+ }
+
+ void MergeSectionsFrom(const DocHitInfoTermFrequencyPair& other) {
+ SectionIdMask other_mask = other.doc_hit_info_.hit_section_ids_mask();
+ doc_hit_info_.MergeSectionsFrom(other_mask);
+ while (other_mask) {
+ SectionId section_id = __builtin_ctzll(other_mask);
+ hit_term_frequency_[section_id] = other.hit_term_frequency_[section_id];
+ other_mask &= ~(UINT64_C(1) << section_id);
+ }
+ }
+
+ DocHitInfo doc_hit_info() const { return doc_hit_info_; }
+
+ Hit::TermFrequency hit_term_frequency(SectionId section_id) const {
+ return hit_term_frequency_[section_id];
+ }
+
+ private:
+ DocHitInfo doc_hit_info_;
+ Hit::TermFrequencyArray hit_term_frequency_;
+};
+
// Dummy class to help with testing. It starts with an kInvalidDocumentId doc
// hit info until an Advance is called (like normal DocHitInfoIterators). It
// will then proceed to return the doc_hit_infos in order as Advance's are
@@ -39,14 +73,23 @@ namespace lib {
class DocHitInfoIteratorDummy : public DocHitInfoIterator {
public:
DocHitInfoIteratorDummy() = default;
- explicit DocHitInfoIteratorDummy(std::vector<DocHitInfo> doc_hit_infos,
- std::string term = "")
+ explicit DocHitInfoIteratorDummy(
+ std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos,
+ std::string term = "")
: doc_hit_infos_(std::move(doc_hit_infos)), term_(std::move(term)) {}
+ explicit DocHitInfoIteratorDummy(const std::vector<DocHitInfo>& doc_hit_infos,
+ std::string term = "")
+ : term_(std::move(term)) {
+ for (auto& doc_hit_info : doc_hit_infos) {
+ doc_hit_infos_.push_back(DocHitInfoTermFrequencyPair(doc_hit_info));
+ }
+ }
+
libtextclassifier3::Status Advance() override {
+ ++index_;
if (index_ < doc_hit_infos_.size()) {
- doc_hit_info_ = doc_hit_infos_.at(index_);
- index_++;
+ doc_hit_info_ = doc_hit_infos_.at(index_).doc_hit_info();
return libtextclassifier3::Status::OK;
}
@@ -58,20 +101,20 @@ class DocHitInfoIteratorDummy : public DocHitInfoIterator {
void PopulateMatchedTermsStats(
std::vector<TermMatchInfo>* matched_terms_stats,
SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override {
- if (doc_hit_info_.document_id() == kInvalidDocumentId) {
+ if (index_ == -1 || index_ >= doc_hit_infos_.size()) {
// Current hit isn't valid, return.
return;
}
SectionIdMask section_mask =
doc_hit_info_.hit_section_ids_mask() & filtering_section_mask;
SectionIdMask section_mask_copy = section_mask;
- std::array<Hit::TermFrequency, kMaxSectionId> section_term_frequencies = {
- Hit::kNoTermFrequency};
+ std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies =
+ {Hit::kNoTermFrequency};
while (section_mask_copy) {
- SectionId section_id = __builtin_ctz(section_mask_copy);
+ SectionId section_id = __builtin_ctzll(section_mask_copy);
section_term_frequencies.at(section_id) =
- doc_hit_info_.hit_term_frequency(section_id);
- section_mask_copy &= ~(1u << section_id);
+ doc_hit_infos_.at(index_).hit_term_frequency(section_id);
+ section_mask_copy &= ~(UINT64_C(1) << section_id);
}
TermMatchInfo term_stats(term_, section_mask,
std::move(section_term_frequencies));
@@ -109,20 +152,22 @@ class DocHitInfoIteratorDummy : public DocHitInfoIterator {
std::string ToString() const override {
std::string ret = "<";
- for (auto& doc_hit_info : doc_hit_infos_) {
- absl_ports::StrAppend(&ret, IcingStringUtil::StringPrintf(
- "[%d,%d]", doc_hit_info.document_id(),
- doc_hit_info.hit_section_ids_mask()));
+ for (auto& doc_hit_info_pair : doc_hit_infos_) {
+ absl_ports::StrAppend(
+ &ret, IcingStringUtil::StringPrintf(
+ "[%d,%" PRIu64 "]",
+ doc_hit_info_pair.doc_hit_info().document_id(),
+ doc_hit_info_pair.doc_hit_info().hit_section_ids_mask()));
}
absl_ports::StrAppend(&ret, ">");
return ret;
}
private:
- int32_t index_ = 0;
+ int32_t index_ = -1;
int32_t num_blocks_inspected_ = 0;
int32_t num_leaf_advance_calls_ = 0;
- std::vector<DocHitInfo> doc_hit_infos_;
+ std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos_;
std::string term_;
};