summaryrefslogtreecommitdiff
path: root/native/actions/actions-suggestions.cc
diff options
context:
space:
mode:
Diffstat (limited to 'native/actions/actions-suggestions.cc')
-rw-r--r--native/actions/actions-suggestions.cc75
1 files changed, 56 insertions, 19 deletions
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index a84f2cd..1fcd35c 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -22,6 +22,7 @@
#include "actions/types.h"
#include "actions/utils.h"
#include "actions/zlib-utils.h"
+#include "annotator/collections.h"
#include "utils/base/logging.h"
#include "utils/flatbuffers.h"
#include "utils/lua-utils.h"
@@ -50,6 +51,11 @@ const std::string& ActionsSuggestions::kSendEmailType =
*[]() { return new std::string("send_email"); }();
const std::string& ActionsSuggestions::kShareLocation =
*[]() { return new std::string("share_location"); }();
+
+// Name for a datetime annotation that only includes time but no date.
+const std::string& kTimeAnnotation =
+ *[]() { return new std::string("time"); }();
+
constexpr float kDefaultFloat = 0.0;
constexpr bool kDefaultBool = false;
constexpr int kDefaultInt = 1;
@@ -260,6 +266,7 @@ bool ActionsSuggestions::ValidateAndInitialize() {
}
}
+ // Gather annotation entities for the rules.
if (model_->annotation_actions_spec() != nullptr &&
model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
@@ -300,6 +307,18 @@ bool ActionsSuggestions::ValidateAndInitialize() {
grammar_actions_.reset(new GrammarActions(
unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
model_->smart_reply_action_type()->str()));
+
+ // Gather annotation entities for the grammars.
+ if (auto annotation_nt = model_->rules()
+ ->grammar_rules()
+ ->rules()
+ ->nonterminals()
+ ->annotation_nt()) {
+ for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
+ *annotation_nt) {
+ annotation_entity_types_.insert(entry->key()->str());
+ }
+ }
}
std::string actions_script;
@@ -689,47 +708,41 @@ bool ActionsSuggestions::SetupModelInput(
interpreter->tensor(interpreter->inputs()[param_index])->type;
const auto param_value_it = model_parameters.find(param_name);
const bool has_value = param_value_it != model_parameters.end();
- /*
- case kTfLiteInt16:
- *tflite::GetTensorData<int16_t>(input_tensor) = input_value;
- break;
- case kTfLiteInt8:
- */
switch (param_type) {
case kTfLiteFloat32:
model_executor_->SetInput<float>(
param_index,
- has_value ? param_value_it->second.FloatValue() : kDefaultFloat,
+ has_value ? param_value_it->second.Value<float>() : kDefaultFloat,
interpreter);
break;
case kTfLiteInt32:
model_executor_->SetInput<int32_t>(
param_index,
- has_value ? param_value_it->second.IntValue() : kDefaultInt,
+ has_value ? param_value_it->second.Value<int>() : kDefaultInt,
interpreter);
break;
case kTfLiteInt64:
model_executor_->SetInput<int64_t>(
param_index,
- has_value ? param_value_it->second.Int64Value() : kDefaultInt,
+ has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
interpreter);
break;
case kTfLiteUInt8:
model_executor_->SetInput<uint8_t>(
param_index,
- has_value ? param_value_it->second.UInt8Value() : kDefaultInt,
+ has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
interpreter);
break;
case kTfLiteInt8:
model_executor_->SetInput<int8_t>(
param_index,
- has_value ? param_value_it->second.Int8Value() : kDefaultInt,
+ has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
interpreter);
break;
case kTfLiteBool:
model_executor_->SetInput<bool>(
param_index,
- has_value ? param_value_it->second.BoolValue() : kDefaultBool,
+ has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
interpreter);
break;
default:
@@ -1023,6 +1036,30 @@ Conversation ActionsSuggestions::AnnotateConversation(
if (message->annotations.empty()) {
message->annotations = annotator->Annotate(
message->text, AnnotationOptionsForMessage(*message));
+ for (int i = 0; i < message->annotations.size(); i++) {
+ ClassificationResult* classification =
+ &message->annotations[i].classification.front();
+
+ // Specialize datetime annotation to time annotation if no date
+ // component is present.
+ if (classification->collection == Collections::DateTime() &&
+ classification->datetime_parse_result.IsSet()) {
+ bool has_only_time = true;
+ for (const DatetimeComponent& component :
+ classification->datetime_parse_result.datetime_components) {
+ if (component.component_type !=
+ DatetimeComponent::ComponentType::UNSPECIFIED &&
+ component.component_type <
+ DatetimeComponent::ComponentType::HOUR) {
+ has_only_time = false;
+ break;
+ }
+ }
+ if (has_only_time) {
+ classification->collection = kTimeAnnotation;
+ }
+ }
+ }
}
}
return annotated_conversation;
@@ -1224,6 +1261,13 @@ bool ActionsSuggestions::GatherActionsSuggestions(
SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
+ if (grammar_actions_ != nullptr &&
+ !grammar_actions_->SuggestActions(annotated_conversation,
+ &response->actions)) {
+ TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
+ return false;
+ }
+
int input_text_length = 0;
int num_matching_locales = 0;
for (int i = annotated_conversation.messages.size() - num_messages;
@@ -1299,13 +1343,6 @@ bool ActionsSuggestions::GatherActionsSuggestions(
return false;
}
- if (grammar_actions_ != nullptr &&
- !grammar_actions_->SuggestActions(annotated_conversation,
- &response->actions)) {
- TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
- return false;
- }
-
if (preconditions_.suppress_on_low_confidence_input &&
!regex_actions_->FilterConfidenceOutput(post_check_rules,
&response->actions)) {