diff options
Diffstat (limited to 'native/annotator/duration/duration_test.cc')
-rw-r--r-- | native/annotator/duration/duration_test.cc | 174 |
1 files changed, 125 insertions, 49 deletions
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc index 7c07a72..f726058 100644 --- a/native/annotator/duration/duration_test.cc +++ b/native/annotator/duration/duration_test.cc @@ -16,6 +16,7 @@ #include "annotator/duration/duration.h" +#include <cstddef> #include <string> #include <vector> @@ -37,41 +38,61 @@ using testing::ElementsAre; using testing::Field; using testing::IsEmpty; -const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() { - static const flatbuffers::DetachedBuffer* options_data = []() { - DurationAnnotatorOptionsT options; - options.enabled = true; +namespace { +const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) { + DurationAnnotatorOptionsT options; + options.enabled = true; + options.enabled_modes = enabled_modes; - options.week_expressions.push_back("week"); - options.week_expressions.push_back("weeks"); + options.week_expressions.push_back("week"); + options.week_expressions.push_back("weeks"); - options.day_expressions.push_back("day"); - options.day_expressions.push_back("days"); + options.day_expressions.push_back("day"); + options.day_expressions.push_back("days"); - options.hour_expressions.push_back("hour"); - options.hour_expressions.push_back("hours"); + options.hour_expressions.push_back("hour"); + options.hour_expressions.push_back("hours"); - options.minute_expressions.push_back("minute"); - options.minute_expressions.push_back("minutes"); + options.minute_expressions.push_back("minute"); + options.minute_expressions.push_back("minutes"); - options.second_expressions.push_back("second"); - options.second_expressions.push_back("seconds"); + options.second_expressions.push_back("second"); + options.second_expressions.push_back("seconds"); - options.filler_expressions.push_back("and"); - options.filler_expressions.push_back("a"); - options.filler_expressions.push_back("an"); - options.filler_expressions.push_back("one"); + options.filler_expressions.push_back("and"); + options.filler_expressions.push_back("a"); + options.filler_expressions.push_back("an"); + options.filler_expressions.push_back("one"); - options.half_expressions.push_back("half"); + options.half_expressions.push_back("half"); - options.sub_token_separator_codepoints.push_back('-'); + options.sub_token_separator_codepoints.push_back('-'); - flatbuffers::FlatBufferBuilder builder; - builder.Finish(DurationAnnotatorOptions::Pack(builder, &options)); - return new flatbuffers::DetachedBuffer(builder.Release()); - }(); + flatbuffers::FlatBufferBuilder builder; + builder.Finish(DurationAnnotatorOptions::Pack(builder, &options)); + return new flatbuffers::DetachedBuffer(builder.Release()); +} +} // namespace - return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data()); +const DurationAnnotatorOptions* TestingDurationAnnotatorOptions( + ModeFlag enabled_modes) { + static const flatbuffers::DetachedBuffer* options_data_all = + CreateOptionsData(ModeFlag_ALL); + static const flatbuffers::DetachedBuffer* options_data_selection = + CreateOptionsData(ModeFlag_SELECTION); + static const flatbuffers::DetachedBuffer* options_data_no_selection = + CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION); + + if (enabled_modes == ModeFlag_SELECTION) { + return flatbuffers::GetRoot<DurationAnnotatorOptions>( + options_data_selection->data()); + } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) { + return flatbuffers::GetRoot<DurationAnnotatorOptions>( + options_data_no_selection->data()); + } else { + return flatbuffers::GetRoot<DurationAnnotatorOptions>( + options_data_all->data()); + } } std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) { @@ -103,10 +124,10 @@ std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) { class DurationAnnotatorTest : public ::testing::Test { protected: - DurationAnnotatorTest() + explicit DurationAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL) : INIT_UNILIB_FOR_TESTING(unilib_), feature_processor_(BuildFeatureProcessor(&unilib_)), - duration_annotator_(TestingDurationAnnotatorOptions(), + duration_annotator_(TestingDurationAnnotatorOptions(enabled_modes), feature_processor_.get(), &unilib_) {} std::vector<Token> Tokenize(const UnicodeText& text) { @@ -118,6 +139,19 @@ class DurationAnnotatorTest : public ::testing::Test { DurationAnnotator duration_annotator_; }; +class DurationAnnotatorForAnnotationAndClassificationTest + : public DurationAnnotatorTest { + protected: + DurationAnnotatorForAnnotationAndClassificationTest() + : DurationAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {} +}; + +class DurationAnnotatorForSelectionTest : public DurationAnnotatorTest { + protected: + DurationAnnotatorForSelectionTest() + : DurationAnnotatorTest(ModeFlag_SELECTION) {} +}; + TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) { ClassificationResult classification; EXPECT_TRUE(duration_annotator_.ClassifyText( @@ -129,6 +163,14 @@ TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) { Field(&ClassificationResult::duration_ms, 15 * 60 * 1000))); } +TEST_F(DurationAnnotatorForSelectionTest, + ClassifyTextDisabledClassificationReturnsFalse) { + ClassificationResult classification; + EXPECT_FALSE(duration_annotator_.ClassifyText( + UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24}, + AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification)); +} + TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) { ClassificationResult classification; EXPECT_TRUE(duration_annotator_.ClassifyText( @@ -152,7 +194,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, @@ -165,13 +208,26 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) { 15 * 60 * 1000))))))); } +TEST_F(DurationAnnotatorForAnnotationAndClassificationTest, + FindsAllDisabledModeReturnsNoResults) { + const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?"); + std::vector<Token> tokens = Tokenize(text); + std::vector<AnnotatedSpan> result; + EXPECT_TRUE(duration_annotator_.FindAll( + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); + + EXPECT_THAT(result, IsEmpty()); +} + TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) { const UnicodeText text = UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?"); std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -190,7 +246,8 @@ TEST_F(DurationAnnotatorTest, FindsComposedDuration) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, @@ -209,7 +266,8 @@ TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -228,7 +286,8 @@ TEST_F(DurationAnnotatorTest, FindsHalfAnHour) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -247,7 +306,8 @@ TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, @@ -266,7 +326,8 @@ TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -286,7 +347,8 @@ TEST_F(DurationAnnotatorTest, std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -305,7 +367,8 @@ TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -323,7 +386,8 @@ TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); ASSERT_EQ(result.size(), 0); } @@ -334,7 +398,8 @@ TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -352,7 +417,8 @@ TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -383,7 +449,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -402,7 +469,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -422,7 +490,8 @@ TEST_F(DurationAnnotatorTest, std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -440,7 +509,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -458,7 +528,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -475,7 +546,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -540,7 +612,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -558,7 +631,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT( result, @@ -576,7 +650,8 @@ TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_ANNOTATION, &result)); EXPECT_THAT(result, IsEmpty()); } @@ -586,7 +661,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) { std::vector<Token> tokens = Tokenize(text); std::vector<AnnotatedSpan> result; EXPECT_TRUE(duration_annotator_.FindAll( - text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result)); + text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, + ModeFlag_SELECTION, &result)); EXPECT_THAT( result, |