summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorandroid-build-team Robot <android-build-team-robot@google.com>2020-05-29 01:06:48 +0000
committerandroid-build-team Robot <android-build-team-robot@google.com>2020-05-29 01:06:48 +0000
commit0d4718caea1b5825131379266caecb35b0e1817d (patch)
tree923ac7a59ecb6513948acb04c55e982dc6f4da47
parent02ac32ed54cc201c553c3411724eca962df1ffe8 (diff)
parentd1a33389176cea2f14bd877ea983b5569ef0348a (diff)
downloadlibtextclassifier-0d4718caea1b5825131379266caecb35b0e1817d.tar.gz
Snap for 6538275 from d1a33389176cea2f14bd877ea983b5569ef0348a to rvc-d1-release
Change-Id: I4c7c012ed6fa18310d333ed6e5b01ad9f732791d
-rw-r--r--java/tests/instrumentation/AndroidTest.xml33
-rw-r--r--native/actions/actions-suggestions.cc75
-rw-r--r--native/annotator/annotator.cc27
-rw-r--r--native/annotator/datetime/extractor.cc1
-rw-r--r--native/annotator/experimental/experimental-dummy.h3
-rwxr-xr-xnative/annotator/experimental/experimental.fbs16
-rw-r--r--native/annotator/grammar/dates/utils/date-match.cc28
-rw-r--r--native/utils/calendar/calendar-common.h2
-rw-r--r--native/utils/flatbuffers.cc213
-rw-r--r--native/utils/flatbuffers.h313
-rw-r--r--native/utils/grammar/utils/rules.cc7
-rw-r--r--native/utils/grammar/utils/rules.h4
-rw-r--r--native/utils/intents/jni.cc30
-rw-r--r--native/utils/lua-utils.cc49
-rw-r--r--native/utils/lua-utils.h6
-rw-r--r--native/utils/variant.cc12
-rw-r--r--native/utils/variant.h158
-rw-r--r--notification/tests/AndroidTest.xml33
18 files changed, 672 insertions, 338 deletions
diff --git a/java/tests/instrumentation/AndroidTest.xml b/java/tests/instrumentation/AndroidTest.xml
new file mode 100644
index 0000000..e02a338
--- /dev/null
+++ b/java/tests/instrumentation/AndroidTest.xml
@@ -0,0 +1,33 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2020 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<!-- This test config file is auto-generated. -->
+<configuration description="Runs TextClassifierServiceTest.">
+ <option name="test-suite-tag" value="apct" />
+ <option name="test-suite-tag" value="apct-instrumentation" />
+ <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
+ <option name="cleanup-apks" value="true" />
+ <option name="test-file-name" value="TextClassifierServiceTest.apk" />
+ </target_preparer>
+
+ <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
+ <option name="package" value="com.android.textclassifier.tests" />
+ <option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ </test>
+
+ <object type="module_controller" class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController">
+ <option name="mainline-module-package-name" value="com.google.android.extservices" />
+ </object>
+</configuration>
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index a84f2cd..1fcd35c 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -22,6 +22,7 @@
#include "actions/types.h"
#include "actions/utils.h"
#include "actions/zlib-utils.h"
+#include "annotator/collections.h"
#include "utils/base/logging.h"
#include "utils/flatbuffers.h"
#include "utils/lua-utils.h"
@@ -50,6 +51,11 @@ const std::string& ActionsSuggestions::kSendEmailType =
*[]() { return new std::string("send_email"); }();
const std::string& ActionsSuggestions::kShareLocation =
*[]() { return new std::string("share_location"); }();
+
+// Name for a datetime annotation that only includes time but no date.
+const std::string& kTimeAnnotation =
+ *[]() { return new std::string("time"); }();
+
constexpr float kDefaultFloat = 0.0;
constexpr bool kDefaultBool = false;
constexpr int kDefaultInt = 1;
@@ -260,6 +266,7 @@ bool ActionsSuggestions::ValidateAndInitialize() {
}
}
+ // Gather annotation entities for the rules.
if (model_->annotation_actions_spec() != nullptr &&
model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
@@ -300,6 +307,18 @@ bool ActionsSuggestions::ValidateAndInitialize() {
grammar_actions_.reset(new GrammarActions(
unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
model_->smart_reply_action_type()->str()));
+
+ // Gather annotation entities for the grammars.
+ if (auto annotation_nt = model_->rules()
+ ->grammar_rules()
+ ->rules()
+ ->nonterminals()
+ ->annotation_nt()) {
+ for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
+ *annotation_nt) {
+ annotation_entity_types_.insert(entry->key()->str());
+ }
+ }
}
std::string actions_script;
@@ -689,47 +708,41 @@ bool ActionsSuggestions::SetupModelInput(
interpreter->tensor(interpreter->inputs()[param_index])->type;
const auto param_value_it = model_parameters.find(param_name);
const bool has_value = param_value_it != model_parameters.end();
- /*
- case kTfLiteInt16:
- *tflite::GetTensorData<int16_t>(input_tensor) = input_value;
- break;
- case kTfLiteInt8:
- */
switch (param_type) {
case kTfLiteFloat32:
model_executor_->SetInput<float>(
param_index,
- has_value ? param_value_it->second.FloatValue() : kDefaultFloat,
+ has_value ? param_value_it->second.Value<float>() : kDefaultFloat,
interpreter);
break;
case kTfLiteInt32:
model_executor_->SetInput<int32_t>(
param_index,
- has_value ? param_value_it->second.IntValue() : kDefaultInt,
+ has_value ? param_value_it->second.Value<int>() : kDefaultInt,
interpreter);
break;
case kTfLiteInt64:
model_executor_->SetInput<int64_t>(
param_index,
- has_value ? param_value_it->second.Int64Value() : kDefaultInt,
+ has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
interpreter);
break;
case kTfLiteUInt8:
model_executor_->SetInput<uint8_t>(
param_index,
- has_value ? param_value_it->second.UInt8Value() : kDefaultInt,
+ has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
interpreter);
break;
case kTfLiteInt8:
model_executor_->SetInput<int8_t>(
param_index,
- has_value ? param_value_it->second.Int8Value() : kDefaultInt,
+ has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
interpreter);
break;
case kTfLiteBool:
model_executor_->SetInput<bool>(
param_index,
- has_value ? param_value_it->second.BoolValue() : kDefaultBool,
+ has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
interpreter);
break;
default:
@@ -1023,6 +1036,30 @@ Conversation ActionsSuggestions::AnnotateConversation(
if (message->annotations.empty()) {
message->annotations = annotator->Annotate(
message->text, AnnotationOptionsForMessage(*message));
+ for (int i = 0; i < message->annotations.size(); i++) {
+ ClassificationResult* classification =
+ &message->annotations[i].classification.front();
+
+ // Specialize datetime annotation to time annotation if no date
+ // component is present.
+ if (classification->collection == Collections::DateTime() &&
+ classification->datetime_parse_result.IsSet()) {
+ bool has_only_time = true;
+ for (const DatetimeComponent& component :
+ classification->datetime_parse_result.datetime_components) {
+ if (component.component_type !=
+ DatetimeComponent::ComponentType::UNSPECIFIED &&
+ component.component_type <
+ DatetimeComponent::ComponentType::HOUR) {
+ has_only_time = false;
+ break;
+ }
+ }
+ if (has_only_time) {
+ classification->collection = kTimeAnnotation;
+ }
+ }
+ }
}
}
return annotated_conversation;
@@ -1224,6 +1261,13 @@ bool ActionsSuggestions::GatherActionsSuggestions(
SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
+ if (grammar_actions_ != nullptr &&
+ !grammar_actions_->SuggestActions(annotated_conversation,
+ &response->actions)) {
+ TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
+ return false;
+ }
+
int input_text_length = 0;
int num_matching_locales = 0;
for (int i = annotated_conversation.messages.size() - num_messages;
@@ -1299,13 +1343,6 @@ bool ActionsSuggestions::GatherActionsSuggestions(
return false;
}
- if (grammar_actions_ != nullptr &&
- !grammar_actions_->SuggestActions(annotated_conversation,
- &response->actions)) {
- TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
- return false;
- }
-
if (preconditions_.suppress_on_low_confidence_input &&
!regex_actions_->FilterConfidenceOutput(post_check_rules,
&response->actions)) {
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 4b11c7e..6ee983f 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -699,8 +699,8 @@ bool Annotator::InitializePersonNameEngineFromFileDescriptor(int fd, int offset,
bool Annotator::InitializeExperimentalAnnotators() {
if (ExperimentalAnnotator::IsEnabled()) {
- experimental_annotator_.reset(
- new ExperimentalAnnotator(*selection_feature_processor_, *unilib_));
+ experimental_annotator_.reset(new ExperimentalAnnotator(
+ model_->experimental_model(), *selection_feature_processor_, *unilib_));
return true;
}
return false;
@@ -2496,13 +2496,22 @@ bool Annotator::ParseAndFillInMoneyAmount(
LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
*serialized_entity_data);
if (data == nullptr) {
- TC3_LOG(ERROR)
- << "Data field is null when trying to parse Money Entity Data";
+ if (model_->version() >= 706) {
+ // This way of parsing money entity data is enabled for models newer than
+ // v706, consequently logging errors only for them (b/156634162).
+ TC3_LOG(ERROR)
+ << "Data field is null when trying to parse Money Entity Data";
+ }
return false;
}
if (data->money->unnormalized_amount.empty()) {
- TC3_LOG(ERROR) << "Data unnormalized_amount is empty when trying to parse "
- "Money Entity Data";
+ if (model_->version() >= 706) {
+ // This way of parsing money entity data is enabled for models newer than
+ // v706, consequently logging errors only for them (b/156634162).
+ TC3_LOG(ERROR)
+ << "Data unnormalized_amount is empty when trying to parse "
+ "Money Entity Data";
+ }
return false;
}
@@ -2593,7 +2602,11 @@ bool Annotator::RegexChunk(const UnicodeText& context_unicode,
if (regex_pattern.config->collection_name()->str() ==
Collections::Money()) {
if (!ParseAndFillInMoneyAmount(&serialized_entity_data)) {
- TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
+ if (model_->version() >= 706) {
+ // This way of parsing money entity data is enabled for models
+ // newer than v706 => logging errors only for them (b/156634162).
+ TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
+ }
}
}
}
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
index ebcf091..b8e1b7a 100644
--- a/native/annotator/datetime/extractor.cc
+++ b/native/annotator/datetime/extractor.cc
@@ -473,6 +473,7 @@ bool DatetimeExtractor::ParseRelationAndConvertToRelativeCount(
{DatetimeExtractorType_NEXT, 1},
{DatetimeExtractorType_NEXT_OR_SAME, 1},
{DatetimeExtractorType_LAST, -1},
+ {DatetimeExtractorType_PAST, -1},
},
relative_count);
}
diff --git a/native/annotator/experimental/experimental-dummy.h b/native/annotator/experimental/experimental-dummy.h
index 0d50bca..389aae1 100644
--- a/native/annotator/experimental/experimental-dummy.h
+++ b/native/annotator/experimental/experimental-dummy.h
@@ -33,7 +33,8 @@ class ExperimentalAnnotator {
// always disabled;
static constexpr bool IsEnabled() { return false; }
- explicit ExperimentalAnnotator(const FeatureProcessor& feature_processor,
+ explicit ExperimentalAnnotator(const ExperimentalModel* model,
+ const FeatureProcessor& feature_processor,
const UniLib& unilib) {}
bool Annotate(const UnicodeText& context,
diff --git a/native/annotator/experimental/experimental.fbs b/native/annotator/experimental/experimental.fbs
index fff2d9e..6e15d04 100755
--- a/native/annotator/experimental/experimental.fbs
+++ b/native/annotator/experimental/experimental.fbs
@@ -1,3 +1,19 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
namespace libtextclassifier3;
table ExperimentalModel {
}
diff --git a/native/annotator/grammar/dates/utils/date-match.cc b/native/annotator/grammar/dates/utils/date-match.cc
index 1ab1e6a..d9fca52 100644
--- a/native/annotator/grammar/dates/utils/date-match.cc
+++ b/native/annotator/grammar/dates/utils/date-match.cc
@@ -225,6 +225,18 @@ DatetimeComponent::RelativeQualifier DateMatch::GetRelativeQualifier() const {
return DatetimeComponent::RelativeQualifier::UNSPECIFIED;
}
+// Embed RelativeQualifier information of DatetimeComponent as a sign of
+// relative counter field of datetime component i.e. relative counter is
+// negative when relative qualifier RelativeQualifier::PAST.
+int GetAdjustedRelativeCounter(
+ const DatetimeComponent::RelativeQualifier& relative_qualifier,
+ const int relative_counter) {
+ if (DatetimeComponent::RelativeQualifier::PAST == relative_qualifier) {
+ return -relative_counter;
+ }
+ return relative_counter;
+}
+
Optional<DatetimeComponent> CreateDatetimeComponent(
const DatetimeComponent::ComponentType& component_type,
const DatetimeComponent::RelativeQualifier& relative_qualifier,
@@ -232,13 +244,15 @@ Optional<DatetimeComponent> CreateDatetimeComponent(
if (absolute_value == NO_VAL && relative_value == NO_VAL) {
return Optional<DatetimeComponent>();
}
- return Optional<DatetimeComponent>(
- DatetimeComponent(component_type,
- (relative_value != NO_VAL)
- ? relative_qualifier
- : DatetimeComponent::RelativeQualifier::UNSPECIFIED,
- (absolute_value != NO_VAL) ? absolute_value : 0,
- (relative_value != NO_VAL) ? relative_value : 0));
+ return Optional<DatetimeComponent>(DatetimeComponent(
+ component_type,
+ (relative_value != NO_VAL)
+ ? relative_qualifier
+ : DatetimeComponent::RelativeQualifier::UNSPECIFIED,
+ (absolute_value != NO_VAL) ? absolute_value : 0,
+ (relative_value != NO_VAL)
+ ? GetAdjustedRelativeCounter(relative_qualifier, relative_value)
+ : 0));
}
Optional<DatetimeComponent> CreateDayOfWeekComponent(
diff --git a/native/utils/calendar/calendar-common.h b/native/utils/calendar/calendar-common.h
index f842300..e6fd076 100644
--- a/native/utils/calendar/calendar-common.h
+++ b/native/utils/calendar/calendar-common.h
@@ -229,7 +229,7 @@ bool CalendarLibTempl<TCalendar>::ApplyRelationField(
case DatetimeComponent::RelativeQualifier::PAST:
TC3_CALENDAR_CHECK(
AdjustByRelation(relative_date_time_component,
- -relative_date_time_component.relative_count,
+ relative_date_time_component.relative_count,
/*allow_today=*/false, calendar))
return true;
case DatetimeComponent::RelativeQualifier::FUTURE:
diff --git a/native/utils/flatbuffers.cc b/native/utils/flatbuffers.cc
index 1cf60a9..cf4c97f 100644
--- a/native/utils/flatbuffers.cc
+++ b/native/utils/flatbuffers.cc
@@ -24,49 +24,6 @@
namespace libtextclassifier3 {
namespace {
-bool CreateRepeatedField(const reflection::Schema* schema,
- const reflection::Type* type,
- std::unique_ptr<RepeatedField>* repeated_field) {
- switch (type->element()) {
- case reflection::Bool:
- repeated_field->reset(new TypedRepeatedField<bool>);
- return true;
- case reflection::Byte:
- repeated_field->reset(new TypedRepeatedField<char>);
- return true;
- case reflection::UByte:
- repeated_field->reset(new TypedRepeatedField<unsigned char>);
- return true;
- case reflection::Int:
- repeated_field->reset(new TypedRepeatedField<int>);
- return true;
- case reflection::UInt:
- repeated_field->reset(new TypedRepeatedField<uint>);
- return true;
- case reflection::Long:
- repeated_field->reset(new TypedRepeatedField<int64>);
- return true;
- case reflection::ULong:
- repeated_field->reset(new TypedRepeatedField<uint64>);
- return true;
- case reflection::Float:
- repeated_field->reset(new TypedRepeatedField<float>);
- return true;
- case reflection::Double:
- repeated_field->reset(new TypedRepeatedField<double>);
- return true;
- case reflection::String:
- repeated_field->reset(new TypedRepeatedField<std::string>);
- return true;
- case reflection::Obj:
- repeated_field->reset(
- new TypedRepeatedField<ReflectiveFlatbuffer>(schema, type));
- return true;
- default:
- TC3_LOG(ERROR) << "Unsupported type: " << type->element();
- return false;
- }
-}
// Gets the field information for a field name, returns nullptr if the
// field was not defined.
@@ -76,8 +33,8 @@ const reflection::Field* GetFieldOrNull(const reflection::Object* type,
return type->fields()->LookupByKey(field_name.data());
}
-const reflection::Field* GetFieldByOffsetOrNull(const reflection::Object* type,
- const int field_offset) {
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const int field_offset) {
if (type->fields() == nullptr) {
return nullptr;
}
@@ -97,14 +54,14 @@ const reflection::Field* GetFieldOrNull(const reflection::Object* type,
if (!field_name.empty()) {
return GetFieldOrNull(type, field_name.data());
}
- return GetFieldByOffsetOrNull(type, field_offset);
+ return GetFieldOrNull(type, field_offset);
}
const reflection::Field* GetFieldOrNull(const reflection::Object* type,
const FlatbufferField* field) {
TC3_CHECK(type != nullptr && field != nullptr);
if (field->field_name() == nullptr) {
- return GetFieldByOffsetOrNull(type, field->field_offset());
+ return GetFieldOrNull(type, field->field_offset());
}
return GetFieldOrNull(
type,
@@ -154,7 +111,7 @@ bool ParseAndSetField(const reflection::Field* field,
return false;
}
if (field->type()->base_type() == reflection::Vector) {
- buffer->Repeated<T>(field)->Add(value);
+ buffer->Repeated(field)->Add(value);
return true;
} else {
return buffer->Set<T>(field, value);
@@ -221,9 +178,9 @@ bool ReflectiveFlatbuffer::GetFieldWithParent(
return true;
}
-const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull(
+const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
const int field_offset) const {
- return libtextclassifier3::GetFieldByOffsetOrNull(type_, field_offset);
+ return libtextclassifier3::GetFieldOrNull(type_, field_offset);
}
bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
@@ -257,6 +214,27 @@ bool ReflectiveFlatbuffer::ParseAndSet(const FlatbufferFieldPath* path,
return parent->ParseAndSet(field, value);
}
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(StringPiece field_name) {
+ const reflection::Field* field = GetFieldOrNull(field_name);
+ if (field == nullptr) {
+ return nullptr;
+ }
+
+ if (field->type()->base_type() != reflection::BaseType::Vector) {
+ return nullptr;
+ }
+
+ return Add(field);
+}
+
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(
+ const reflection::Field* field) {
+ if (field == nullptr) {
+ return nullptr;
+ }
+ return Repeated(field)->Add();
+}
+
ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
const StringPiece field_name) {
if (const reflection::Field* field = GetFieldOrNull(field_name)) {
@@ -306,11 +284,8 @@ RepeatedField* ReflectiveFlatbuffer::Repeated(const reflection::Field* field) {
}
// Otherwise, create a new instance and store it.
- std::unique_ptr<RepeatedField> repeated_field;
- if (!CreateRepeatedField(schema_, field->type(), &repeated_field)) {
- TC3_LOG(ERROR) << "Could not create repeated field.";
- return nullptr;
- }
+ std::unique_ptr<RepeatedField> repeated_field(
+ new RepeatedField(schema_, field));
const auto it = repeated_fields_.insert(
/*hint=*/entry, std::make_pair(field, std::move(repeated_field)));
return it->second.get();
@@ -330,9 +305,10 @@ flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
// Create strings.
for (const auto& it : fields_) {
- if (it.second.HasString()) {
- offsets.push_back({it.first->offset(),
- builder->CreateString(it.second.StringValue()).o});
+ if (it.second.Has<std::string>()) {
+ offsets.push_back(
+ {it.first->offset(),
+ builder->CreateString(it.second.ConstRefValue<std::string>()).o});
}
}
@@ -349,44 +325,46 @@ flatbuffers::uoffset_t ReflectiveFlatbuffer::Serialize(
switch (it.second.GetType()) {
case Variant::TYPE_BOOL_VALUE:
builder->AddElement<uint8_t>(
- it.first->offset(), static_cast<uint8_t>(it.second.BoolValue()),
+ it.first->offset(), static_cast<uint8_t>(it.second.Value<bool>()),
static_cast<uint8_t>(it.first->default_integer()));
continue;
case Variant::TYPE_INT8_VALUE:
builder->AddElement<int8_t>(
- it.first->offset(), static_cast<int8_t>(it.second.Int8Value()),
+ it.first->offset(), static_cast<int8_t>(it.second.Value<int8>()),
static_cast<int8_t>(it.first->default_integer()));
continue;
case Variant::TYPE_UINT8_VALUE:
builder->AddElement<uint8_t>(
- it.first->offset(), static_cast<uint8_t>(it.second.UInt8Value()),
+ it.first->offset(), static_cast<uint8_t>(it.second.Value<uint8>()),
static_cast<uint8_t>(it.first->default_integer()));
continue;
case Variant::TYPE_INT_VALUE:
builder->AddElement<int32>(
- it.first->offset(), it.second.IntValue(),
+ it.first->offset(), it.second.Value<int>(),
static_cast<int32>(it.first->default_integer()));
continue;
case Variant::TYPE_UINT_VALUE:
builder->AddElement<uint32>(
- it.first->offset(), it.second.UIntValue(),
+ it.first->offset(), it.second.Value<uint>(),
static_cast<uint32>(it.first->default_integer()));
continue;
case Variant::TYPE_INT64_VALUE:
- builder->AddElement<int64>(it.first->offset(), it.second.Int64Value(),
+ builder->AddElement<int64>(it.first->offset(), it.second.Value<int64>(),
it.first->default_integer());
continue;
case Variant::TYPE_UINT64_VALUE:
- builder->AddElement<uint64>(it.first->offset(), it.second.UInt64Value(),
+ builder->AddElement<uint64>(it.first->offset(),
+ it.second.Value<uint64>(),
it.first->default_integer());
continue;
case Variant::TYPE_FLOAT_VALUE:
builder->AddElement<float>(
- it.first->offset(), it.second.FloatValue(),
+ it.first->offset(), it.second.Value<float>(),
static_cast<float>(it.first->default_real()));
continue;
case Variant::TYPE_DOUBLE_VALUE:
- builder->AddElement<double>(it.first->offset(), it.second.DoubleValue(),
+ builder->AddElement<double>(it.first->offset(),
+ it.second.Value<double>(),
it.first->default_real());
continue;
default:
@@ -419,7 +397,7 @@ bool ReflectiveFlatbuffer::AppendFromVector<std::string>(
return false;
}
- TypedRepeatedField<std::string>* to_repeated = Repeated<std::string>(field);
+ RepeatedField* to_repeated = Repeated(field);
for (const flatbuffers::String* element : *from_vector) {
to_repeated->Add(element->str());
}
@@ -435,8 +413,7 @@ bool ReflectiveFlatbuffer::AppendFromVector<ReflectiveFlatbuffer>(
return false;
}
- TypedRepeatedField<ReflectiveFlatbuffer>* to_repeated =
- Repeated<ReflectiveFlatbuffer>(field);
+ RepeatedField* to_repeated = Repeated(field);
for (const flatbuffers::Table* const from_element : *from_vector) {
ReflectiveFlatbuffer* to_element = to_repeated->Add();
if (to_element == nullptr) {
@@ -502,7 +479,9 @@ bool ReflectiveFlatbuffer::MergeFrom(const flatbuffers::Table* from) {
->str());
break;
case reflection::Obj:
- if (!Mutable(field)->MergeFrom(
+ if (ReflectiveFlatbuffer* nested_field = Mutable(field);
+ nested_field == nullptr ||
+ !nested_field->MergeFrom(
from->GetPointer<const flatbuffers::Table* const>(
field->offset()))) {
return false;
@@ -635,4 +614,96 @@ bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
return true;
}
+//
+// Repeated field methods.
+//
+
+ReflectiveFlatbuffer* RepeatedField::Add() {
+ if (is_primitive_) {
+ TC3_LOG(ERROR) << "Trying to add sub-message on a primitive-typed field.";
+ return nullptr;
+ }
+
+ object_items_.emplace_back(new ReflectiveFlatbuffer(
+ schema_, schema_->objects()->Get(field_->type()->index())));
+ return object_items_.back().get();
+}
+
+namespace {
+
+template <typename T>
+flatbuffers::uoffset_t TypedSerialize(const std::vector<Variant>& values,
+ flatbuffers::FlatBufferBuilder* builder) {
+ std::vector<T> typed_values;
+ typed_values.reserve(values.size());
+ for (const Variant& item : values) {
+ typed_values.push_back(item.Value<T>());
+ }
+ return builder->CreateVector(typed_values).o;
+}
+
+} // namespace
+
+flatbuffers::uoffset_t RepeatedField::Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const {
+ switch (field_->type()->element()) {
+ case reflection::String:
+ return SerializeString(builder);
+ break;
+ case reflection::Obj:
+ return SerializeObject(builder);
+ break;
+ case reflection::Bool:
+ return TypedSerialize<bool>(items_, builder);
+ break;
+ case reflection::Byte:
+ return TypedSerialize<int8_t>(items_, builder);
+ break;
+ case reflection::UByte:
+ return TypedSerialize<uint8_t>(items_, builder);
+ break;
+ case reflection::Int:
+ return TypedSerialize<int>(items_, builder);
+ break;
+ case reflection::UInt:
+ return TypedSerialize<uint>(items_, builder);
+ break;
+ case reflection::Long:
+ return TypedSerialize<int64>(items_, builder);
+ break;
+ case reflection::ULong:
+ return TypedSerialize<uint64>(items_, builder);
+ break;
+ case reflection::Float:
+ return TypedSerialize<float>(items_, builder);
+ break;
+ case reflection::Double:
+ return TypedSerialize<double>(items_, builder);
+ break;
+ default:
+ TC3_LOG(FATAL) << "Unsupported type: " << field_->type()->element();
+ break;
+ }
+ TC3_LOG(FATAL) << "Invalid state.";
+ return 0;
+}
+
+flatbuffers::uoffset_t RepeatedField::SerializeString(
+ flatbuffers::FlatBufferBuilder* builder) const {
+ std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(items_.size());
+ for (int i = 0; i < items_.size(); i++) {
+ offsets[i] = builder->CreateString(items_[i].ConstRefValue<std::string>());
+ }
+ return builder->CreateVector(offsets).o;
+}
+
+flatbuffers::uoffset_t RepeatedField::SerializeObject(
+ flatbuffers::FlatBufferBuilder* builder) const {
+ std::vector<flatbuffers::Offset<void>> offsets(object_items_.size());
+ for (int i = 0; i < object_items_.size(); i++) {
+ offsets[i] = object_items_[i]->Serialize(builder);
+ }
+ return builder->CreateVector(offsets).o;
+}
+
} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers.h b/native/utils/flatbuffers.h
index 93a4109..aaf248e 100644
--- a/native/utils/flatbuffers.h
+++ b/native/utils/flatbuffers.h
@@ -19,7 +19,6 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
#define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
-#include <map>
#include <memory>
#include <string>
#include <unordered_map>
@@ -31,13 +30,12 @@
#include "utils/variant.h"
#include "flatbuffers/flatbuffers.h"
#include "flatbuffers/reflection.h"
+#include "flatbuffers/reflection_generated.h"
namespace libtextclassifier3 {
class ReflectiveFlatBuffer;
class RepeatedField;
-template <typename T>
-class TypedRepeatedField;
// Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
// integrity.
@@ -105,6 +103,41 @@ std::string PackFlatbuffer(
builder.GetSize());
}
+class ReflectiveFlatbuffer;
+
+// Checks whether a variant value type agrees with a field type.
+template <typename T>
+bool IsMatchingType(const reflection::BaseType type) {
+ switch (type) {
+ case reflection::Bool:
+ return std::is_same<T, bool>::value;
+ case reflection::Byte:
+ return std::is_same<T, int8>::value;
+ case reflection::UByte:
+ return std::is_same<T, uint8>::value;
+ case reflection::Int:
+ return std::is_same<T, int32>::value;
+ case reflection::UInt:
+ return std::is_same<T, uint32>::value;
+ case reflection::Long:
+ return std::is_same<T, int64>::value;
+ case reflection::ULong:
+ return std::is_same<T, uint64>::value;
+ case reflection::Float:
+ return std::is_same<T, float>::value;
+ case reflection::Double:
+ return std::is_same<T, double>::value;
+ case reflection::String:
+ return std::is_same<T, std::string>::value ||
+ std::is_same<T, StringPiece>::value ||
+ std::is_same<T, const char*>::value;
+ case reflection::Obj:
+ return std::is_same<T, ReflectiveFlatbuffer>::value;
+ default:
+ return false;
+ }
+}
+
// A flatbuffer that can be built using flatbuffer reflection data of the
// schema.
// Normally, field information is hard-coded in code generated from a flatbuffer
@@ -123,119 +156,58 @@ class ReflectiveFlatbuffer {
// field was not defined.
const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
- const reflection::Field* GetFieldByOffsetOrNull(const int field_offset) const;
+ const reflection::Field* GetFieldOrNull(const int field_offset) const;
// Gets a nested field and the message it is defined on.
bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
ReflectiveFlatbuffer** parent,
reflection::Field const** field);
- // Checks whether a variant value type agrees with a field type.
- template <typename T>
- bool IsMatchingType(const reflection::BaseType type) const {
- switch (type) {
- case reflection::Bool:
- return std::is_same<T, bool>::value;
- case reflection::Byte:
- return std::is_same<T, int8>::value;
- case reflection::UByte:
- return std::is_same<T, uint8>::value;
- case reflection::Int:
- return std::is_same<T, int32>::value;
- case reflection::UInt:
- return std::is_same<T, uint32>::value;
- case reflection::Long:
- return std::is_same<T, int64>::value;
- case reflection::ULong:
- return std::is_same<T, uint64>::value;
- case reflection::Float:
- return std::is_same<T, float>::value;
- case reflection::Double:
- return std::is_same<T, double>::value;
- case reflection::String:
- return std::is_same<T, std::string>::value ||
- std::is_same<T, StringPiece>::value ||
- std::is_same<T, const char*>::value;
- case reflection::Obj:
- return std::is_same<T, ReflectiveFlatbuffer>::value;
- default:
- return false;
- }
- }
-
- // Sets a (primitive) field to a specific value.
+ // Sets a field to a specific value.
// Returns true if successful, and false if the field was not found or the
// expected type doesn't match.
template <typename T>
- bool Set(StringPiece field_name, T value) {
- if (const reflection::Field* field = GetFieldOrNull(field_name)) {
- return Set<T>(field, value);
- }
- return false;
- }
+ bool Set(StringPiece field_name, T value);
- // Sets a (primitive) field to a specific value.
+ // Sets a field to a specific value.
// Returns true if successful, and false if the expected type doesn't match.
// Expects `field` to be non-null.
template <typename T>
- bool Set(const reflection::Field* field, T value) {
- if (field == nullptr) {
- TC3_LOG(ERROR) << "Expected non-null field.";
- return false;
- }
- Variant variant_value(value);
- if (!IsMatchingType<T>(field->type()->base_type())) {
- TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
- << "`, expected: " << field->type()->base_type()
- << ", got: " << variant_value.GetType();
- return false;
- }
- fields_[field] = variant_value;
- return true;
- }
+ bool Set(const reflection::Field* field, T value);
+ // Sets a field to a specific value. Field is specified by path.
template <typename T>
- bool Set(const FlatbufferFieldPath* path, T value) {
- ReflectiveFlatbuffer* parent;
- const reflection::Field* field;
- if (!GetFieldWithParent(path, &parent, &field)) {
- return false;
- }
- return parent->Set<T>(field, value);
- }
-
- // Sets a (primitive) field to a specific value.
- // Parses the string value according to the field type.
- bool ParseAndSet(const reflection::Field* field, const std::string& value);
- bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
+ bool Set(const FlatbufferFieldPath* path, T value);
- // Gets the reflective flatbuffer for a table field.
+ // Sets sub-message field (if not set yet), and returns a pointer to it.
// Returns nullptr if the field was not found, or the field type was not a
// table.
ReflectiveFlatbuffer* Mutable(StringPiece field_name);
ReflectiveFlatbuffer* Mutable(const reflection::Field* field);
+ // Parses the value (according to the type) and sets a primitive field to the
+ // parsed value.
+ bool ParseAndSet(const reflection::Field* field, const std::string& value);
+ bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
+
+ // Adds a primitive value to the repeated field.
+ template <typename T>
+ bool Add(StringPiece field_name, T value);
+
+ // Add a sub-message to the repeated field.
+ ReflectiveFlatbuffer* Add(StringPiece field_name);
+
+ template <typename T>
+ bool Add(const reflection::Field* field, T value);
+
+ ReflectiveFlatbuffer* Add(const reflection::Field* field);
+
// Gets the reflective flatbuffer for a repeated field.
// Returns nullptr if the field was not found, or the field type was not a
// vector.
RepeatedField* Repeated(StringPiece field_name);
RepeatedField* Repeated(const reflection::Field* field);
- template <typename T>
- TypedRepeatedField<T>* Repeated(const reflection::Field* field) {
- if (!IsMatchingType<T>(field->type()->element())) {
- TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
- << "`";
- return nullptr;
- }
- return static_cast<TypedRepeatedField<T>*>(Repeated(field));
- }
-
- template <typename T>
- TypedRepeatedField<T>* Repeated(StringPiece field_name) {
- return static_cast<TypedRepeatedField<T>*>(Repeated(field_name));
- }
-
// Serializes the flatbuffer.
flatbuffers::uoffset_t Serialize(
flatbuffers::FlatBufferBuilder* builder) const;
@@ -318,76 +290,131 @@ class ReflectiveFlatbufferBuilder {
// Serves as a common base class for repeated fields.
class RepeatedField {
public:
- virtual ~RepeatedField() {}
+ RepeatedField(const reflection::Schema* const schema,
+ const reflection::Field* field)
+ : schema_(schema),
+ field_(field),
+ is_primitive_(field->type()->element() != reflection::BaseType::Obj) {}
- virtual flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const = 0;
-};
+ template <typename T>
+ bool Add(const T value);
-// Represents a repeated field of particular type.
-template <typename T>
-class TypedRepeatedField : public RepeatedField {
- public:
- void Add(const T value) { items_.push_back(value); }
+ ReflectiveFlatbuffer* Add();
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const override {
- return builder->CreateVector(items_).o;
+ template <typename T>
+ T Get(int index) const {
+ return items_.at(index).Value<T>();
}
- private:
- std::vector<T> items_;
-};
-
-// Specialization for strings.
-template <>
-class TypedRepeatedField<std::string> : public RepeatedField {
- public:
- void Add(const std::string& value) { items_.push_back(value); }
+ template <>
+ ReflectiveFlatbuffer* Get(int index) const {
+ if (is_primitive_) {
+ TC3_LOG(ERROR) << "Trying to get primitive value out of non-primitive "
+ "repeated field.";
+ return nullptr;
+ }
+ return object_items_.at(index).get();
+ }
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const override {
- std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(
- items_.size());
- for (int i = 0; i < items_.size(); i++) {
- offsets[i] = builder->CreateString(items_[i]);
+ int Size() const {
+ if (is_primitive_) {
+ return items_.size();
+ } else {
+ return object_items_.size();
}
- return builder->CreateVector(offsets).o;
}
+ flatbuffers::uoffset_t Serialize(
+ flatbuffers::FlatBufferBuilder* builder) const;
+
private:
- std::vector<std::string> items_;
+ flatbuffers::uoffset_t SerializeString(
+ flatbuffers::FlatBufferBuilder* builder) const;
+ flatbuffers::uoffset_t SerializeObject(
+ flatbuffers::FlatBufferBuilder* builder) const;
+
+ const reflection::Schema* const schema_;
+ const reflection::Field* field_;
+ bool is_primitive_;
+
+ std::vector<Variant> items_;
+ std::vector<std::unique_ptr<ReflectiveFlatbuffer>> object_items_;
};
-// Specialization for repeated sub-messages.
-template <>
-class TypedRepeatedField<ReflectiveFlatbuffer> : public RepeatedField {
- public:
- TypedRepeatedField<ReflectiveFlatbuffer>(
- const reflection::Schema* const schema,
- const reflection::Type* const type)
- : schema_(schema), type_(type) {}
+template <typename T>
+bool ReflectiveFlatbuffer::Set(StringPiece field_name, T value) {
+ if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+ if (field->type()->base_type() == reflection::BaseType::Vector ||
+ field->type()->base_type() == reflection::BaseType::Obj) {
+ TC3_LOG(ERROR)
+ << "Trying to set a primitive value on a non-scalar field.";
+ return false;
+ }
+ return Set<T>(field, value);
+ }
+ TC3_LOG(ERROR) << "Couldn't find a field: " << field_name;
+ return false;
+}
- ReflectiveFlatbuffer* Add() {
- items_.emplace_back(new ReflectiveFlatbuffer(
- schema_, schema_->objects()->Get(type_->index())));
- return items_.back().get();
+template <typename T>
+bool ReflectiveFlatbuffer::Set(const reflection::Field* field, T value) {
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Expected non-null field.";
+ return false;
}
+ Variant variant_value(value);
+ if (!IsMatchingType<T>(field->type()->base_type())) {
+ TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
+ << "`, expected: " << field->type()->base_type()
+ << ", got: " << variant_value.GetType();
+ return false;
+ }
+ fields_[field] = variant_value;
+ return true;
+}
- flatbuffers::uoffset_t Serialize(
- flatbuffers::FlatBufferBuilder* builder) const override {
- std::vector<flatbuffers::Offset<void>> offsets(items_.size());
- for (int i = 0; i < items_.size(); i++) {
- offsets[i] = items_[i]->Serialize(builder);
- }
- return builder->CreateVector(offsets).o;
+template <typename T>
+bool ReflectiveFlatbuffer::Set(const FlatbufferFieldPath* path, T value) {
+ ReflectiveFlatbuffer* parent;
+ const reflection::Field* field;
+ if (!GetFieldWithParent(path, &parent, &field)) {
+ return false;
}
+ return parent->Set<T>(field, value);
+}
- private:
- const reflection::Schema* const schema_;
- const reflection::Type* const type_;
- std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_;
-};
+template <typename T>
+bool ReflectiveFlatbuffer::Add(StringPiece field_name, T value) {
+ const reflection::Field* field = GetFieldOrNull(field_name);
+ if (field == nullptr) {
+ return false;
+ }
+
+ if (field->type()->base_type() != reflection::BaseType::Vector) {
+ return false;
+ }
+
+ return Add<T>(field, value);
+}
+
+template <typename T>
+bool ReflectiveFlatbuffer::Add(const reflection::Field* field, T value) {
+ if (field == nullptr) {
+ return false;
+ }
+ Repeated(field)->Add(value);
+ return true;
+}
+
+template <typename T>
+bool RepeatedField::Add(const T value) {
+ if (!is_primitive_ || !IsMatchingType<T>(field_->type()->element())) {
+ TC3_LOG(ERROR) << "Trying to add value of unmatching type.";
+ return false;
+ }
+ items_.push_back(Variant{value});
+ return true;
+}
// Resolves field lookups by name to the concrete field offsets.
bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
@@ -402,7 +429,7 @@ bool ReflectiveFlatbuffer::AppendFromVector(const flatbuffers::Table* from,
return false;
}
- TypedRepeatedField<T>* to_repeated = Repeated<T>(field);
+ RepeatedField* to_repeated = Repeated(field);
for (const T element : *from_vector) {
to_repeated->Add(element);
}
diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc
index 69a06a8..d6e4b76 100644
--- a/native/utils/grammar/utils/rules.cc
+++ b/native/utils/grammar/utils/rules.cc
@@ -177,6 +177,13 @@ int Rules::AddAnnotation(const std::string& annotation_name) {
return it->second;
}
+void Rules::BindAnnotation(const std::string& nonterminal_name,
+ const std::string& annotation_name) {
+ auto [_, inserted] = annotation_nonterminals_.insert(
+ {annotation_name, AddNonterminal(nonterminal_name)});
+ TC3_CHECK(inserted);
+}
+
bool Rules::IsNonterminalOfName(const RhsElement& element,
const std::string& nonterminal) const {
if (element.is_terminal) {
diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h
index 5cc20d7..5a2cbc2 100644
--- a/native/utils/grammar/utils/rules.h
+++ b/native/utils/grammar/utils/rules.h
@@ -153,6 +153,10 @@ class Rules {
// Defines a nonterminal for an externally provided annotation.
int AddAnnotation(const std::string& annotation_name);
+ // Defines a nonterminal for an externally provided annotation.
+ void BindAnnotation(const std::string& nonterminal_name,
+ const std::string& annotation_name);
+
// Adds an alias for a nonterminal. This is a separate name for the same
// nonterminal.
void AddAlias(const std::string& nonterminal_name, const std::string& alias);
diff --git a/native/utils/intents/jni.cc b/native/utils/intents/jni.cc
index 1c6c283..051d078 100644
--- a/native/utils/intents/jni.cc
+++ b/native/utils/intents/jni.cc
@@ -175,40 +175,41 @@ StatusOr<ScopedLocalRef<jobject>> RemoteActionTemplatesHandler::AsNamedVariant(
case Variant::TYPE_INT_VALUE:
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_int_, name.get(),
- value.IntValue());
+ value.Value<int>());
case Variant::TYPE_INT64_VALUE:
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_long_, name.get(),
- value.Int64Value());
+ value.Value<int64>());
case Variant::TYPE_FLOAT_VALUE:
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_float_, name.get(),
- value.FloatValue());
+ value.Value<float>());
case Variant::TYPE_DOUBLE_VALUE:
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_double_, name.get(),
- value.DoubleValue());
+ value.Value<double>());
case Variant::TYPE_BOOL_VALUE:
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_bool_, name.get(),
- value.BoolValue());
+ value.Value<bool>());
case Variant::TYPE_STRING_VALUE: {
TC3_ASSIGN_OR_RETURN(
ScopedLocalRef<jstring> value_jstring,
- jni_cache_->ConvertToJavaString(value.StringValue()));
+ jni_cache_->ConvertToJavaString(value.ConstRefValue<std::string>()));
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_string_, name.get(),
value_jstring.get());
}
case Variant::TYPE_STRING_VECTOR_VALUE: {
- TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobjectArray> value_jstring_array,
- AsStringArray(value.StringVectorValue()));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> value_jstring_array,
+ AsStringArray(value.ConstRefValue<std::vector<std::string>>()));
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_string_array_, name.get(),
@@ -216,8 +217,9 @@ StatusOr<ScopedLocalRef<jobject>> RemoteActionTemplatesHandler::AsNamedVariant(
}
case Variant::TYPE_FLOAT_VECTOR_VALUE: {
- TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jfloatArray> value_jfloat_array,
- AsFloatArray(value.FloatVectorValue()));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jfloatArray> value_jfloat_array,
+ AsFloatArray(value.ConstRefValue<std::vector<float>>()));
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_float_array_, name.get(),
@@ -226,7 +228,7 @@ StatusOr<ScopedLocalRef<jobject>> RemoteActionTemplatesHandler::AsNamedVariant(
case Variant::TYPE_INT_VECTOR_VALUE: {
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jintArray> value_jint_array,
- AsIntArray(value.IntVectorValue()));
+ AsIntArray(value.ConstRefValue<std::vector<int>>()));
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_int_array_, name.get(),
@@ -234,8 +236,10 @@ StatusOr<ScopedLocalRef<jobject>> RemoteActionTemplatesHandler::AsNamedVariant(
}
case Variant::TYPE_STRING_VARIANT_MAP_VALUE: {
- TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobjectArray> value_jobect_array,
- AsNamedVariantArray(value.StringVariantMapValue()));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> value_jobect_array,
+ AsNamedVariantArray(
+ value.ConstRefValue<std::map<std::string, Variant>>()));
return JniHelper::NewObject(env, named_variant_class_.get(),
named_variant_from_named_variant_array_,
name.get(), value_jobect_array.get());
diff --git a/native/utils/lua-utils.cc b/native/utils/lua-utils.cc
index fa19923..d6fe2c4 100644
--- a/native/utils/lua-utils.cc
+++ b/native/utils/lua-utils.cc
@@ -223,6 +223,11 @@ int LuaEnvironment::GetField(const reflection::Schema* schema,
int LuaEnvironment::ReadFlatbuffer(const int index,
ReflectiveFlatbuffer* buffer) const {
+ if (buffer == nullptr) {
+ TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index;
+ lua_error(state_);
+ return LUA_ERRRUN;
+ }
if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected table, got: "
<< lua_type(state_, /*idx=*/kIndexStackTop);
@@ -278,48 +283,48 @@ int LuaEnvironment::ReadFlatbuffer(const int index,
// Read repeated field.
switch (field->type()->element()) {
case reflection::Bool:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<bool>(field));
+ ReadRepeatedField<bool>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::Byte:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<int8>(field));
+ ReadRepeatedField<int8>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::UByte:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<uint8>(field));
+ ReadRepeatedField<uint8>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::Int:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<int32>(field));
+ ReadRepeatedField<int32>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::UInt:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<uint32>(field));
+ ReadRepeatedField<uint32>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::Long:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<int64>(field));
+ ReadRepeatedField<int64>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::ULong:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<uint64>(field));
+ ReadRepeatedField<uint64>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::Float:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<float>(field));
+ ReadRepeatedField<float>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::Double:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<double>(field));
+ ReadRepeatedField<double>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::String:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<std::string>(field));
+ ReadRepeatedField<std::string>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
case reflection::Obj:
- ReadRepeatedField(/*index=*/kIndexStackTop,
- buffer->Repeated<ReflectiveFlatbuffer>(field));
+ ReadRepeatedField<ReflectiveFlatbuffer>(/*index=*/kIndexStackTop,
+ buffer->Repeated(field));
break;
default:
TC3_LOG(ERROR) << "Unsupported repeated field type: "
diff --git a/native/utils/lua-utils.h b/native/utils/lua-utils.h
index f602aa0..b01471a 100644
--- a/native/utils/lua-utils.h
+++ b/native/utils/lua-utils.h
@@ -506,15 +506,15 @@ class LuaEnvironment {
// Reads a repeated field from lua.
template <typename T>
- void ReadRepeatedField(const int index, TypedRepeatedField<T>* result) const {
+ void ReadRepeatedField(const int index, RepeatedField* result) const {
for (const auto& element : ReadVector<T>(index)) {
result->Add(element);
}
}
template <>
- void ReadRepeatedField<ReflectiveFlatbuffer>(
- const int index, TypedRepeatedField<ReflectiveFlatbuffer>* result) const {
+ void ReadRepeatedField<ReflectiveFlatbuffer>(const int index,
+ RepeatedField* result) const {
lua_pushnil(state_);
while (Next(index - 1)) {
ReadFlatbuffer(index, result->Add());
diff --git a/native/utils/variant.cc b/native/utils/variant.cc
index 9cdc0b6..0513440 100644
--- a/native/utils/variant.cc
+++ b/native/utils/variant.cc
@@ -21,26 +21,26 @@ namespace libtextclassifier3 {
std::string Variant::ToString() const {
switch (GetType()) {
case Variant::TYPE_BOOL_VALUE:
- if (BoolValue()) {
+ if (Value<bool>()) {
return "true";
} else {
return "false";
}
break;
case Variant::TYPE_INT_VALUE:
- return std::to_string(IntValue());
+ return std::to_string(Value<int>());
break;
case Variant::TYPE_INT64_VALUE:
- return std::to_string(Int64Value());
+ return std::to_string(Value<int64>());
break;
case Variant::TYPE_FLOAT_VALUE:
- return std::to_string(FloatValue());
+ return std::to_string(Value<float>());
break;
case Variant::TYPE_DOUBLE_VALUE:
- return std::to_string(DoubleValue());
+ return std::to_string(Value<double>());
break;
case Variant::TYPE_STRING_VALUE:
- return StringValue();
+ return ConstRefValue<std::string>();
break;
default:
TC3_LOG(FATAL) << "Unsupported variant type: " << GetType();
diff --git a/native/utils/variant.h b/native/utils/variant.h
index 11c361c..551a822 100644
--- a/native/utils/variant.h
+++ b/native/utils/variant.h
@@ -85,110 +85,178 @@ class Variant {
Variant& operator=(const Variant&) = default;
- int Int8Value() const {
- TC3_CHECK(HasInt8());
+ template <class T>
+ struct dependent_false : std::false_type {};
+
+ template <typename T>
+ T Value() const {
+ static_assert(dependent_false<T>::value, "Not supported.");
+ }
+
+ template <>
+ int8 Value() const {
+ TC3_CHECK(Has<int8>());
return int8_value_;
}
- int UInt8Value() const {
- TC3_CHECK(HasUInt8());
+ template <>
+ uint8 Value() const {
+ TC3_CHECK(Has<uint8>());
return uint8_value_;
}
- int IntValue() const {
- TC3_CHECK(HasInt());
+ template <>
+ int Value() const {
+ TC3_CHECK(Has<int>());
return int_value_;
}
- uint UIntValue() const {
- TC3_CHECK(HasUInt());
+ template <>
+ uint Value() const {
+ TC3_CHECK(Has<uint>());
return uint_value_;
}
- int64 Int64Value() const {
- TC3_CHECK(HasInt64());
+ template <>
+ int64 Value() const {
+ TC3_CHECK(Has<int64>());
return long_value_;
}
- uint64 UInt64Value() const {
- TC3_CHECK(HasUInt64());
+ template <>
+ uint64 Value() const {
+ TC3_CHECK(Has<uint64>());
return ulong_value_;
}
- float FloatValue() const {
- TC3_CHECK(HasFloat());
+ template <>
+ float Value() const {
+ TC3_CHECK(Has<float>());
return float_value_;
}
- double DoubleValue() const {
- TC3_CHECK(HasDouble());
+ template <>
+ double Value() const {
+ TC3_CHECK(Has<double>());
return double_value_;
}
- bool BoolValue() const {
- TC3_CHECK(HasBool());
+ template <>
+ bool Value() const {
+ TC3_CHECK(Has<bool>());
return bool_value_;
}
- const std::string& StringValue() const {
- TC3_CHECK(HasString());
+ template <typename T>
+ const T& ConstRefValue() const;
+
+ template <>
+ const std::string& ConstRefValue() const {
+ TC3_CHECK(Has<std::string>());
return string_value_;
}
- const std::vector<std::string>& StringVectorValue() const {
- TC3_CHECK(HasStringVector());
+ template <>
+ const std::vector<std::string>& ConstRefValue() const {
+ TC3_CHECK(Has<std::vector<std::string>>());
return string_vector_value_;
}
- const std::vector<float>& FloatVectorValue() const {
- TC3_CHECK(HasFloatVector());
+ template <>
+ const std::vector<float>& ConstRefValue() const {
+ TC3_CHECK(Has<std::vector<float>>());
return float_vector_value_;
}
- const std::vector<int>& IntVectorValue() const {
- TC3_CHECK(HasIntVector());
+ template <>
+ const std::vector<int>& ConstRefValue() const {
+ TC3_CHECK(Has<std::vector<int>>());
return int_vector_value_;
}
- const std::map<std::string, Variant>& StringVariantMapValue() const {
- TC3_CHECK(HasStringVariantMap());
+ template <>
+ const std::map<std::string, Variant>& ConstRefValue() const {
+ TC3_CHECK((Has<std::map<std::string, Variant>>()));
return string_variant_map_value_;
}
- // Converts the value of this variant to its string representation, regardless
- // of the type of the actual value.
- std::string ToString() const;
+ template <typename T>
+ bool Has() const;
- bool HasInt8() const { return type_ == TYPE_INT8_VALUE; }
+ template <>
+ bool Has<int8>() const {
+ return type_ == TYPE_INT8_VALUE;
+ }
- bool HasUInt8() const { return type_ == TYPE_UINT8_VALUE; }
+ template <>
+ bool Has<uint8>() const {
+ return type_ == TYPE_UINT8_VALUE;
+ }
- bool HasInt() const { return type_ == TYPE_INT_VALUE; }
+ template <>
+ bool Has<int>() const {
+ return type_ == TYPE_INT_VALUE;
+ }
- bool HasUInt() const { return type_ == TYPE_UINT_VALUE; }
+ template <>
+ bool Has<uint>() const {
+ return type_ == TYPE_UINT_VALUE;
+ }
- bool HasInt64() const { return type_ == TYPE_INT64_VALUE; }
+ template <>
+ bool Has<int64>() const {
+ return type_ == TYPE_INT64_VALUE;
+ }
- bool HasUInt64() const { return type_ == TYPE_UINT64_VALUE; }
+ template <>
+ bool Has<uint64>() const {
+ return type_ == TYPE_UINT64_VALUE;
+ }
- bool HasFloat() const { return type_ == TYPE_FLOAT_VALUE; }
+ template <>
+ bool Has<float>() const {
+ return type_ == TYPE_FLOAT_VALUE;
+ }
- bool HasDouble() const { return type_ == TYPE_DOUBLE_VALUE; }
+ template <>
+ bool Has<double>() const {
+ return type_ == TYPE_DOUBLE_VALUE;
+ }
- bool HasBool() const { return type_ == TYPE_BOOL_VALUE; }
+ template <>
+ bool Has<bool>() const {
+ return type_ == TYPE_BOOL_VALUE;
+ }
- bool HasString() const { return type_ == TYPE_STRING_VALUE; }
+ template <>
+ bool Has<std::string>() const {
+ return type_ == TYPE_STRING_VALUE;
+ }
- bool HasStringVector() const { return type_ == TYPE_STRING_VECTOR_VALUE; }
+ template <>
+ bool Has<std::vector<std::string>>() const {
+ return type_ == TYPE_STRING_VECTOR_VALUE;
+ }
- bool HasFloatVector() const { return type_ == TYPE_FLOAT_VECTOR_VALUE; }
+ template <>
+ bool Has<std::vector<float>>() const {
+ return type_ == TYPE_FLOAT_VECTOR_VALUE;
+ }
- bool HasIntVector() const { return type_ == TYPE_INT_VECTOR_VALUE; }
+ template <>
+ bool Has<std::vector<int>>() const {
+ return type_ == TYPE_INT_VECTOR_VALUE;
+ }
- bool HasStringVariantMap() const {
+ template <>
+ bool Has<std::map<std::string, Variant>>() const {
return type_ == TYPE_STRING_VARIANT_MAP_VALUE;
}
+ // Converts the value of this variant to its string representation, regardless
+ // of the type of the actual value.
+ std::string ToString() const;
+
Type GetType() const { return type_; }
bool HasValue() const { return type_ != TYPE_EMPTY; }
diff --git a/notification/tests/AndroidTest.xml b/notification/tests/AndroidTest.xml
new file mode 100644
index 0000000..1890e75
--- /dev/null
+++ b/notification/tests/AndroidTest.xml
@@ -0,0 +1,33 @@
+<?xml version="1.0" encoding="utf-8"?>
+<!-- Copyright (C) 2020 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<!-- This test config file is auto-generated. -->
+<configuration description="Runs TextClassifierNotificationTests.">
+ <option name="test-suite-tag" value="apct" />
+ <option name="test-suite-tag" value="apct-instrumentation" />
+ <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
+ <option name="cleanup-apks" value="true" />
+ <option name="test-file-name" value="TextClassifierNotificationTests.apk" />
+ </target_preparer>
+
+ <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
+ <option name="package" value="com.android.textclassifier.notification" />
+ <option name="runner" value="androidx.test.runner.AndroidJUnitRunner" />
+ </test>
+
+ <object type="module_controller" class="com.android.tradefed.testtype.suite.module.MainlineTestModuleController">
+ <option name="mainline-module-package-name" value="com.google.android.extservices" />
+ </object>
+</configuration>