diff options
Diffstat (limited to 'icing/scoring/priority-queue-scored-document-hits-ranker.h')
-rw-r--r-- | icing/scoring/priority-queue-scored-document-hits-ranker.h | 69 |
1 files changed, 59 insertions, 10 deletions
diff --git a/icing/scoring/priority-queue-scored-document-hits-ranker.h b/icing/scoring/priority-queue-scored-document-hits-ranker.h index 3ef2ae5..0798d7d 100644 --- a/icing/scoring/priority-queue-scored-document-hits-ranker.h +++ b/icing/scoring/priority-queue-scored-document-hits-ranker.h @@ -26,21 +26,37 @@ namespace lib { // ScoredDocumentHitsRanker interface implementation, based on // std::priority_queue. We can get next top hit in O(lgN) time. +template <typename ScoredDataType, + typename Converter = typename ScoredDataType::Converter> class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker { public: explicit PriorityQueueScoredDocumentHitsRanker( - std::vector<ScoredDocumentHit>&& scored_document_hits, - bool is_descending = true); + std::vector<ScoredDataType>&& scored_data_vec, bool is_descending = true); ~PriorityQueueScoredDocumentHitsRanker() override = default; - ScoredDocumentHit PopNext() override; + // Note: ranker may store ScoredDocumentHit or JoinedScoredDocumentHit, so we + // have template for scored_data_pq_. + // - JoinedScoredDocumentHit is a superset of ScoredDocumentHit, so we unify + // the return type of PopNext to use the superset type + // JoinedScoredDocumentHit in order to make it simple, and rankers storing + // ScoredDocumentHit should convert it to JoinedScoredDocumentHit before + // returning. It makes the implementation simpler, especially for + // ResultRetriever, which now only needs to deal with one single return + // format. + // - JoinedScoredDocumentHit has ~2x size of ScoredDocumentHit. Since we cache + // ranker (which contains a priority queue of data) in ResultState, if we + // store the scored hits in JoinedScoredDocumentHit format directly, then it + // doubles the memory usage. Therefore, we still keep the flexibility to + // store ScoredDocumentHit or any other types of data, but require PopNext + // to convert it to JoinedScoredDocumentHit. + JoinedScoredDocumentHit PopNext() override; void TruncateHitsTo(int new_size) override; - int size() const override { return scored_document_hits_pq_.size(); } + int size() const override { return scored_data_pq_.size(); } - bool empty() const override { return scored_document_hits_pq_.empty(); } + bool empty() const override { return scored_data_pq_.empty(); } private: // Comparator for std::priority_queue. Since std::priority is a max heap @@ -49,8 +65,8 @@ class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker { public: explicit Comparator(bool is_ascending) : is_ascending_(is_ascending) {} - bool operator()(const ScoredDocumentHit& lhs, - const ScoredDocumentHit& rhs) const { + bool operator()(const ScoredDataType& lhs, + const ScoredDataType& rhs) const { // STL comparator requirement: equal MUST return false. // If writing `return is_ascending_ == !(lhs < rhs)`: // - When lhs == rhs, !(lhs < rhs) is true @@ -68,11 +84,44 @@ class PriorityQueueScoredDocumentHitsRanker : public ScoredDocumentHitsRanker { Comparator comparator_; // Use priority queue to get top K hits in O(KlgN) time. - std::priority_queue<ScoredDocumentHit, std::vector<ScoredDocumentHit>, - Comparator> - scored_document_hits_pq_; + std::priority_queue<ScoredDataType, std::vector<ScoredDataType>, Comparator> + scored_data_pq_; + + Converter converter_; }; +template <typename ScoredDataType, typename Converter> +PriorityQueueScoredDocumentHitsRanker<ScoredDataType, Converter>:: + PriorityQueueScoredDocumentHitsRanker( + std::vector<ScoredDataType>&& scored_data_vec, bool is_descending) + : comparator_(/*is_ascending=*/!is_descending), + scored_data_pq_(comparator_, std::move(scored_data_vec)) {} + +template <typename ScoredDataType, typename Converter> +JoinedScoredDocumentHit +PriorityQueueScoredDocumentHitsRanker<ScoredDataType, Converter>::PopNext() { + ScoredDataType next_scored_data = scored_data_pq_.top(); + scored_data_pq_.pop(); + return converter_(std::move(next_scored_data)); +} + +template <typename ScoredDataType, typename Converter> +void PriorityQueueScoredDocumentHitsRanker< + ScoredDataType, Converter>::TruncateHitsTo(int new_size) { + if (new_size < 0 || scored_data_pq_.size() <= new_size) { + return; + } + + // Copying the best new_size results. + std::priority_queue<ScoredDataType, std::vector<ScoredDataType>, Comparator> + new_pq(comparator_); + for (int i = 0; i < new_size; ++i) { + new_pq.push(scored_data_pq_.top()); + scored_data_pq_.pop(); + } + scored_data_pq_ = std::move(new_pq); +} + } // namespace lib } // namespace icing |