diff options
Diffstat (limited to 'icing/testing/common-matchers.h')
-rw-r--r-- | icing/testing/common-matchers.h | 88 |
1 files changed, 78 insertions, 10 deletions
diff --git a/icing/testing/common-matchers.h b/icing/testing/common-matchers.h index f2738e3..e090800 100644 --- a/icing/testing/common-matchers.h +++ b/icing/testing/common-matchers.h @@ -15,7 +15,10 @@ #ifndef ICING_TESTING_COMMON_MATCHERS_H_ #define ICING_TESTING_COMMON_MATCHERS_H_ +#include <algorithm> #include <cmath> +#include <string> +#include <vector> #include "icing/text_classifier/lib3/utils/base/status.h" #include "icing/text_classifier/lib3/utils/base/status_macros.h" @@ -31,6 +34,7 @@ #include "icing/proto/status.pb.h" #include "icing/schema/schema-store.h" #include "icing/schema/section.h" +#include "icing/scoring/scored-document-hit.h" #include "icing/util/status-macros.h" namespace icing { @@ -104,19 +108,73 @@ MATCHER_P2(EqualsDocHitInfoWithTermFrequency, document_id, term_frequency_as_expected; } +class ScoredDocumentHitFormatter { + public: + std::string operator()(const ScoredDocumentHit& scored_document_hit) { + return IcingStringUtil::StringPrintf( + "(document_id=%d, hit_section_id_mask=%" PRId64 ", score=%.2f)", + scored_document_hit.document_id(), + scored_document_hit.hit_section_id_mask(), scored_document_hit.score()); + } +}; + +class ScoredDocumentHitEqualComparator { + public: + bool operator()(const ScoredDocumentHit& lhs, + const ScoredDocumentHit& rhs) const { + return lhs.document_id() == rhs.document_id() && + lhs.hit_section_id_mask() == rhs.hit_section_id_mask() && + std::fabs(lhs.score() - rhs.score()) < 1e-6; + } +}; + // Used to match a ScoredDocumentHit MATCHER_P(EqualsScoredDocumentHit, expected_scored_document_hit, "") { - if (arg.document_id() != expected_scored_document_hit.document_id() || - arg.hit_section_id_mask() != - expected_scored_document_hit.hit_section_id_mask() || - std::fabs(arg.score() - expected_scored_document_hit.score()) > 1e-6) { + ScoredDocumentHitEqualComparator equal_comparator; + if (!equal_comparator(arg, expected_scored_document_hit)) { + ScoredDocumentHitFormatter formatter; + *result_listener << "Expected: " << formatter(expected_scored_document_hit) + << ". Actual: " << formatter(arg); + return false; + } + return true; +} + +// Used to match a JoinedScoredDocumentHit +MATCHER_P(EqualsJoinedScoredDocumentHit, expected_joined_scored_document_hit, + "") { + ScoredDocumentHitEqualComparator equal_comparator; + if (std::fabs(arg.final_score() - + expected_joined_scored_document_hit.final_score()) > 1e-6 || + !equal_comparator( + arg.parent_scored_document_hit(), + expected_joined_scored_document_hit.parent_scored_document_hit()) || + arg.child_scored_document_hits().size() != + expected_joined_scored_document_hit.child_scored_document_hits() + .size() || + !std::equal( + arg.child_scored_document_hits().cbegin(), + arg.child_scored_document_hits().cend(), + expected_joined_scored_document_hit.child_scored_document_hits() + .cbegin(), + equal_comparator)) { + ScoredDocumentHitFormatter formatter; + *result_listener << IcingStringUtil::StringPrintf( - "Expected: document_id=%d, hit_section_id_mask=%d, score=%.2f. Actual: " - "document_id=%d, hit_section_id_mask=%d, score=%.2f", - expected_scored_document_hit.document_id(), - expected_scored_document_hit.hit_section_id_mask(), - expected_scored_document_hit.score(), arg.document_id(), - arg.hit_section_id_mask(), arg.score()); + "Expected: final_score=%.2f, parent_scored_document_hit=%s, " + "child_scored_document_hits=[%s]. Actual: final_score=%.2f, " + "parent_scored_document_hit=%s, child_scored_document_hits=[%s]", + expected_joined_scored_document_hit.final_score(), + formatter( + expected_joined_scored_document_hit.parent_scored_document_hit()) + .c_str(), + absl_ports::StrJoin( + expected_joined_scored_document_hit.child_scored_document_hits(), + ",", formatter) + .c_str(), + arg.final_score(), formatter(arg.parent_scored_document_hit()).c_str(), + absl_ports::StrJoin(arg.child_scored_document_hits(), ",", formatter) + .c_str()); return false; } return true; @@ -435,6 +493,11 @@ MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") { actual_copy.clear_debug_info(); for (SearchResultProto::ResultProto& result : *actual_copy.mutable_results()) { + // Joined results + for (SearchResultProto::ResultProto& joined_result : + *result.mutable_joined_results()) { + joined_result.clear_score(); + } result.clear_score(); } @@ -443,6 +506,11 @@ MATCHER_P(EqualsSearchResultIgnoreStatsAndScores, expected, "") { expected_copy.clear_debug_info(); for (SearchResultProto::ResultProto& result : *expected_copy.mutable_results()) { + // Joined results + for (SearchResultProto::ResultProto& joined_result : + *result.mutable_joined_results()) { + joined_result.clear_score(); + } result.clear_score(); } return ExplainMatchResult(portable_equals_proto::EqualsProto(expected_copy), |