summaryrefslogtreecommitdiff
path: root/native/annotator/datetime/grammar-parser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'native/annotator/datetime/grammar-parser.cc')
-rw-r--r--native/annotator/datetime/grammar-parser.cc11
1 files changed, 9 insertions, 2 deletions
diff --git a/native/annotator/datetime/grammar-parser.cc b/native/annotator/datetime/grammar-parser.cc
index 6d51c19..c49a01a 100644
--- a/native/annotator/datetime/grammar-parser.cc
+++ b/native/annotator/datetime/grammar-parser.cc
@@ -20,6 +20,7 @@
#include <unordered_set>
#include "annotator/datetime/datetime-grounder.h"
+#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/grammar/analyzer.h"
#include "utils/grammar/evaluated-derivation.h"
@@ -33,11 +34,13 @@ namespace libtextclassifier3 {
GrammarDatetimeParser::GrammarDatetimeParser(
const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
- const float target_classification_score, const float priority_score)
+ const float target_classification_score, const float priority_score,
+ ModeFlag enabled_modes)
: analyzer_(analyzer),
datetime_grounder_(datetime_grounder),
target_classification_score_(target_classification_score),
- priority_score_(priority_score) {}
+ priority_score_(priority_score),
+ enabled_modes_(enabled_modes) {}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
@@ -54,6 +57,10 @@ StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
+ if (!(enabled_modes_ & mode)) {
+ return std::vector<DatetimeParseResultSpan>();
+ }
+
std::vector<DatetimeParseResultSpan> results;
UnsafeArena arena(/*block_size=*/16 << 10);
std::vector<Locale> locales = locale_list.GetLocales();