1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "annotator/datetime/grammar-parser.h"
#include <set>
#include <unordered_set>
#include "annotator/datetime/datetime-grounder.h"
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/grammar/analyzer.h"
#include "utils/grammar/evaluated-derivation.h"
#include "utils/grammar/parsing/derivation.h"
using ::libtextclassifier3::grammar::EvaluatedDerivation;
using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
namespace libtextclassifier3 {
GrammarDatetimeParser::GrammarDatetimeParser(
const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
const float target_classification_score, const float priority_score,
ModeFlag enabled_modes)
: analyzer_(analyzer),
datetime_grounder_(datetime_grounder),
target_classification_score_(target_classification_score),
priority_score_(priority_score),
enabled_modes_(enabled_modes) {}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
reference_time_ms_utc, reference_timezone, locale_list, mode,
annotation_usecase, anchor_start_end);
}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
if (!(enabled_modes_ & mode)) {
return std::vector<DatetimeParseResultSpan>();
}
std::vector<DatetimeParseResultSpan> results;
UnsafeArena arena(/*block_size=*/16 << 10);
std::vector<Locale> locales = locale_list.GetLocales();
// If the locale list is empty then datetime regex expression will still
// execute but in grammar based parser the rules are associated with local
// and engine will not run if the locale list is empty. In an unlikely
// scenario when locale is not mentioned fallback to en-*.
if (locales.empty()) {
locales.emplace_back(Locale::FromBCP47("en"));
}
TC3_ASSIGN_OR_RETURN(
const std::vector<EvaluatedDerivation> evaluated_derivations,
analyzer_.Parse(input, locales, &arena,
/*deduplicate_derivations=*/false));
std::vector<EvaluatedDerivation> valid_evaluated_derivations;
for (const EvaluatedDerivation& evaluated_derivation :
evaluated_derivations) {
if (evaluated_derivation.value) {
if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
const UngroundedDatetime* ungrounded_datetime =
evaluated_derivation.value->Table<UngroundedDatetime>();
if (datetime_grounder_.IsValidUngroundedDatetime(ungrounded_datetime)) {
valid_evaluated_derivations.emplace_back(evaluated_derivation);
}
}
}
}
valid_evaluated_derivations =
grammar::DeduplicateDerivations(valid_evaluated_derivations);
for (const EvaluatedDerivation& evaluated_derivation :
valid_evaluated_derivations) {
if (evaluated_derivation.value) {
if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
const UngroundedDatetime* ungrounded_datetime =
evaluated_derivation.value->Table<UngroundedDatetime>();
if ((ungrounded_datetime->annotation_usecases() &
(1 << annotation_usecase)) == 0) {
continue;
}
const StatusOr<std::vector<DatetimeParseResult>>&
datetime_parse_results = datetime_grounder_.Ground(
reference_time_ms_utc, reference_timezone,
locale_list.GetReferenceLocale(), ungrounded_datetime);
TC3_ASSIGN_OR_RETURN(
const std::vector<DatetimeParseResult>& parse_datetime,
datetime_parse_results);
DatetimeParseResultSpan datetime_parse_result_span;
datetime_parse_result_span.target_classification_score =
target_classification_score_;
datetime_parse_result_span.priority_score = priority_score_;
datetime_parse_result_span.data.reserve(parse_datetime.size());
datetime_parse_result_span.data.insert(
datetime_parse_result_span.data.end(), parse_datetime.begin(),
parse_datetime.end());
datetime_parse_result_span.span =
evaluated_derivation.parse_tree->codepoint_span;
results.emplace_back(datetime_parse_result_span);
}
}
}
return results;
}
} // namespace libtextclassifier3
|