aboutsummaryrefslogtreecommitdiff
path: root/icing/join/join-processor.cc
diff options
context:
space:
mode:
Diffstat (limited to 'icing/join/join-processor.cc')
-rw-r--r--icing/join/join-processor.cc47
1 files changed, 24 insertions, 23 deletions
diff --git a/icing/join/join-processor.cc b/icing/join/join-processor.cc
index 7700397..ab32850 100644
--- a/icing/join/join-processor.cc
+++ b/icing/join/join-processor.cc
@@ -34,11 +34,17 @@
namespace icing {
namespace lib {
-libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>>
-JoinProcessor::Join(
+libtextclassifier3::StatusOr<JoinChildrenFetcher>
+JoinProcessor::GetChildrenFetcher(
const JoinSpecProto& join_spec,
- std::vector<ScoredDocumentHit>&& parent_scored_document_hits,
std::vector<ScoredDocumentHit>&& child_scored_document_hits) {
+ if (join_spec.parent_property_expression() != kQualifiedIdExpr) {
+ // TODO(b/256022027): So far we only support kQualifiedIdExpr for
+ // parent_property_expression, we could support more.
+ return absl_ports::UnimplementedError(absl_ports::StrCat(
+ "Parent property expression must be ", kQualifiedIdExpr));
+ }
+
std::sort(
child_scored_document_hits.begin(), child_scored_document_hits.end(),
ScoredDocumentHitComparator(
@@ -59,7 +65,7 @@ JoinProcessor::Join(
// ScoredDocumentHits refer to. The values in this map are vectors of child
// ScoredDocumentHits that refer to a parent DocumentId.
std::unordered_map<DocumentId, std::vector<ScoredDocumentHit>>
- parent_id_to_child_map;
+ map_joinable_qualified_id;
for (const ScoredDocumentHit& child : child_scored_document_hits) {
std::string property_content = FetchPropertyExpressionValue(
child.document_id(), join_spec.child_property_expression());
@@ -84,14 +90,21 @@ JoinProcessor::Join(
DocumentId parent_doc_id = std::move(parent_doc_id_or).ValueOrDie();
// Since we've already sorted child_scored_document_hits, just simply omit
- // if the parent_id_to_child_map[parent_doc_id].size() has reached max
+ // if the map_joinable_qualified_id[parent_doc_id].size() has reached max
// joined child count.
- if (parent_id_to_child_map[parent_doc_id].size() <
+ if (map_joinable_qualified_id[parent_doc_id].size() <
join_spec.max_joined_child_count()) {
- parent_id_to_child_map[parent_doc_id].push_back(child);
+ map_joinable_qualified_id[parent_doc_id].push_back(child);
}
}
+ return JoinChildrenFetcher(join_spec, std::move(map_joinable_qualified_id));
+}
+libtextclassifier3::StatusOr<std::vector<JoinedScoredDocumentHit>>
+JoinProcessor::Join(
+ const JoinSpecProto& join_spec,
+ std::vector<ScoredDocumentHit>&& parent_scored_document_hits,
+ const JoinChildrenFetcher& join_children_fetcher) {
std::unique_ptr<AggregationScorer> aggregation_scorer =
AggregationScorer::Create(join_spec);
@@ -100,23 +113,11 @@ JoinProcessor::Join(
// Step 2: iterate through all parent documentIds and construct
// JoinedScoredDocumentHit for each by looking up
- // parent_id_to_child_map.
+ // join_children_fetcher.
for (ScoredDocumentHit& parent : parent_scored_document_hits) {
- DocumentId parent_doc_id = kInvalidDocumentId;
- if (join_spec.parent_property_expression() == kQualifiedIdExpr) {
- parent_doc_id = parent.document_id();
- } else {
- // TODO(b/256022027): So far we only support kQualifiedIdExpr for
- // parent_property_expression, we could support more.
- return absl_ports::UnimplementedError(absl_ports::StrCat(
- "Parent property expression must be ", kQualifiedIdExpr));
- }
-
- std::vector<ScoredDocumentHit> children;
- if (auto iter = parent_id_to_child_map.find(parent_doc_id);
- iter != parent_id_to_child_map.end()) {
- children = std::move(iter->second);
- }
+ ICING_ASSIGN_OR_RETURN(
+ std::vector<ScoredDocumentHit> children,
+ join_children_fetcher.GetChildren(parent.document_id()));
double final_score = aggregation_scorer->GetScore(parent, children);
joined_scored_document_hits.emplace_back(final_score, std::move(parent),