summaryrefslogtreecommitdiff
path: root/native/annotator/datetime
diff options
context:
space:
mode:
Diffstat (limited to 'native/annotator/datetime')
-rw-r--r--native/annotator/datetime/grammar-parser.cc11
-rw-r--r--native/annotator/datetime/grammar-parser.h5
-rw-r--r--native/annotator/datetime/grammar-parser_test.cc26
3 files changed, 32 insertions, 10 deletions
diff --git a/native/annotator/datetime/grammar-parser.cc b/native/annotator/datetime/grammar-parser.cc
index 6d51c19..c49a01a 100644
--- a/native/annotator/datetime/grammar-parser.cc
+++ b/native/annotator/datetime/grammar-parser.cc
@@ -20,6 +20,7 @@
#include <unordered_set>
#include "annotator/datetime/datetime-grounder.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/grammar/analyzer.h"
#include "utils/grammar/evaluated-derivation.h"
@@ -33,11 +34,13 @@ namespace libtextclassifier3 {
GrammarDatetimeParser::GrammarDatetimeParser(
const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
- const float target_classification_score, const float priority_score)
+ const float target_classification_score, const float priority_score,
+ ModeFlag enabled_modes)
: analyzer_(analyzer),
datetime_grounder_(datetime_grounder),
target_classification_score_(target_classification_score),
- priority_score_(priority_score) {}
+ priority_score_(priority_score),
+ enabled_modes_(enabled_modes) {}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
@@ -54,6 +57,10 @@ StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
+ if (!(enabled_modes_ & mode)) {
+ return std::vector<DatetimeParseResultSpan>();
+ }
+
std::vector<DatetimeParseResultSpan> results;
UnsafeArena arena(/*block_size=*/16 << 10);
std::vector<Locale> locales = locale_list.GetLocales();
diff --git a/native/annotator/datetime/grammar-parser.h b/native/annotator/datetime/grammar-parser.h
index 6ff4b46..35da843 100644
--- a/native/annotator/datetime/grammar-parser.h
+++ b/native/annotator/datetime/grammar-parser.h
@@ -22,6 +22,7 @@
#include "annotator/datetime/datetime-grounder.h"
#include "annotator/datetime/parser.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/base/statusor.h"
#include "utils/grammar/analyzer.h"
@@ -37,7 +38,8 @@ class GrammarDatetimeParser : public DatetimeParser {
explicit GrammarDatetimeParser(const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
const float target_classification_score,
- const float priority_score);
+ const float priority_score,
+ ModeFlag enabled_modes);
// Parses the dates in 'input' and fills result. Makes sure that the results
// do not overlap.
@@ -61,6 +63,7 @@ class GrammarDatetimeParser : public DatetimeParser {
const DatetimeGrounder& datetime_grounder_;
const float target_classification_score_;
const float priority_score_;
+ const ModeFlag enabled_modes_;
};
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/grammar-parser_test.cc b/native/annotator/datetime/grammar-parser_test.cc
index cf2dffd..b8a270d 100644
--- a/native/annotator/datetime/grammar-parser_test.cc
+++ b/native/annotator/datetime/grammar-parser_test.cc
@@ -22,6 +22,7 @@
#include "annotator/datetime/datetime-grounder.h"
#include "annotator/datetime/testing/base-parser-test.h"
#include "annotator/datetime/testing/datetime-component-builder.h"
+#include "annotator/model_generated.h"
#include "utils/grammar/analyzer.h"
#include "utils/jvm-test-utils.h"
#include "utils/test-data-test-utils.h"
@@ -42,7 +43,15 @@ std::string ReadFile(const std::string& file_name) {
class GrammarDatetimeParserTest : public DateTimeParserTest {
public:
- void SetUp() override {
+ void SetUp() override { ResetParser(ModeFlag_ALL); }
+
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const override {
+ return parser_.get();
+ }
+
+ protected:
+ void ResetParser(ModeFlag enabled_modes) {
grammar_buffer_ = ReadFile(GetModelPath() + "datetime.fb");
unilib_ = CreateUniLibForTesting();
calendarlib_ = CreateCalendarLibForTesting();
@@ -51,12 +60,8 @@ class GrammarDatetimeParserTest : public DateTimeParserTest {
datetime_grounder_ = std::make_unique<DatetimeGrounder>(calendarlib_.get());
parser_.reset(new GrammarDatetimeParser(*analyzer_, *datetime_grounder_,
/*target_classification_score=*/1.0,
- /*priority_score=*/1.0));
- }
-
- // Exposes the date time parser for tests and evaluations.
- const DatetimeParser* DatetimeParserForTests() const override {
- return parser_.get();
+ /*priority_score=*/1.0,
+ enabled_modes));
}
private:
@@ -486,6 +491,13 @@ TEST_F(GrammarDatetimeParserTest, Parse) {
.Build()}));
}
+TEST_F(GrammarDatetimeParserTest, NotEnabledModeHasNoResult) {
+ ResetParser(ModeFlag_SELECTION);
+ // `DateTimeParserTest` implementation parses the input under the ANNOTATION
+ // mode.
+ EXPECT_TRUE(HasNoResult("{January 1, 1988}"));
+}
+
TEST_F(GrammarDatetimeParserTest, DateValidation) {
EXPECT_TRUE(ParsesCorrectly(
"{01/02/2020}", 1577919600000, GRANULARITY_DAY,