diff options
Diffstat (limited to 'icing/index/main/doc-hit-info-iterator-term-main.h')
-rw-r--r-- | icing/index/main/doc-hit-info-iterator-term-main.h | 37 |
1 files changed, 24 insertions, 13 deletions
diff --git a/icing/index/main/doc-hit-info-iterator-term-main.h b/icing/index/main/doc-hit-info-iterator-term-main.h index f3cf701..c1b289f 100644 --- a/icing/index/main/doc-hit-info-iterator-term-main.h +++ b/icing/index/main/doc-hit-info-iterator-term-main.h @@ -33,14 +33,16 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { public: explicit DocHitInfoIteratorTermMain(MainIndex* main_index, const std::string& term, - SectionIdMask section_restrict_mask) + SectionIdMask section_restrict_mask, + bool need_hit_term_frequency) : term_(term), main_index_(main_index), cached_doc_hit_infos_idx_(-1), num_advance_calls_(0), num_blocks_inspected_(0), next_posting_list_id_(PostingListIdentifier::kInvalid), - section_restrict_mask_(section_restrict_mask) {} + section_restrict_mask_(section_restrict_mask), + need_hit_term_frequency_(need_hit_term_frequency) {} libtextclassifier3::Status Advance() override; @@ -52,20 +54,23 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { void PopulateMatchedTermsStats( std::vector<TermMatchInfo>* matched_terms_stats, SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override { - if (doc_hit_info_.document_id() == kInvalidDocumentId) { + if (cached_doc_hit_infos_idx_ == -1 || + cached_doc_hit_infos_idx_ >= cached_doc_hit_infos_.size()) { // Current hit isn't valid, return. return; } SectionIdMask section_mask = doc_hit_info_.hit_section_ids_mask() & filtering_section_mask; SectionIdMask section_mask_copy = section_mask; - std::array<Hit::TermFrequency, kMaxSectionId> section_term_frequencies = { - Hit::kNoTermFrequency}; + std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies = + {Hit::kNoTermFrequency}; while (section_mask_copy) { - SectionId section_id = __builtin_ctz(section_mask_copy); - section_term_frequencies.at(section_id) = - doc_hit_info_.hit_term_frequency(section_id); - section_mask_copy &= ~(1u << section_id); + SectionId section_id = __builtin_ctzll(section_mask_copy); + if (need_hit_term_frequency_) { + section_term_frequencies.at(section_id) = cached_hit_term_frequency_.at( + cached_doc_hit_infos_idx_)[section_id]; + } + section_mask_copy &= ~(UINT64_C(1) << section_id); } TermMatchInfo term_stats(term_, section_mask, std::move(section_term_frequencies)); @@ -93,6 +98,7 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { // that are present in the index. Current value pointed to by the Iterator is // tracked by cached_doc_hit_infos_idx_. std::vector<DocHitInfo> cached_doc_hit_infos_; + std::vector<Hit::TermFrequencyArray> cached_hit_term_frequency_; int cached_doc_hit_infos_idx_; int num_advance_calls_; int num_blocks_inspected_; @@ -100,14 +106,17 @@ class DocHitInfoIteratorTermMain : public DocHitInfoIterator { // Mask indicating which sections hits should be considered for. // Ex. 0000 0000 0000 0010 means that only hits from section 1 are desired. const SectionIdMask section_restrict_mask_; + const bool need_hit_term_frequency_; }; class DocHitInfoIteratorTermMainExact : public DocHitInfoIteratorTermMain { public: explicit DocHitInfoIteratorTermMainExact(MainIndex* main_index, const std::string& term, - SectionIdMask section_restrict_mask) - : DocHitInfoIteratorTermMain(main_index, term, section_restrict_mask) {} + SectionIdMask section_restrict_mask, + bool need_hit_term_frequency) + : DocHitInfoIteratorTermMain(main_index, term, section_restrict_mask, + need_hit_term_frequency) {} std::string ToString() const override; @@ -119,8 +128,10 @@ class DocHitInfoIteratorTermMainPrefix : public DocHitInfoIteratorTermMain { public: explicit DocHitInfoIteratorTermMainPrefix(MainIndex* main_index, const std::string& term, - SectionIdMask section_restrict_mask) - : DocHitInfoIteratorTermMain(main_index, term, section_restrict_mask) {} + SectionIdMask section_restrict_mask, + bool need_hit_term_frequency) + : DocHitInfoIteratorTermMain(main_index, term, section_restrict_mask, + need_hit_term_frequency) {} std::string ToString() const override; |