diff options
Diffstat (limited to 'icing/join/join-processor.cc')
-rw-r--r-- | icing/join/join-processor.cc | 47 |
1 files changed, 24 insertions, 23 deletions
diff --git a/icing/join/join-processor.cc b/icing/join/join-processor.cc index 7700397..ab32850 100644 --- a/icing/join/join-processor.cc +++ b/icing/join/join-processor.cc @@ -34,11 +34,17 @@ namespace icing { namespace lib { -libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>> -JoinProcessor::Join( +libtextclassifier3::StatusOr<JoinChildrenFetcher> +JoinProcessor::GetChildrenFetcher( const JoinSpecProto& join_spec, - std::vector<ScoredDocumentHit>&& parent_scored_document_hits, std::vector<ScoredDocumentHit>&& child_scored_document_hits) { + if (join_spec.parent_property_expression() != kQualifiedIdExpr) { + // TODO(b/256022027): So far we only support kQualifiedIdExpr for + // parent_property_expression, we could support more. + return absl_ports::UnimplementedError(absl_ports::StrCat( + "Parent property expression must be ", kQualifiedIdExpr)); + } + std::sort( child_scored_document_hits.begin(), child_scored_document_hits.end(), ScoredDocumentHitComparator( @@ -59,7 +65,7 @@ JoinProcessor::Join( // ScoredDocumentHits refer to. The values in this map are vectors of child // ScoredDocumentHits that refer to a parent DocumentId. std::unordered_map<DocumentId, std::vector<ScoredDocumentHit>> - parent_id_to_child_map; + map_joinable_qualified_id; for (const ScoredDocumentHit& child : child_scored_document_hits) { std::string property_content = FetchPropertyExpressionValue( child.document_id(), join_spec.child_property_expression()); @@ -84,14 +90,21 @@ JoinProcessor::Join( DocumentId parent_doc_id = std::move(parent_doc_id_or).ValueOrDie(); // Since we've already sorted child_scored_document_hits, just simply omit - // if the parent_id_to_child_map[parent_doc_id].size() has reached max + // if the map_joinable_qualified_id[parent_doc_id].size() has reached max // joined child count. - if (parent_id_to_child_map[parent_doc_id].size() < + if (map_joinable_qualified_id[parent_doc_id].size() < join_spec.max_joined_child_count()) { - parent_id_to_child_map[parent_doc_id].push_back(child); + map_joinable_qualified_id[parent_doc_id].push_back(child); } } + return JoinChildrenFetcher(join_spec, std::move(map_joinable_qualified_id)); +} +libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>> +JoinProcessor::Join( + const JoinSpecProto& join_spec, + std::vector<ScoredDocumentHit>&& parent_scored_document_hits, + const JoinChildrenFetcher& join_children_fetcher) { std::unique_ptr<AggregationScorer> aggregation_scorer = AggregationScorer::Create(join_spec); @@ -100,23 +113,11 @@ JoinProcessor::Join( // Step 2: iterate through all parent documentIds and construct // JoinedScoredDocumentHit for each by looking up - // parent_id_to_child_map. + // join_children_fetcher. for (ScoredDocumentHit& parent : parent_scored_document_hits) { - DocumentId parent_doc_id = kInvalidDocumentId; - if (join_spec.parent_property_expression() == kQualifiedIdExpr) { - parent_doc_id = parent.document_id(); - } else { - // TODO(b/256022027): So far we only support kQualifiedIdExpr for - // parent_property_expression, we could support more. - return absl_ports::UnimplementedError(absl_ports::StrCat( - "Parent property expression must be ", kQualifiedIdExpr)); - } - - std::vector<ScoredDocumentHit> children; - if (auto iter = parent_id_to_child_map.find(parent_doc_id); - iter != parent_id_to_child_map.end()) { - children = std::move(iter->second); - } + ICING_ASSIGN_OR_RETURN( + std::vector<ScoredDocumentHit> children, + join_children_fetcher.GetChildren(parent.document_id())); double final_score = aggregation_scorer->GetScore(parent, children); joined_scored_document_hits.emplace_back(final_score, std::move(parent), |