summaryrefslogtreecommitdiff
path: root/native/annotator/duration/duration_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'native/annotator/duration/duration_test.cc')
-rw-r--r--native/annotator/duration/duration_test.cc174
1 files changed, 125 insertions, 49 deletions
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
index 7c07a72..f726058 100644
--- a/native/annotator/duration/duration_test.cc
+++ b/native/annotator/duration/duration_test.cc
@@ -16,6 +16,7 @@
#include "annotator/duration/duration.h"
+#include <cstddef>
#include <string>
#include <vector>
@@ -37,41 +38,61 @@ using testing::ElementsAre;
using testing::Field;
using testing::IsEmpty;
-const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() {
- static const flatbuffers::DetachedBuffer* options_data = []() {
- DurationAnnotatorOptionsT options;
- options.enabled = true;
+namespace {
+const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
+ DurationAnnotatorOptionsT options;
+ options.enabled = true;
+ options.enabled_modes = enabled_modes;
- options.week_expressions.push_back("week");
- options.week_expressions.push_back("weeks");
+ options.week_expressions.push_back("week");
+ options.week_expressions.push_back("weeks");
- options.day_expressions.push_back("day");
- options.day_expressions.push_back("days");
+ options.day_expressions.push_back("day");
+ options.day_expressions.push_back("days");
- options.hour_expressions.push_back("hour");
- options.hour_expressions.push_back("hours");
+ options.hour_expressions.push_back("hour");
+ options.hour_expressions.push_back("hours");
- options.minute_expressions.push_back("minute");
- options.minute_expressions.push_back("minutes");
+ options.minute_expressions.push_back("minute");
+ options.minute_expressions.push_back("minutes");
- options.second_expressions.push_back("second");
- options.second_expressions.push_back("seconds");
+ options.second_expressions.push_back("second");
+ options.second_expressions.push_back("seconds");
- options.filler_expressions.push_back("and");
- options.filler_expressions.push_back("a");
- options.filler_expressions.push_back("an");
- options.filler_expressions.push_back("one");
+ options.filler_expressions.push_back("and");
+ options.filler_expressions.push_back("a");
+ options.filler_expressions.push_back("an");
+ options.filler_expressions.push_back("one");
- options.half_expressions.push_back("half");
+ options.half_expressions.push_back("half");
- options.sub_token_separator_codepoints.push_back('-');
+ options.sub_token_separator_codepoints.push_back('-');
- flatbuffers::FlatBufferBuilder builder;
- builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
- return new flatbuffers::DetachedBuffer(builder.Release());
- }();
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
+ return new flatbuffers::DetachedBuffer(builder.Release());
+}
+} // namespace
- return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
+const DurationAnnotatorOptions* TestingDurationAnnotatorOptions(
+ ModeFlag enabled_modes) {
+ static const flatbuffers::DetachedBuffer* options_data_all =
+ CreateOptionsData(ModeFlag_ALL);
+ static const flatbuffers::DetachedBuffer* options_data_selection =
+ CreateOptionsData(ModeFlag_SELECTION);
+ static const flatbuffers::DetachedBuffer* options_data_no_selection =
+ CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION);
+
+ if (enabled_modes == ModeFlag_SELECTION) {
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(
+ options_data_selection->data());
+ } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) {
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(
+ options_data_no_selection->data());
+ } else {
+ return flatbuffers::GetRoot<DurationAnnotatorOptions>(
+ options_data_all->data());
+ }
}
std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
@@ -103,10 +124,10 @@ std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
class DurationAnnotatorTest : public ::testing::Test {
protected:
- DurationAnnotatorTest()
+ explicit DurationAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL)
: INIT_UNILIB_FOR_TESTING(unilib_),
feature_processor_(BuildFeatureProcessor(&unilib_)),
- duration_annotator_(TestingDurationAnnotatorOptions(),
+ duration_annotator_(TestingDurationAnnotatorOptions(enabled_modes),
feature_processor_.get(), &unilib_) {}
std::vector<Token> Tokenize(const UnicodeText& text) {
@@ -118,6 +139,19 @@ class DurationAnnotatorTest : public ::testing::Test {
DurationAnnotator duration_annotator_;
};
+class DurationAnnotatorForAnnotationAndClassificationTest
+ : public DurationAnnotatorTest {
+ protected:
+ DurationAnnotatorForAnnotationAndClassificationTest()
+ : DurationAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {}
+};
+
+class DurationAnnotatorForSelectionTest : public DurationAnnotatorTest {
+ protected:
+ DurationAnnotatorForSelectionTest()
+ : DurationAnnotatorTest(ModeFlag_SELECTION) {}
+};
+
TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
ClassificationResult classification;
EXPECT_TRUE(duration_annotator_.ClassifyText(
@@ -129,6 +163,14 @@ TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
}
+TEST_F(DurationAnnotatorForSelectionTest,
+ ClassifyTextDisabledClassificationReturnsFalse) {
+ ClassificationResult classification;
+ EXPECT_FALSE(duration_annotator_.ClassifyText(
+ UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
+ AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
+}
+
TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
ClassificationResult classification;
EXPECT_TRUE(duration_annotator_.ClassifyText(
@@ -152,7 +194,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
@@ -165,13 +208,26 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
15 * 60 * 1000)))))));
}
+TEST_F(DurationAnnotatorForAnnotationAndClassificationTest,
+ FindsAllDisabledModeReturnsNoResults) {
+ const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
+
+ EXPECT_THAT(result, IsEmpty());
+}
+
TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
const UnicodeText text =
UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -190,7 +246,8 @@ TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
@@ -209,7 +266,8 @@ TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -228,7 +286,8 @@ TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -247,7 +306,8 @@ TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,
@@ -266,7 +326,8 @@ TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -286,7 +347,8 @@ TEST_F(DurationAnnotatorTest,
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -305,7 +367,8 @@ TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -323,7 +386,8 @@ TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
ASSERT_EQ(result.size(), 0);
}
@@ -334,7 +398,8 @@ TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -352,7 +417,8 @@ TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -383,7 +449,8 @@ TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -402,7 +469,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -422,7 +490,8 @@ TEST_F(DurationAnnotatorTest,
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -440,7 +509,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -458,7 +528,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -475,7 +546,8 @@ TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -540,7 +612,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -558,7 +631,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(
result,
@@ -576,7 +650,8 @@ TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_ANNOTATION, &result));
EXPECT_THAT(result, IsEmpty());
}
@@ -586,7 +661,8 @@ TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
- text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
+ ModeFlag_SELECTION, &result));
EXPECT_THAT(
result,