aboutsummaryrefslogtreecommitdiff
path: root/icing/testing/common-matchers.h
diff options
context:
space:
mode:
Diffstat (limited to 'icing/testing/common-matchers.h')
-rw-r--r--icing/testing/common-matchers.h88
1 files changed, 78 insertions, 10 deletions
diff --git a/icing/testing/common-matchers.h b/icing/testing/common-matchers.h
index f2738e3..e090800 100644
--- a/icing/testing/common-matchers.h
+++ b/icing/testing/common-matchers.h
@@ -15,7 +15,10 @@
#ifndef ICING_TESTING_COMMON_MATCHERS_H_
#define ICING_TESTING_COMMON_MATCHERS_H_
+#include <algorithm>
#include <cmath>
+#include <string>
+#include <vector>
#include "icing/text_classifier/lib3/utils/base/status.h"
#include "icing/text_classifier/lib3/utils/base/status_macros.h"
@@ -31,6 +34,7 @@
#include "icing/proto/status.pb.h"
#include "icing/schema/schema-store.h"
#include "icing/schema/section.h"
+#include "icing/scoring/scored-document-hit.h"
#include "icing/util/status-macros.h"
namespace icing {
@@ -104,19 +108,73 @@ MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id,
term_frequency_as_expected;
}
+class ScoredDocumentHitFormatter {
+ public:
+ std::string operator()(const ScoredDocumentHit& scored_document_hit) {
+ return IcingStringUtil::StringPrintf(
+ "(document_id=%d, hit_section_id_mask=%" PRId64 ", score=%.2f)",
+ scored_document_hit.document_id(),
+ scored_document_hit.hit_section_id_mask(), scored_document_hit.score());
+ }
+};
+
+class ScoredDocumentHitEqualComparator {
+ public:
+ bool operator()(const ScoredDocumentHit& lhs,
+ const ScoredDocumentHit& rhs) const {
+ return lhs.document_id() == rhs.document_id() &&
+ lhs.hit_section_id_mask() == rhs.hit_section_id_mask() &&
+ std::fabs(lhs.score() - rhs.score()) < 1e-6;
+ }
+};
+
// Used to match a ScoredDocumentHit
MATCHER_P(EqualsScoredDocumentHit, expected_scored_document_hit, "") {
- if (arg.document_id() != expected_scored_document_hit.document_id() ||
- arg.hit_section_id_mask() !=
- expected_scored_document_hit.hit_section_id_mask() ||
- std::fabs(arg.score() - expected_scored_document_hit.score()) > 1e-6) {
+ ScoredDocumentHitEqualComparator equal_comparator;
+ if (!equal_comparator(arg, expected_scored_document_hit)) {
+ ScoredDocumentHitFormatter formatter;
+ *result_listener << "Expected: " << formatter(expected_scored_document_hit)
+ << ". Actual: " << formatter(arg);
+ return false;
+ }
+ return true;
+}
+
+// Used to match a JoinedScoredDocumentHit
+MATCHER_P(EqualsJoinedScoredDocumentHit, expected_joined_scored_document_hit,
+ "") {
+ ScoredDocumentHitEqualComparator equal_comparator;
+ if (std::fabs(arg.final_score() -
+ expected_joined_scored_document_hit.final_score()) > 1e-6 ||
+ !equal_comparator(
+ arg.parent_scored_document_hit(),
+ expected_joined_scored_document_hit.parent_scored_document_hit()) ||
+ arg.child_scored_document_hits().size() !=
+ expected_joined_scored_document_hit.child_scored_document_hits()
+ .size() ||
+ !std::equal(
+ arg.child_scored_document_hits().cbegin(),
+ arg.child_scored_document_hits().cend(),
+ expected_joined_scored_document_hit.child_scored_document_hits()
+ .cbegin(),
+ equal_comparator)) {
+ ScoredDocumentHitFormatter formatter;
+
*result_listener << IcingStringUtil::StringPrintf(
- "Expected: document_id=%d, hit_section_id_mask=%d, score=%.2f. Actual: "
- "document_id=%d, hit_section_id_mask=%d, score=%.2f",
- expected_scored_document_hit.document_id(),
- expected_scored_document_hit.hit_section_id_mask(),
- expected_scored_document_hit.score(), arg.document_id(),
- arg.hit_section_id_mask(), arg.score());
+ "Expected: final_score=%.2f, parent_scored_document_hit=%s, "
+ "child_scored_document_hits=[%s]. Actual: final_score=%.2f, "
+ "parent_scored_document_hit=%s, child_scored_document_hits=[%s]",
+ expected_joined_scored_document_hit.final_score(),
+ formatter(
+ expected_joined_scored_document_hit.parent_scored_document_hit())
+ .c_str(),
+ absl_ports::StrJoin(
+ expected_joined_scored_document_hit.child_scored_document_hits(),
+ ",", formatter)
+ .c_str(),
+ arg.final_score(), formatter(arg.parent_scored_document_hit()).c_str(),
+ absl_ports::StrJoin(arg.child_scored_document_hits(), ",", formatter)
+ .c_str());
return false;
}
return true;
@@ -435,6 +493,11 @@ MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") {
actual_copy.clear_debug_info();
for (SearchResultProto::ResultProto& result :
*actual_copy.mutable_results()) {
+ // Joined results
+ for (SearchResultProto::ResultProto& joined_result :
+ *result.mutable_joined_results()) {
+ joined_result.clear_score();
+ }
result.clear_score();
}
@@ -443,6 +506,11 @@ MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") {
expected_copy.clear_debug_info();
for (SearchResultProto::ResultProto& result :
*expected_copy.mutable_results()) {
+ // Joined results
+ for (SearchResultProto::ResultProto& joined_result :
+ *result.mutable_joined_results()) {
+ joined_result.clear_score();
+ }
result.clear_score();
}
return ExplainMatchResult(portable_equals_proto::EqualsProto(expected_copy),