summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTony Mak <tonymak@google.com>2018-12-07 18:33:22 +0000
committerTony Mak <tonymak@google.com>2018-12-07 18:33:22 +0000
commit185e1eb6f9294559ab911929d40a7117f8f11c0f (patch)
tree6ffab1a7cebcf42bfeb6dac03ad6ce66f66b8034
parent296b7b68be844bb90e55b83346ce24af8b729164 (diff)
downloadlibtextclassifier-185e1eb6f9294559ab911929d40a7117f8f11c0f.tar.gz
Export libtextclassifier to Android
Change-Id: I4d304cbc5f18394e16f7fc2b0f2f76e9d183ca55
-rw-r--r--actions/actions-suggestions.cc15
-rw-r--r--actions/actions-suggestions.h4
-rw-r--r--actions/actions-suggestions_test.cc4
-rw-r--r--actions/actions_jni.cc10
-rw-r--r--java/com/google/android/textclassifier/ActionsSuggestionsModel.java15
-rw-r--r--utils/java/jni-cache.cc11
6 files changed, 39 insertions, 20 deletions
diff --git a/actions/actions-suggestions.cc b/actions/actions-suggestions.cc
index d7d261f..b8cd895 100644
--- a/actions/actions-suggestions.cc
+++ b/actions/actions-suggestions.cc
@@ -263,12 +263,25 @@ void ActionsSuggestions::SuggestActionsFromModel(
std::vector<float> time_diffs;
// Gather last `num_messages` messages from the conversation.
+ int64 last_message_reference_time_ms_utc = 0;
+ const float second_in_ms = 1000;
for (int i = conversation.messages.size() - num_messages;
i < conversation.messages.size(); i++) {
const ConversationMessage& message = conversation.messages[i];
context.push_back(message.text);
user_ids.push_back(message.user_id);
- time_diffs.push_back(message.time_diff_secs);
+
+ float time_diff_secs = 0;
+ if (message.reference_time_ms_utc != 0 &&
+ last_message_reference_time_ms_utc != 0) {
+ time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
+ last_message_reference_time_ms_utc) /
+ second_in_ms);
+ }
+ if (message.reference_time_ms_utc != 0) {
+ last_message_reference_time_ms_utc = message.reference_time_ms_utc;
+ }
+ time_diffs.push_back(time_diff_secs);
}
SetupModelInput(context, user_ids, time_diffs,
diff --git a/actions/actions-suggestions.h b/actions/actions-suggestions.h
index fa7807e..75b5f81 100644
--- a/actions/actions-suggestions.h
+++ b/actions/actions-suggestions.h
@@ -93,8 +93,8 @@ struct ConversationMessage {
// Text of the message.
std::string text;
- // Relative time to previous message.
- float time_diff_secs;
+ // Reference time of this message.
+ int64 reference_time_ms_utc;
// Annotations on the text.
std::vector<AnnotatedSpan> annotations;
diff --git a/actions/actions-suggestions_test.cc b/actions/actions-suggestions_test.cc
index df8abcd..883560a 100644
--- a/actions/actions-suggestions_test.cc
+++ b/actions/actions-suggestions_test.cc
@@ -172,9 +172,9 @@ TEST(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) {
ClassificationResult(Annotator::kAddressCollection, 1.0)};
const ActionsSuggestionsResponse& response =
actions_suggestions->SuggestActions(
- {{{/*user_id=*/0, "hi, how are you?", /*time_diff_secs=*/0},
+ {{{/*user_id=*/0, "hi, how are you?", /*reference_time=*/10000},
{/*user_id=*/1, "good! are you at home?",
- /*time_diff_secs=*/60,
+ /*reference_time=*/15000,
/*annotations=*/{annotation}}}});
EXPECT_EQ(response.actions.size(), 1);
EXPECT_EQ(response.actions.back().type, "view_map");
diff --git a/actions/actions_jni.cc b/actions/actions_jni.cc
index 17571c3..2924f3c 100644
--- a/actions/actions_jni.cc
+++ b/actions/actions_jni.cc
@@ -116,14 +116,14 @@ ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
const std::pair<bool, int32> status_or_user_id =
CallJniMethod0<int32>(env, jmessage, message_class.get(),
&JNIEnv::CallIntMethod, "getUserId", "I");
- const std::pair<bool, int32> status_or_time_diff = CallJniMethod0<int32>(
- env, jmessage, message_class.get(), &JNIEnv::CallIntMethod,
- "getTimeDiffInSeconds", "I");
+ const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
+ env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
+ "getReferenceTimeMsUtc", "J");
const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
"getLocales", "Ljava/lang/String;");
if (!status_or_text.first || !status_or_user_id.first ||
- !status_or_locales.first || !status_or_time_diff.first) {
+ !status_or_locales.first || !status_or_reference_time.first) {
return {};
}
@@ -131,7 +131,7 @@ ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
message.text =
ToStlString(env, reinterpret_cast<jstring>(status_or_text.second));
message.user_id = status_or_user_id.second;
- message.time_diff_secs = status_or_time_diff.second;
+ message.reference_time_ms_utc = status_or_reference_time.second;
message.locales =
ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
return message;
diff --git a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
index a836e74..661ca3b 100644
--- a/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/java/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -135,13 +135,13 @@ public final class ActionsSuggestionsModel implements AutoCloseable {
public static final class ConversationMessage {
private final int userId;
private final String text;
- private final int timeDiffInSeconds;
+ private final long referenceTimeMsUtc;
private final String locales;
- public ConversationMessage(int userId, String text, int timeDiffInSeconds, String locales) {
+ public ConversationMessage(int userId, String text, long referenceTimeMsUtc, String locales) {
this.userId = userId;
this.text = text;
- this.timeDiffInSeconds = timeDiffInSeconds;
+ this.referenceTimeMsUtc = referenceTimeMsUtc;
this.locales = locales;
}
@@ -155,13 +155,14 @@ public final class ActionsSuggestionsModel implements AutoCloseable {
}
/**
- * The time difference (in seconds) between the first message of the coversation and this
- * message, value {@code 0} means unspecified.
+ * Return the reference time of the message, for example, it could be compose time or send time.
+ * {@code 0} means unspecified.
*/
- public int getTimeDiffInSeconds() {
- return timeDiffInSeconds;
+ public long getReferenceTimeMsUtc() {
+ return referenceTimeMsUtc;
}
+ /** Returns a comma separated list of locales supported by the model as BCP 47 tags. */
public String getLocales() {
return locales;
}
diff --git a/utils/java/jni-cache.cc b/utils/java/jni-cache.cc
index 4bb9523..ce52288 100644
--- a/utils/java/jni-cache.cc
+++ b/utils/java/jni-cache.cc
@@ -46,9 +46,14 @@ JniCache::JniCache(JavaVM* jvm)
result->FIELD##_class = MakeGlobalRef(env->FindClass(NAME), env, jvm); \
TC3_CHECK_JNI_PTR(result->FIELD##_class)
-#define TC3_GET_OPTIONAL_CLASS(FIELD, NAME) \
- result->FIELD##_class = MakeGlobalRef(env->FindClass(NAME), env, jvm); \
- env->ExceptionClear();
+#define TC3_GET_OPTIONAL_CLASS(FIELD, NAME) \
+ { \
+ jclass clazz = env->FindClass(NAME); \
+ if (clazz != nullptr) { \
+ result->FIELD##_class = MakeGlobalRef(clazz, env, jvm); \
+ } \
+ env->ExceptionClear(); \
+ }
#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
result->CLASS##_##FIELD = \