diff options
Diffstat (limited to 'native/actions/actions-suggestions.cc')
-rw-r--r-- | native/actions/actions-suggestions.cc | 75 |
1 files changed, 56 insertions, 19 deletions
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc index a84f2cd..1fcd35c 100644 --- a/native/actions/actions-suggestions.cc +++ b/native/actions/actions-suggestions.cc @@ -22,6 +22,7 @@ #include "actions/types.h" #include "actions/utils.h" #include "actions/zlib-utils.h" +#include "annotator/collections.h" #include "utils/base/logging.h" #include "utils/flatbuffers.h" #include "utils/lua-utils.h" @@ -50,6 +51,11 @@ const std::string& ActionsSuggestions::kSendEmailType = *[]() { return new std::string("send_email"); }(); const std::string& ActionsSuggestions::kShareLocation = *[]() { return new std::string("share_location"); }(); + +// Name for a datetime annotation that only includes time but no date. +const std::string& kTimeAnnotation = + *[]() { return new std::string("time"); }(); + constexpr float kDefaultFloat = 0.0; constexpr bool kDefaultBool = false; constexpr int kDefaultInt = 1; @@ -260,6 +266,7 @@ bool ActionsSuggestions::ValidateAndInitialize() { } } + // Gather annotation entities for the rules. if (model_->annotation_actions_spec() != nullptr && model_->annotation_actions_spec()->annotation_mapping() != nullptr) { for (const AnnotationActionsSpec_::AnnotationMapping* mapping : @@ -300,6 +307,18 @@ bool ActionsSuggestions::ValidateAndInitialize() { grammar_actions_.reset(new GrammarActions( unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(), model_->smart_reply_action_type()->str())); + + // Gather annotation entities for the grammars. + if (auto annotation_nt = model_->rules() + ->grammar_rules() + ->rules() + ->nonterminals() + ->annotation_nt()) { + for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry : + *annotation_nt) { + annotation_entity_types_.insert(entry->key()->str()); + } + } } std::string actions_script; @@ -689,47 +708,41 @@ bool ActionsSuggestions::SetupModelInput( interpreter->tensor(interpreter->inputs()[param_index])->type; const auto param_value_it = model_parameters.find(param_name); const bool has_value = param_value_it != model_parameters.end(); - /* - case kTfLiteInt16: - *tflite::GetTensorData<int16_t>(input_tensor) = input_value; - break; - case kTfLiteInt8: - */ switch (param_type) { case kTfLiteFloat32: model_executor_->SetInput<float>( param_index, - has_value ? param_value_it->second.FloatValue() : kDefaultFloat, + has_value ? param_value_it->second.Value<float>() : kDefaultFloat, interpreter); break; case kTfLiteInt32: model_executor_->SetInput<int32_t>( param_index, - has_value ? param_value_it->second.IntValue() : kDefaultInt, + has_value ? param_value_it->second.Value<int>() : kDefaultInt, interpreter); break; case kTfLiteInt64: model_executor_->SetInput<int64_t>( param_index, - has_value ? param_value_it->second.Int64Value() : kDefaultInt, + has_value ? param_value_it->second.Value<int64>() : kDefaultInt, interpreter); break; case kTfLiteUInt8: model_executor_->SetInput<uint8_t>( param_index, - has_value ? param_value_it->second.UInt8Value() : kDefaultInt, + has_value ? param_value_it->second.Value<uint8>() : kDefaultInt, interpreter); break; case kTfLiteInt8: model_executor_->SetInput<int8_t>( param_index, - has_value ? param_value_it->second.Int8Value() : kDefaultInt, + has_value ? param_value_it->second.Value<int8>() : kDefaultInt, interpreter); break; case kTfLiteBool: model_executor_->SetInput<bool>( param_index, - has_value ? param_value_it->second.BoolValue() : kDefaultBool, + has_value ? param_value_it->second.Value<bool>() : kDefaultBool, interpreter); break; default: @@ -1023,6 +1036,30 @@ Conversation ActionsSuggestions::AnnotateConversation( if (message->annotations.empty()) { message->annotations = annotator->Annotate( message->text, AnnotationOptionsForMessage(*message)); + for (int i = 0; i < message->annotations.size(); i++) { + ClassificationResult* classification = + &message->annotations[i].classification.front(); + + // Specialize datetime annotation to time annotation if no date + // component is present. + if (classification->collection == Collections::DateTime() && + classification->datetime_parse_result.IsSet()) { + bool has_only_time = true; + for (const DatetimeComponent& component : + classification->datetime_parse_result.datetime_components) { + if (component.component_type != + DatetimeComponent::ComponentType::UNSPECIFIED && + component.component_type < + DatetimeComponent::ComponentType::HOUR) { + has_only_time = false; + break; + } + } + if (has_only_time) { + classification->collection = kTimeAnnotation; + } + } + } } } return annotated_conversation; @@ -1224,6 +1261,13 @@ bool ActionsSuggestions::GatherActionsSuggestions( SuggestActionsFromAnnotations(annotated_conversation, &response->actions); + if (grammar_actions_ != nullptr && + !grammar_actions_->SuggestActions(annotated_conversation, + &response->actions)) { + TC3_LOG(ERROR) << "Could not suggest actions from grammar rules."; + return false; + } + int input_text_length = 0; int num_matching_locales = 0; for (int i = annotated_conversation.messages.size() - num_messages; @@ -1299,13 +1343,6 @@ bool ActionsSuggestions::GatherActionsSuggestions( return false; } - if (grammar_actions_ != nullptr && - !grammar_actions_->SuggestActions(annotated_conversation, - &response->actions)) { - TC3_LOG(ERROR) << "Could not suggest actions from grammar rules."; - return false; - } - if (preconditions_.suppress_on_low_confidence_input && !regex_actions_->FilterConfidenceOutput(post_check_rules, &response->actions)) { |