aboutsummaryrefslogtreecommitdiff
path: root/ink_stroke_modeler
diff options
context:
space:
mode:
Diffstat (limited to 'ink_stroke_modeler')
-rw-r--r--ink_stroke_modeler/BUILD.bazel97
-rw-r--r--ink_stroke_modeler/CMakeLists.txt102
-rw-r--r--ink_stroke_modeler/internal/BUILD.bazel143
-rw-r--r--ink_stroke_modeler/internal/CMakeLists.txt156
-rw-r--r--ink_stroke_modeler/internal/internal_types.h73
-rw-r--r--ink_stroke_modeler/internal/position_modeler.h138
-rw-r--r--ink_stroke_modeler/internal/position_modeler_test.cc317
-rw-r--r--ink_stroke_modeler/internal/prediction/BUILD.bazel83
-rw-r--r--ink_stroke_modeler/internal/prediction/CMakeLists.txt86
-rw-r--r--ink_stroke_modeler/internal/prediction/input_predictor.h57
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/BUILD.bazel58
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/CMakeLists.txt54
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.cc98
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h62
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc100
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.cc79
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h96
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h270
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_filter/matrix_test.cc215
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_predictor.cc265
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_predictor.h103
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc215
-rw-r--r--ink_stroke_modeler/internal/prediction/stroke_end_predictor.cc52
-rw-r--r--ink_stroke_modeler/internal/prediction/stroke_end_predictor.h58
-rw-r--r--ink_stroke_modeler/internal/prediction/stroke_end_predictor_test.cc137
-rw-r--r--ink_stroke_modeler/internal/stylus_state_modeler.cc101
-rw-r--r--ink_stroke_modeler/internal/stylus_state_modeler.h82
-rw-r--r--ink_stroke_modeler/internal/stylus_state_modeler_test.cc269
-rw-r--r--ink_stroke_modeler/internal/type_matchers.cc69
-rw-r--r--ink_stroke_modeler/internal/type_matchers.h42
-rw-r--r--ink_stroke_modeler/internal/utils.h99
-rw-r--r--ink_stroke_modeler/internal/utils_test.cc87
-rw-r--r--ink_stroke_modeler/internal/validation.h48
-rw-r--r--ink_stroke_modeler/internal/validation_test.cc82
-rw-r--r--ink_stroke_modeler/internal/wobble_smoother.cc70
-rw-r--r--ink_stroke_modeler/internal/wobble_smoother.h57
-rw-r--r--ink_stroke_modeler/internal/wobble_smoother_test.cc104
-rw-r--r--ink_stroke_modeler/params.cc157
-rw-r--r--ink_stroke_modeler/params.h224
-rw-r--r--ink_stroke_modeler/params_test.cc259
-rw-r--r--ink_stroke_modeler/stroke_modeler.cc253
-rw-r--r--ink_stroke_modeler/stroke_modeler.h99
-rw-r--r--ink_stroke_modeler/stroke_modeler_test.cc1395
-rw-r--r--ink_stroke_modeler/types.cc36
-rw-r--r--ink_stroke_modeler/types.h356
-rw-r--r--ink_stroke_modeler/types_test.cc212
46 files changed, 7215 insertions, 0 deletions
diff --git a/ink_stroke_modeler/BUILD.bazel b/ink_stroke_modeler/BUILD.bazel
new file mode 100644
index 0000000..d48f1ea
--- /dev/null
+++ b/ink_stroke_modeler/BUILD.bazel
@@ -0,0 +1,97 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])
+
+cc_library(
+ name = "params",
+ srcs = ["params.cc"],
+ hdrs = ["params.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ "//ink_stroke_modeler/internal:validation",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:variant",
+ ],
+)
+
+cc_test(
+ name = "params_test",
+ srcs = ["params_test.cc"],
+ deps = [
+ ":params",
+ "@com_google_absl//absl/status",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "types",
+ srcs = ["types.cc"],
+ hdrs = ["types.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//ink_stroke_modeler/internal:validation",
+ "@com_google_absl//absl/status",
+ ],
+)
+
+cc_test(
+ name = "types_test",
+ srcs = ["types_test.cc"],
+ deps = [
+ ":types",
+ "//ink_stroke_modeler/internal:type_matchers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "stroke_modeler",
+ srcs = ["stroke_modeler.cc"],
+ hdrs = ["stroke_modeler.h"],
+ deps = [
+ ":params",
+ ":types",
+ "//ink_stroke_modeler/internal:internal_types",
+ "//ink_stroke_modeler/internal:position_modeler",
+ "//ink_stroke_modeler/internal:stylus_state_modeler",
+ "//ink_stroke_modeler/internal:wobble_smoother",
+ "//ink_stroke_modeler/internal/prediction:input_predictor",
+ "//ink_stroke_modeler/internal/prediction:kalman_predictor",
+ "//ink_stroke_modeler/internal/prediction:stroke_end_predictor",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/status:statusor",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_absl//absl/types:variant",
+ ],
+)
+
+cc_test(
+ name = "stroke_modeler_test",
+ srcs = ["stroke_modeler_test.cc"],
+ deps = [
+ ":params",
+ ":stroke_modeler",
+ "//ink_stroke_modeler/internal:type_matchers",
+ "@com_google_absl//absl/status",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/ink_stroke_modeler/CMakeLists.txt b/ink_stroke_modeler/CMakeLists.txt
new file mode 100644
index 0000000..736663a
--- /dev/null
+++ b/ink_stroke_modeler/CMakeLists.txt
@@ -0,0 +1,102 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_subdirectory(internal)
+
+ink_cc_library(
+ NAME
+ params
+ SRCS
+ params.cc
+ HDRS
+ params.h
+ DEPS
+ InkStrokeModeler::types
+ absl::status
+ absl::strings
+ absl::variant
+ InkStrokeModeler::validation
+)
+
+ink_cc_test(
+ NAME
+ params_test
+ SRCS
+ params_test.cc
+ DEPS
+ InkStrokeModeler::params
+ absl::status
+ GTest::gtest_main
+)
+
+ink_cc_library(
+ NAME
+ types
+ SRCS
+ types.cc
+ HDRS
+ types.h
+ DEPS
+ absl::status
+ InkStrokeModeler::validation
+)
+
+ink_cc_test(
+ NAME
+ types_test
+ SRCS
+ types_test.cc
+ DEPS
+ InkStrokeModeler::types
+ InkStrokeModeler::type_matchers
+ GTest::gmock_main
+)
+
+ink_cc_library(
+ NAME
+ stroke_modeler
+ SRCS
+ stroke_modeler.cc
+ HDRS
+ stroke_modeler.h
+ DEPS
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+ InkStrokeModeler::internal_types
+ InkStrokeModeler::position_modeler
+ InkStrokeModeler::stylus_state_modeler
+ InkStrokeModeler::wobble_smoother
+ InkStrokeModeler::input_predictor
+ InkStrokeModeler::kalman_predictor
+ InkStrokeModeler::stroke_end_predictor
+ absl::core_headers
+ absl::memory
+ absl::status
+ absl::statusor
+ absl::optional
+ absl::variant
+)
+
+ink_cc_test(
+ NAME
+ stroke_modeler_test
+ SRCS
+ stroke_modeler_test.cc
+ DEPS
+ InkStrokeModeler::params
+ InkStrokeModeler::stroke_modeler
+ InkStrokeModeler::type_matchers
+ absl::status
+ GTest::gmock_main
+)
diff --git a/ink_stroke_modeler/internal/BUILD.bazel b/ink_stroke_modeler/internal/BUILD.bazel
new file mode 100644
index 0000000..f4e1d98
--- /dev/null
+++ b/ink_stroke_modeler/internal/BUILD.bazel
@@ -0,0 +1,143 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//ink_stroke_modeler:__subpackages__"],
+)
+
+licenses(["notice"])
+
+cc_library(
+ name = "internal_types",
+ hdrs = ["internal_types.h"],
+ deps = ["//ink_stroke_modeler:types"],
+)
+
+cc_library(
+ name = "type_matchers",
+ testonly = 1,
+ srcs = ["type_matchers.cc"],
+ hdrs = ["type_matchers.h"],
+ deps = [
+ ":internal_types",
+ "//:gtest_for_library_testonly",
+ "//ink_stroke_modeler:types",
+ ],
+)
+
+cc_library(
+ name = "utils",
+ hdrs = ["utils.h"],
+ deps = ["//ink_stroke_modeler:types"],
+)
+
+cc_test(
+ name = "utils_test",
+ srcs = ["utils_test.cc"],
+ deps = [
+ ":type_matchers",
+ ":utils",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "wobble_smoother",
+ srcs = ["wobble_smoother.cc"],
+ hdrs = ["wobble_smoother.h"],
+ deps = [
+ ":utils",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ ],
+)
+
+cc_test(
+ name = "wobble_smoother_test",
+ srcs = ["wobble_smoother_test.cc"],
+ deps = [
+ ":type_matchers",
+ ":wobble_smoother",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "position_modeler",
+ hdrs = ["position_modeler.h"],
+ deps = [
+ ":internal_types",
+ ":utils",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ ],
+)
+
+cc_test(
+ name = "position_modeler_test",
+ srcs = ["position_modeler_test.cc"],
+ deps = [
+ ":internal_types",
+ ":position_modeler",
+ ":type_matchers",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "stylus_state_modeler",
+ srcs = ["stylus_state_modeler.cc"],
+ hdrs = ["stylus_state_modeler.h"],
+ deps = [
+ ":internal_types",
+ ":utils",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ ],
+)
+
+cc_test(
+ name = "stylus_state_modeler_test",
+ srcs = ["stylus_state_modeler_test.cc"],
+ deps = [
+ ":internal_types",
+ ":stylus_state_modeler",
+ ":type_matchers",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "validation",
+ hdrs = ["validation.h"],
+ deps = [
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_test(
+ name = "validaiton_test",
+ srcs = ["validation_test.cc"],
+ deps = [
+ ":validation",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/ink_stroke_modeler/internal/CMakeLists.txt b/ink_stroke_modeler/internal/CMakeLists.txt
new file mode 100644
index 0000000..b1886df
--- /dev/null
+++ b/ink_stroke_modeler/internal/CMakeLists.txt
@@ -0,0 +1,156 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_subdirectory(prediction)
+
+ink_cc_library(
+ NAME
+ internal_types
+ HDRS
+ internal_types.h
+ DEPS
+ InkStrokeModeler::types
+)
+
+ink_cc_library(
+ NAME
+ type_matchers
+ SRCS
+ type_matchers.cc
+ HDRS
+ type_matchers.h
+ DEPS
+ InkStrokeModeler::internal_types
+ InkStrokeModeler::types
+ GTest::gtest_main
+)
+
+ink_cc_library(
+ NAME
+ utils
+ HDRS
+ utils.h
+ DEPS
+ InkStrokeModeler::types
+)
+
+ink_cc_test(
+ NAME
+ utils_test
+ SRCS
+ utils_test.cc
+ DEPS
+ InkStrokeModeler::type_matchers
+ InkStrokeModeler::utils
+ GTest::gmock_main
+)
+
+ink_cc_library(
+ NAME
+ wobble_smoother
+ SRCS
+ wobble_smoother.cc
+ HDRS
+ wobble_smoother.h
+ DEPS
+ InkStrokeModeler::utils
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+)
+
+ink_cc_test(
+ NAME
+ wobble_smoother_test
+ SRCS
+ wobble_smoother_test.cc
+ DEPS
+ InkStrokeModeler::type_matchers
+ InkStrokeModeler::wobble_smoother
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+ GTest::gmock_main
+)
+
+ink_cc_library(
+ NAME
+ position_modeler
+ HDRS
+ position_modeler.h
+ DEPS
+ InkStrokeModeler::internal_types
+ InkStrokeModeler::utils
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+)
+
+ink_cc_test(
+ NAME
+ position_modeler_test
+ SRCS
+ position_modeler_test.cc
+ DEPS
+ InkStrokeModeler::internal_types
+ InkStrokeModeler::position_modeler
+ InkStrokeModeler::type_matchers
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+ GTest::gmock_main
+)
+
+ink_cc_library(
+ NAME
+ stylus_state_modeler
+ SRCS
+ stylus_state_modeler.cc
+ HDRS
+ stylus_state_modeler.h
+ DEPS
+ InkStrokeModeler::internal_types
+ InkStrokeModeler::utils
+ InkStrokeModeler::params
+)
+
+ink_cc_test(
+ NAME
+ stylus_state_modeler_test
+ SRCS
+ stylus_state_modeler_test.cc
+ DEPS
+ InkStrokeModeler::internal_types
+ InkStrokeModeler::stylus_state_modeler
+ InkStrokeModeler::type_matchers
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+ GTest::gmock_main
+)
+
+ink_cc_library(
+ NAME
+ validation
+ HDRS
+ validation.h
+ DEPS
+ absl::status
+ absl::strings
+)
+
+ink_cc_test(
+ NAME
+ validation_test
+ SRCS
+ validation_test.cc
+ DEPS
+ InkStrokeModeler::validation
+ GTest::gtest_main
+)
diff --git a/ink_stroke_modeler/internal/internal_types.h b/ink_stroke_modeler/internal/internal_types.h
new file mode 100644
index 0000000..1c6b671
--- /dev/null
+++ b/ink_stroke_modeler/internal/internal_types.h
@@ -0,0 +1,73 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_INTERNAL_TYPES_H_
+#define INK_STROKE_MODELER_INTERNAL_INTERNAL_TYPES_H_
+
+#include <ostream>
+
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// This struct contains the position and velocity of the modeled pen tip at
+// the indicated time.
+struct TipState {
+ Vec2 position{0};
+ Vec2 velocity{0};
+ Time time{0};
+};
+
+std::ostream &operator<<(std::ostream &s, const TipState &tip_state);
+
+// This struct contains information about the state of the stylus. See the
+// corresponding fields on the Input struct for more info.
+struct StylusState {
+ float pressure = -1;
+ float tilt = -1;
+ float orientation = -1;
+};
+
+bool operator==(const StylusState &lhs, const StylusState &rhs);
+std::ostream &operator<<(std::ostream &s, const StylusState &stylus_state);
+
+////////////////////////////////////////////////////////////////////////////////
+// Inline function definitions
+////////////////////////////////////////////////////////////////////////////////
+
+inline bool operator==(const StylusState &lhs, const StylusState &rhs) {
+ return lhs.pressure == rhs.pressure && lhs.tilt == rhs.tilt &&
+ lhs.orientation == rhs.orientation;
+}
+
+inline std::ostream &operator<<(std::ostream &s, const TipState &tip_state) {
+ return s << "TipState: pos: " << tip_state.position
+ << ", velocity: " << tip_state.velocity
+ << ", time: " << tip_state.time << ">";
+}
+
+inline std::ostream &operator<<(std::ostream &s,
+ const StylusState &stylus_state) {
+ return s << "<Result: pressure: " << stylus_state.pressure
+ << ", tilt: " << stylus_state.tilt
+ << ", orientation: " << stylus_state.orientation << ">";
+}
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_INTERNAL_TYPES_H_
diff --git a/ink_stroke_modeler/internal/position_modeler.h b/ink_stroke_modeler/internal/position_modeler.h
new file mode 100644
index 0000000..c3276e8
--- /dev/null
+++ b/ink_stroke_modeler/internal/position_modeler.h
@@ -0,0 +1,138 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_POSITION_MODELER_H_
+#define INK_STROKE_MODELER_INTERNAL_POSITION_MODELER_H_
+
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/internal/utils.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// This class models the movement of the pen tip based on the laws of motion.
+// The pen tip is represented as a mass, connected by a spring to a moving
+// anchor; as the anchor moves, it drags the pen tip along behind it.
+class PositionModeler {
+ public:
+ void Reset(const TipState& state, PositionModelerParams params) {
+ state_ = state;
+ params_ = params;
+ }
+
+ // Given the position of the anchor and the time, updates the model and
+ // returns the state of the pen tip.
+ TipState Update(Vec2 anchor_position, Time time) {
+ Duration delta_time = time - state_.time;
+ float float_delta = delta_time.Value();
+ auto acceleration =
+ (anchor_position - state_.position) / params_.spring_mass_constant -
+ params_.drag_constant * state_.velocity;
+ state_.velocity += float_delta * acceleration;
+ state_.position += float_delta * state_.velocity;
+ state_.time = time;
+
+ return state_;
+ }
+
+ const TipState& CurrentState() const { return state_; }
+ const PositionModelerParams& Params() const { return params_; }
+
+ // This helper function linearly interpolates between the between the start
+ // and end anchor position and time, updating the model at each step and
+ // storing the result in the given output iterator.
+ //
+ // NOTE: Because the expected use case is to repeatedly call this function on
+ // a sequence of anchor positions/times, the start position/time is not sent
+ // to the model. This prevents us from duplicating those inputs, but it does
+ // mean that the first input must be provided on its own, via either Reset()
+ // or Update(). This also means that the interpolation values are
+ // (1 ... n) / n, as opposed to (0 ... (n - 1)) / (n - 1).
+ //
+ // Template parameter OutputIt is expected to be an output iterator over
+ // TipState.
+ template <typename OutputIt>
+ void UpdateAlongLinearPath(Vec2 start_anchor_position, Time start_time,
+ Vec2 end_anchor_position, Time end_time,
+ int n_samples, OutputIt output) {
+ for (int i = 1; i <= n_samples; ++i) {
+ auto interp_value = static_cast<float>(i) / n_samples;
+ auto position =
+ Interp(start_anchor_position, end_anchor_position, interp_value);
+ auto time = Interp(start_time, end_time, interp_value);
+ *output++ = Update(position, time);
+ }
+ }
+
+ // This helper function models the end of the stroke, by repeatedly updating
+ // with the final anchor position. It attempts to stop at the closest point to
+ // the anchor, by checking if it has overshot, and retrying with successively
+ // smaller time steps.
+ //
+ // It halts when any of these three conditions is met:
+ // - It has taken more than max_iterations steps (including discarded steps)
+ // - The distance between the current state and the anchor is less than
+ // stop_distance
+ // - The distance between the previous state and the current state is less
+ // than stop_distance
+ //
+ // Template parameter OutputIt is expected to be an output iterator over
+ // TipState.
+ template <typename OutputIt>
+ void ModelEndOfStroke(Vec2 anchor_position, Duration delta_time,
+ int max_iterations, float stop_distance,
+ OutputIt output) {
+ for (int i = 0; i < max_iterations; ++i) {
+ // The call to Update modifies the state, so we store a copy of the
+ // previous state so we can retry with a smaller step if necessary.
+ const TipState previous_state = state_;
+ TipState candidate =
+ Update(anchor_position, previous_state.time + delta_time);
+ if (Distance(previous_state.position, candidate.position) <
+ stop_distance) {
+ // We're no longer making any significant progress, which means that
+ // we're about as close as we can get without looping around.
+ return;
+ }
+
+ float closest_t = NearestPointOnSegment(
+ previous_state.position, candidate.position, anchor_position);
+ if (closest_t < 1) {
+ // We're overshot the anchor, retry with a smaller step.
+ delta_time *= .5;
+ state_ = previous_state;
+ continue;
+ }
+ *output++ = candidate;
+
+ if (Distance(candidate.position, anchor_position) < stop_distance) {
+ // We're within tolerance of the anchor.
+ return;
+ }
+ }
+ }
+
+ private:
+ PositionModelerParams params_;
+ TipState state_;
+};
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_POSITION_MODELER_H_
diff --git a/ink_stroke_modeler/internal/position_modeler_test.cc b/ink_stroke_modeler/internal/position_modeler_test.cc
new file mode 100644
index 0000000..7898ea1
--- /dev/null
+++ b/ink_stroke_modeler/internal/position_modeler_test.cc
@@ -0,0 +1,317 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/position_modeler.h"
+
+#include <cmath>
+#include <iterator>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/internal/type_matchers.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+using ::testing::ElementsAre;
+
+const Duration kDefaultTimeStep(1. / 180);
+constexpr float kTol = .00005;
+
+// The expected position values are taken directly from results the old
+// TipDynamics class. The expected velocity values are from the same source, but
+// a multiplier of 300 (i.e. dt / (1 - drag)) had to be applied to account for
+// the fact that PositionModeler uses the time step correctly.
+
+TEST(PositionModelerTest, StraightLine) {
+ PositionModeler modeler;
+ Time current_time(0);
+ modeler.Reset({{0, 0}, {0, 0}, current_time}, PositionModelerParams());
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(modeler.Update({1, 0}, current_time),
+ TipStateNear({{.0909, 0}, {16.3636, 0}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(modeler.Update({2, 0}, current_time),
+ TipStateNear({{.319, 0}, {41.0579, 0}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(modeler.Update({3, 0}, current_time),
+ TipStateNear({{.6996, 0}, {68.5055, 0}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(modeler.Update({4, 0}, current_time),
+ TipStateNear({{1.228, 0}, {95.1099, 0}, current_time}, kTol));
+}
+
+TEST(PositionModelerTest, ZigZag) {
+ PositionModeler modeler;
+ Time current_time(3);
+ modeler.Reset({{-1, -1}, {0, 0}, current_time}, PositionModelerParams());
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(modeler.Update({-.5, -1}, current_time),
+ TipStateNear({{-.9545, -1}, {8.1818, 0}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({-.5, -.5}, current_time),
+ TipStateNear({{-.886, -.9545}, {12.3471, 8.1818}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({0, -.5}, current_time),
+ TipStateNear({{-.7643, -.886}, {21.9056, 12.3471}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({0, 0}, current_time),
+ TipStateNear({{-.6218, -.7643}, {25.6493, 21.9056}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({.5, 0}, current_time),
+ TipStateNear({{-.4343, -.6218}, {33.7456, 25.6493}, current_time}, kTol));
+}
+
+TEST(PositionModelerTest, SharpTurn) {
+ PositionModeler modeler;
+ Time current_time(1.6);
+ modeler.Reset({{0, 0}, {0, 0}, current_time}, PositionModelerParams());
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({.25, .25}, current_time),
+ TipStateNear({{.0227, .0227}, {4.0909, 4.0909}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({.5, .5}, current_time),
+ TipStateNear({{.0798, .0798}, {10.2645, 10.2645}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({.75, .75}, current_time),
+ TipStateNear({{.1749, .1749}, {17.1264, 17.1264}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({1, 1}, current_time),
+ TipStateNear({{.307, .307}, {23.7775, 23.7775}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({1.25, .75}, current_time),
+ TipStateNear({{.472, .4265}, {29.6975, 21.5157}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({1.5, .5}, current_time),
+ TipStateNear({{.6644, .5049}, {34.6406, 14.1117}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({1.75, .25}, current_time),
+ TipStateNear({{.8786, .5288}, {38.5482, 4.2955}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update({2, 0}, current_time),
+ TipStateNear({{1.109, .495}, {41.4794, -6.0756}, current_time}, kTol));
+}
+
+TEST(PositionModelerTest, SmoothTurn) {
+ auto point_on_circle = [](float theta) {
+ return Vec2{std::cos(theta), std::sin(theta)};
+ };
+
+ PositionModeler modeler;
+ Time current_time(10.1);
+ modeler.Reset({point_on_circle(0), {0, 0}, current_time},
+ PositionModelerParams());
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update(point_on_circle(M_PI * .125), current_time),
+ TipStateNear({{.9931, .0348}, {-1.2456, 6.2621}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update(point_on_circle(M_PI * .25), current_time),
+ TipStateNear({{0.9629, 0.1168}, {-5.4269, 14.7588}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(modeler.Update(point_on_circle(M_PI * .375), current_time),
+ TipStateNear(
+ {{0.8921, 0.2394}, {-12.7511, 22.0623}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(modeler.Update(point_on_circle(M_PI * .5), current_time),
+ TipStateNear(
+ {{0.7685, 0.3820}, {-22.2485, 25.6844}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(modeler.Update(point_on_circle(M_PI * .625), current_time),
+ TipStateNear(
+ {{0.5897, 0.5169}, {-32.1865, 24.2771}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(modeler.Update(point_on_circle(M_PI * .75), current_time),
+ TipStateNear(
+ {{0.3645, 0.6151}, {-40.5319, 17.6785}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update(point_on_circle(M_PI * .875), current_time),
+ TipStateNear({{0.1123, 0.6529}, {-45.4017, 6.8034}, current_time}, kTol));
+
+ current_time += kDefaultTimeStep;
+ EXPECT_THAT(
+ modeler.Update(point_on_circle(M_PI), current_time),
+ TipStateNear({{-0.1402, 0.6162}, {-45.4417, -6.6022}, current_time},
+ kTol));
+}
+
+TEST(PositionModelerTest, UpdateAlongLinearPath) {
+ PositionModeler modeler;
+ modeler.Reset({{5, 10}, {0, 0}, Time{3}}, PositionModelerParams());
+
+ std::vector<TipState> result;
+ modeler.UpdateAlongLinearPath({5, 10}, Time{3}, {15, 10}, Time{3.05}, 5,
+ std::back_inserter(result));
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ TipStateNear({{5.5891, 10}, {58.9091, 0}, Time{3.01}}, kTol),
+ TipStateNear({{6.7587, 10}, {116.9613, 0}, Time{3.02}}, kTol),
+ TipStateNear({{8.3355, 10}, {157.6746, 0}, Time{3.03}}, kTol),
+ TipStateNear({{10.1509, 10}, {181.5411, 0}, Time{3.04}}, kTol),
+ TipStateNear({{12.0875, 10}, {193.6607, 0}, Time{3.05}}, kTol)));
+
+ result.clear();
+ modeler.UpdateAlongLinearPath({15, 10}, Time{3.05}, {15, 16}, Time{3.08}, 3,
+ std::back_inserter(result));
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ TipStateNear({{13.4876, 10.5891}, {140.0123, 58.9091}, Time{3.06}},
+ kTol),
+ TipStateNear({{14.3251, 11.7587}, {83.7508, 116.9613}, Time{3.07}},
+ kTol),
+ TipStateNear({{14.7584, 13.3355}, {43.3291, 157.6746}, Time{3.08}},
+ kTol)));
+}
+
+TEST(PositionModelerTest, ModelEndOfStrokeStationary) {
+ PositionModeler modeler;
+ modeler.Reset({{4, -2}, {0, 0}, Time{0}}, PositionModelerParams());
+
+ std::vector<TipState> result;
+ modeler.ModelEndOfStroke({3, -1}, Duration(1. / 180), 20, .01,
+ std::back_inserter(result));
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ TipStateNear({{3.9091, -1.9091}, {-16.3636, 16.3636}, Time{0.0056}},
+ kTol),
+ TipStateNear({{3.7719, -1.7719}, {-24.6942, 24.6942}, Time{0.0111}},
+ kTol),
+ TipStateNear({{3.6194, -1.6194}, {-27.4476, 27.4476}, Time{0.0167}},
+ kTol),
+ TipStateNear({{3.4716, -1.4716}, {-26.6045, 26.6044}, Time{0.0222}},
+ kTol),
+ TipStateNear({{3.3401, -1.3401}, {-23.6799, 23.6799}, Time{0.0278}},
+ kTol),
+ TipStateNear({{3.2302, -1.2302}, {-19.7725, 19.7725}, Time{0.0333}},
+ kTol),
+ TipStateNear({{3.1434, -1.1434}, {-15.6306, 15.6306}, Time{0.0389}},
+ kTol),
+ TipStateNear({{3.0782, -1.0782}, {-11.7244, 11.7244}, Time{0.0444}},
+ kTol),
+ TipStateNear({{3.0320, -1.0320}, {-8.3149, 8.3149}, Time{0.0500}},
+ kTol),
+ TipStateNear({{3.0014, -1.0014}, {-5.5133, 5.5133}, Time{0.0556}},
+ kTol)));
+}
+
+TEST(PositionModelerTest, ModelEndOfStrokeInMotion) {
+ PositionModeler modeler;
+ modeler.Reset({{-1, 2}, {40, 10}, Time{1}}, PositionModelerParams());
+
+ std::vector<TipState> result;
+ modeler.ModelEndOfStroke({7, 2}, Duration(1. / 120), 20, .01,
+ std::back_inserter(result));
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ TipStateNear({{0.7697, 2.0333}, {212.3636, 4.0000}, Time{1.0083}},
+ kTol),
+ TipStateNear({{2.7520, 2.0398}, {237.8711, 0.7818}, Time{1.0167}},
+ kTol),
+ TipStateNear({{4.4138, 2.0343}, {199.4186, -0.6654}, Time{1.0250}},
+ kTol),
+ TipStateNear({{5.6075, 2.0251}, {143.2474, -1.1081}, Time{1.0333}},
+ kTol),
+ TipStateNear({{6.3698, 2.0162}, {91.4784, -1.0586}, Time{1.0417}},
+ kTol),
+ TipStateNear({{6.8037, 2.0094}, {52.0592, -0.8222}, Time{1.0500}},
+ kTol),
+ TipStateNear({{6.9655, 2.0065}, {38.8512, -0.6909}, Time{1.0542}},
+ kTol),
+ TipStateNear({{6.9850, 2.0062}, {37.4471, -0.6750}, Time{1.0547}},
+ kTol)));
+}
+
+TEST(PositionModelerTest, ModelEndOfStrokeMaxIterationsReached) {
+ PositionModeler modeler;
+ modeler.Reset({{8, -3}, {-100, -150}, Time{1}}, PositionModelerParams());
+
+ std::vector<TipState> result;
+ modeler.ModelEndOfStroke({-9, -10}, Duration(.0001), 10, .001,
+ std::back_inserter(result));
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ TipStateNear(
+ {{7.9896, -3.0151}, {-104.2873, -150.9818}, Time{1.0001}}, kTol),
+ TipStateNear(
+ {{7.9787, -3.0303}, {-108.5406, -151.9521}, Time{1.0002}}, kTol),
+ TipStateNear(
+ {{7.9674, -3.0456}, {-112.7601, -152.9110}, Time{1.0003}}, kTol),
+ TipStateNear(
+ {{7.9557, -3.0610}, {-116.9459, -153.8584}, Time{1.0004}}, kTol),
+ TipStateNear(
+ {{7.9436, -3.0764}, {-121.0982, -154.7945}, Time{1.0005}}, kTol),
+ TipStateNear(
+ {{7.9311, -3.0920}, {-125.2169, -155.7193}, Time{1.0006}}, kTol),
+ TipStateNear(
+ {{7.9182, -3.1077}, {-129.3023, -156.6328}, Time{1.0007}}, kTol),
+ TipStateNear(
+ {{7.9048, -3.1234}, {-133.3545, -157.5351}, Time{1.0008}}, kTol),
+ TipStateNear(
+ {{7.8911, -3.1393}, {-137.3736, -158.4263}, Time{1.0009}}, kTol),
+ TipStateNear(
+ {{7.8770, -3.1552}, {-141.3597, -159.3065}, Time{1.0010}},
+ kTol)));
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/prediction/BUILD.bazel b/ink_stroke_modeler/internal/prediction/BUILD.bazel
new file mode 100644
index 0000000..2d94289
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/BUILD.bazel
@@ -0,0 +1,83 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//ink_stroke_modeler:__subpackages__"],
+)
+
+licenses(["notice"])
+
+cc_library(
+ name = "input_predictor",
+ hdrs = ["input_predictor.h"],
+ deps = [
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ "//ink_stroke_modeler/internal:internal_types",
+ ],
+)
+
+cc_library(
+ name = "kalman_predictor",
+ srcs = ["kalman_predictor.cc"],
+ hdrs = ["kalman_predictor.h"],
+ deps = [
+ ":input_predictor",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ "//ink_stroke_modeler/internal:internal_types",
+ "//ink_stroke_modeler/internal:utils",
+ "//ink_stroke_modeler/internal/prediction/kalman_filter",
+ ],
+)
+
+cc_test(
+ name = "kalman_predictor_test",
+ srcs = ["kalman_predictor_test.cc"],
+ deps = [
+ ":input_predictor",
+ ":kalman_predictor",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ "//ink_stroke_modeler/internal:type_matchers",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "stroke_end_predictor",
+ srcs = ["stroke_end_predictor.cc"],
+ hdrs = ["stroke_end_predictor.h"],
+ deps = [
+ ":input_predictor",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler:types",
+ "//ink_stroke_modeler/internal:internal_types",
+ "//ink_stroke_modeler/internal:position_modeler",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_test(
+ name = "stroke_end_predictor_test",
+ srcs = ["stroke_end_predictor_test.cc"],
+ deps = [
+ ":input_predictor",
+ ":stroke_end_predictor",
+ "//ink_stroke_modeler:params",
+ "//ink_stroke_modeler/internal:type_matchers",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/ink_stroke_modeler/internal/prediction/CMakeLists.txt b/ink_stroke_modeler/internal/prediction/CMakeLists.txt
new file mode 100644
index 0000000..e54dbbe
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/CMakeLists.txt
@@ -0,0 +1,86 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+add_subdirectory(kalman_filter)
+
+ink_cc_library(
+ NAME
+ input_predictor
+ HDRS
+ input_predictor.h
+ DEPS
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+ InkStrokeModeler::internal_types
+)
+
+ink_cc_library(
+ NAME
+ kalman_predictor
+ SRCS
+ kalman_predictor.cc
+ HDRS
+ kalman_predictor.h
+ DEPS
+ InkStrokeModeler::input_predictor
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+ InkStrokeModeler::internal_types
+ InkStrokeModeler::utils
+ InkStrokeModeler::kalman_filter
+)
+
+ink_cc_test(
+ NAME
+ kalman_predictor_test
+ SRCS
+ kalman_predictor_test.cc
+ DEPS
+ InkStrokeModeler::input_predictor
+ InkStrokeModeler::kalman_predictor
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+ InkStrokeModeler::type_matchers
+ absl::optional
+ GTest::gmock_main
+)
+
+ink_cc_library(
+ NAME
+ stroke_end_predictor
+ SRCS
+ stroke_end_predictor.cc
+ HDRS
+ stroke_end_predictor.h
+ DEPS
+ InkStrokeModeler::input_predictor
+ InkStrokeModeler::params
+ InkStrokeModeler::types
+ InkStrokeModeler::internal_types
+ InkStrokeModeler::position_modeler
+ absl::optional
+)
+
+ink_cc_test(
+ NAME
+ stroke_end_predictor_test
+ SRCS
+ stroke_end_predictor_test.cc
+ DEPS
+ InkStrokeModeler::input_predictor
+ InkStrokeModeler::stroke_end_predictor
+ InkStrokeModeler::params
+ InkStrokeModeler::type_matchers
+ GTest::gmock_main
+)
diff --git a/ink_stroke_modeler/internal/prediction/input_predictor.h b/ink_stroke_modeler/internal/prediction/input_predictor.h
new file mode 100644
index 0000000..4147953
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/input_predictor.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_INPUT_PREDICTOR_H_
+#define INK_STROKE_MODELER_INTERNAL_PREDICTION_INPUT_PREDICTOR_H_
+
+#include <vector>
+
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// Interface for input predictors that generate points based on past input.
+class InputPredictor {
+ public:
+ virtual ~InputPredictor() {}
+
+ // Resets the predictor's internal model.
+ virtual void Reset() = 0;
+
+ // Updates the predictor's internal model with the given input.
+ virtual void Update(Vec2 position, Time time) = 0;
+
+ // Constructs a prediction based from the given state, based on the
+ // predictor's internal model. The result may be empty if the predictor has
+ // not yet accumulated enough data, via Update(), to construct a reasonable
+ // prediction.
+ //
+ // Subclasses are expected to maintain the following invariants:
+ // - The given state must not appear in the prediction.
+ // - The time delta between each state in the prediction, and between the
+ // given state and the first predicted state, must conform to
+ // SamplingParams::min_output_rate.
+ virtual std::vector<TipState> ConstructPrediction(
+ const TipState &last_state) const = 0;
+};
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_INPUT_PREDICTOR_H_
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/BUILD.bazel b/ink_stroke_modeler/internal/prediction/kalman_filter/BUILD.bazel
new file mode 100644
index 0000000..e487c16
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/BUILD.bazel
@@ -0,0 +1,58 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+package(
+ default_visibility = ["//ink_stroke_modeler:__subpackages__"],
+)
+
+licenses(["notice"])
+
+cc_library(
+ name = "matrix",
+ hdrs = ["matrix.h"],
+)
+
+cc_test(
+ name = "matrix_test",
+ srcs = ["matrix_test.cc"],
+ deps = [
+ ":matrix",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
+ name = "kalman_filter",
+ srcs = [
+ "axis_predictor.cc",
+ "kalman_filter.cc",
+ ],
+ hdrs = [
+ "axis_predictor.h",
+ "kalman_filter.h",
+ ],
+ deps = [
+ ":matrix",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+cc_test(
+ name = "axis_predictor_test",
+ srcs = ["axis_predictor_test.cc"],
+ deps = [
+ ":kalman_filter",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/CMakeLists.txt b/ink_stroke_modeler/internal/prediction/kalman_filter/CMakeLists.txt
new file mode 100644
index 0000000..b3c8d88
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/CMakeLists.txt
@@ -0,0 +1,54 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+ink_cc_library(
+ NAME
+ matrix
+ HDRS
+ matrix.h
+)
+
+ink_cc_test(
+ NAME
+ matrix_test
+ SRCS
+ matrix_test.cc
+ DEPS
+ InkStrokeModeler::matrix
+ GTest::gtest_main
+)
+
+ink_cc_library(
+ NAME
+ kalman_filter
+ SRCS
+ axis_predictor.cc
+ kalman_filter.cc
+ HDRS
+ axis_predictor.h
+ kalman_filter.h
+ DEPS
+ InkStrokeModeler::matrix
+ absl::memory
+)
+
+ink_cc_test(
+ NAME
+ axis_predictor_test
+ SRCS
+ axis_predictor_test.cc
+ DEPS
+ InkStrokeModeler::kalman_filter
+ GTest::gtest_main
+)
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.cc b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.cc
new file mode 100644
index 0000000..fab9363
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.cc
@@ -0,0 +1,98 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h"
+
+#include "absl/memory/memory.h"
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+constexpr int kPositionIndex = 0;
+constexpr int kVelocityIndex = 1;
+constexpr int kAccelerationIndex = 2;
+constexpr int kJerkIndex = 3;
+
+constexpr double kDt = 1.0;
+constexpr double kDtSquared = kDt * kDt;
+constexpr double kDtCubed = kDt * kDt * kDt;
+} // namespace
+
+AxisPredictor::AxisPredictor(double process_noise, double measurement_noise,
+ int min_stable_iteration) {
+ // State translation matrix is basic physics.
+ // new_pos = pre_pos + v * dt + 1/2 * a * dt^2 + 1/6 * J * dt^3.
+ // new_v = v + a * dt + 1/2 * J * dt^2.
+ // new_a = a + J * dt.
+ // new_j = J.
+ Matrix4 state_transition(1, kDt, .5 * kDtSquared, 1.0 / 6 * kDtCubed, //
+ 0, 1, kDt, .5 * kDtSquared, //
+ 0, 0, 1, kDt, //
+ 0, 0, 0, 1);
+ // We model the system noise as noisy force on the pen.
+ // The following matrix describes the impact of that noise on each state.
+ Vec4 process_noise_vector(1.0 / 6 * kDtCubed, 0.5 * kDtSquared, kDt, 1.0);
+ Matrix4 process_noise_covariance =
+ OuterProduct(process_noise_vector, process_noise_vector) * process_noise;
+
+ // Sensor only detects location. Thus measurement only impact the position.
+ Vec4 measurement_vector(1.0, 0.0, 0.0, 0.0);
+
+ kalman_filter_ = absl::make_unique<KalmanFilter>(
+ state_transition, process_noise_covariance, measurement_vector,
+ measurement_noise, min_stable_iteration);
+}
+
+bool AxisPredictor::Stable() const {
+ return kalman_filter_ && kalman_filter_->Stable();
+}
+
+void AxisPredictor::Reset() {
+ if (kalman_filter_) kalman_filter_->Reset();
+}
+
+void AxisPredictor::Update(double observation) {
+ if (kalman_filter_) kalman_filter_->Update(observation);
+}
+
+int AxisPredictor::NumIterations() const {
+ return kalman_filter_ ? kalman_filter_->NumIterations() : 0;
+}
+
+double AxisPredictor::GetPosition() const {
+ if (kalman_filter_)
+ return kalman_filter_->GetStateEstimation()[kPositionIndex];
+ else
+ return 0.0;
+}
+
+double AxisPredictor::GetVelocity() const {
+ if (kalman_filter_)
+ return kalman_filter_->GetStateEstimation()[kVelocityIndex];
+ else
+ return 0.0;
+}
+
+double AxisPredictor::GetAcceleration() const {
+ return kalman_filter_->GetStateEstimation()[kAccelerationIndex];
+}
+
+double AxisPredictor::GetJerk() const {
+ return kalman_filter_->GetStateEstimation()[kJerkIndex];
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h
new file mode 100644
index 0000000..3ab02ee
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_AXIS_PREDICTOR_H_
+#define INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_AXIS_PREDICTOR_H_
+
+#include <memory>
+
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h"
+
+namespace ink {
+namespace stroke_model {
+
+// Class to predict on axis.
+//
+// This predictor use one instance of Kalman filter to predict one dimension of
+// stylus movement.
+class AxisPredictor {
+ public:
+ AxisPredictor(double process_noise, double measurement_noise,
+ int min_stable_iteration);
+
+ // Return true if the underlying Kalman filter is stable.
+ bool Stable() const;
+
+ // Reset the underlying Kalman filter.
+ void Reset();
+
+ // Update the predictor with a new observation.
+ void Update(double observation);
+
+ // Returns the number of times Update() has been called since the last time
+ // the AxisPredictor was reset.
+ int NumIterations() const;
+
+ // Get the predicted values from the underlying Kalman filter.
+ double GetPosition() const;
+ double GetVelocity() const;
+ double GetAcceleration() const;
+ double GetJerk() const;
+
+ private:
+ std::unique_ptr<KalmanFilter> kalman_filter_;
+};
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_AXIS_PREDICTOR_H_
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc
new file mode 100644
index 0000000..40ba9f4
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc
@@ -0,0 +1,100 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h"
+
+#include <vector>
+
+#include "gtest/gtest.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+constexpr int kStableIterNum = 4;
+
+constexpr double kProcessNoise = 0.01;
+constexpr double kMeasurementNoise = 1.0;
+
+} // namespace
+
+struct DataSet {
+ double initial_observation;
+ std::vector<double> observation;
+ std::vector<double> position;
+ std::vector<double> velocity;
+ std::vector<double> acceleration;
+ std::vector<double> jerk;
+};
+
+void ValidateAxisPredictor(AxisPredictor* predictor, const DataSet& data) {
+ predictor->Reset();
+ predictor->Update(data.initial_observation);
+ for (decltype(data.observation.size()) i = 0; i < data.observation.size();
+ i++) {
+ predictor->Update(data.observation[i]);
+ EXPECT_NEAR(data.position[i], predictor->GetPosition(), 0.0001);
+ EXPECT_NEAR(data.velocity[i], predictor->GetVelocity(), 0.0001);
+ EXPECT_NEAR(data.acceleration[i], predictor->GetAcceleration(), 0.0001);
+ EXPECT_NEAR(data.jerk[i], predictor->GetJerk(), 0.0001);
+ }
+}
+
+// Test that the predictor will stable.
+TEST(AxisPredictorTest, ShouldStable) {
+ AxisPredictor predictor(kProcessNoise, kMeasurementNoise, kStableIterNum);
+ for (int i = 0; i < kStableIterNum; i++) {
+ EXPECT_FALSE(predictor.Stable());
+ predictor.Update(1);
+ }
+ EXPECT_TRUE(predictor.Stable());
+}
+
+// Test the kalman filter behavior. The data set is generated by a "known to
+// work" kalman filter.
+TEST(AxisPredictorTest, PredictedValue) {
+ AxisPredictor predictor(kProcessNoise, kMeasurementNoise, kStableIterNum);
+ DataSet data;
+ data.initial_observation = 0;
+ data.observation = {1, 2, 3, 4, 5, 6};
+ data.position = {0.6949411066858742, 1.8880162111305765, 3.0596776689233476,
+ 4.080666568886563, 5.039574058758894, 5.990101744132957};
+ data.velocity = {0.48326413015846115, 1.349212968908908, 1.5150757723942188,
+ 1.2449353797925855, 0.9823147273054352, 0.831418084705206};
+ data.acceleration = {0.20388102703160751, 0.6602537865634062,
+ 0.46392675203046707, 0.0691864035645362,
+ -0.1571001901104591, -0.2303438651979314};
+ data.jerk = {0.051351580374544535, 0.17805019978769315,
+ 0.06592110190532013, -0.06063794909774803,
+ -0.10198612906906362, -0.09541445938944032};
+
+ ValidateAxisPredictor(&predictor, data);
+
+ data.initial_observation = 0;
+ data.observation = {1, 2, 4, 8, 16, 32};
+ data.position = {0.6949411066858742, 1.8880162111305765, 3.9597202826804603,
+ 7.9052737853848285, 15.720340533540115, 31.24662046486774};
+ data.velocity = {0.48326413015846115, 1.349212968908908, 2.492271225870179,
+ 4.610844489557212, 8.828231877380588, 16.987494416071463};
+ data.acceleration = {0.20388102703160751, 0.6602537865634062,
+ 1.090991623810185, 1.885675547541351,
+ 3.4586206593783526, 6.34082285106952};
+ data.jerk = {0.051351580374544535, 0.17805019978769315, 0.25373225050247916,
+ 0.4023497012294069, 0.6945464157568688, 1.1947316519015612};
+
+ ValidateAxisPredictor(&predictor, data);
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.cc b/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.cc
new file mode 100644
index 0000000..45238f7
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.cc
@@ -0,0 +1,79 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h"
+
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h"
+
+namespace ink {
+namespace stroke_model {
+
+KalmanFilter::KalmanFilter(const Matrix4& state_transition,
+ const Matrix4& process_noise_covariance,
+ const Vec4& measurement_vector,
+ double measurement_noise_variance,
+ int min_stable_iteration)
+ : state_transition_matrix_(state_transition),
+ process_noise_covariance_matrix_(process_noise_covariance),
+ measurement_vector_(measurement_vector),
+ measurement_noise_variance_(measurement_noise_variance),
+ min_stable_iteration_(min_stable_iteration),
+ iter_num_(0) {}
+
+void KalmanFilter::Predict() {
+ // X = F * X
+ state_estimation_ = state_transition_matrix_ * state_estimation_;
+ // P = F * P * F' + Q
+ error_covariance_matrix_ = state_transition_matrix_ *
+ error_covariance_matrix_ *
+ state_transition_matrix_.Transpose() +
+ process_noise_covariance_matrix_;
+}
+
+void KalmanFilter::Update(double observation) {
+ if (iter_num_++ == 0) {
+ // We only update the state estimation in the first iteration.
+ state_estimation_[0] = observation;
+ return;
+ }
+ Predict();
+ // Y = z - H * X
+ double y = observation - DotProduct(measurement_vector_, state_estimation_);
+ // S = H * P * H' + R
+ double S = DotProduct(measurement_vector_ * error_covariance_matrix_,
+ measurement_vector_) +
+ measurement_noise_variance_;
+ // K = P * H' * inv(S)
+ Vec4 kalman_gain = measurement_vector_ * error_covariance_matrix_ / S;
+
+ // X = X + K * Y
+ state_estimation_ = state_estimation_ + kalman_gain * y;
+
+ // I_HK = eye(P) - K * H
+ Matrix4 I_KH = Matrix4() - OuterProduct(kalman_gain, measurement_vector_);
+
+ // P = I_KH * P * I_KH' + K * R * K'
+ error_covariance_matrix_ =
+ I_KH * error_covariance_matrix_ * I_KH.Transpose() +
+ OuterProduct(kalman_gain, kalman_gain) * measurement_noise_variance_;
+}
+
+void KalmanFilter::Reset() {
+ state_estimation_ = {0, 0, 0, 0};
+ error_covariance_matrix_ = Matrix4(); // identity
+ iter_num_ = 0;
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h b/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h
new file mode 100644
index 0000000..65fb55f
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h
@@ -0,0 +1,96 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_KALMAN_FILTER_H_
+#define INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_KALMAN_FILTER_H_
+
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h"
+
+namespace ink {
+namespace stroke_model {
+
+// Generates a state estimation based upon observations which can then be used
+// to compute predicted values.
+class KalmanFilter {
+ public:
+ KalmanFilter(const Matrix4& state_transition,
+ const Matrix4& process_noise_covariance,
+ const Vec4& measurement_vector,
+ double measurement_noise_variance, int min_stable_iteration);
+
+ // Get the estimation of current state.
+ const Vec4& GetStateEstimation() const { return state_estimation_; }
+
+ // Will return true only if the Kalman filter has seen enough data and is
+ // considered as stable.
+ bool Stable() const { return iter_num_ >= min_stable_iteration_; }
+
+ // Update the observation of the system.
+ void Update(double observation);
+
+ void Reset();
+
+ // Returns the number of times Update() has been called since the last time
+ // the KalmanFilter was reset.
+ int NumIterations() const { return iter_num_; }
+
+ private:
+ void Predict();
+
+ // Estimate of the latent state
+ // Symbol: X
+ // Dimension: state_vector_dim_
+ Vec4 state_estimation_;
+
+ // The covariance of the difference between prior predicted latent
+ // state and posterior estimated latent state (the so-called "innovation".
+ // Symbol: P
+ Matrix4 error_covariance_matrix_;
+
+ // For position, state transition matrix is derived from basic physics:
+ // new_x = x + v * dt + 1/2 * a * dt^2 + 1/6 * jerk * dt^3
+ // new_v = v + a * dt + 1/2 * jerk * dt^2
+ // ...
+ // Matrix that transmit current state to next state
+ // Symbol: F
+ Matrix4 state_transition_matrix_;
+
+ // Process_noise_covariance_matrix_ is a time-varying parameter that will be
+ // estimated as part of the Kalman filter process.
+ // Symbol: Q
+ Matrix4 process_noise_covariance_matrix_;
+
+ // Vector to transform estimate to measurement.
+ // Symbol: H
+ const Vec4 measurement_vector_{0, 0, 0, 0};
+
+ // measurement_noise_ is a time-varying parameter that will be estimated as
+ // part of the Kalman filter process.
+ // Symbol: R
+ double measurement_noise_variance_;
+
+ // The first iteration at which the Kalman filter is considered stable enough
+ // to make a good estimate of the state.
+ int min_stable_iteration_;
+
+ // Tracks the number of update iterations that have occurred.
+ int iter_num_;
+};
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_KALMAN_FILTER_H_
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h b/ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h
new file mode 100644
index 0000000..2c12bae
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h
@@ -0,0 +1,270 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_MATRIX_H_
+#define INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_MATRIX_H_
+
+#include <array>
+#include <cstddef>
+#include <ostream>
+
+namespace ink {
+namespace stroke_model {
+
+// These classes provide the matrix arithmetic needed for the Kalman filter.
+//
+// This is intentionally limited to just the required functions, so some
+// common matrix arithmetic operations aren't present (e.g. inversion), and
+// some operators' symmetric counterparts are missing (e.g. Vec4 * double is
+// defined, but double * Vec4 is not).
+
+// A double-precision vector in 4-dimensional space.
+class Vec4 {
+ public:
+ constexpr Vec4() : Vec4(0, 0, 0, 0) {}
+ constexpr Vec4(double x, double y, double z, double w)
+ : array_({x, y, z, w}) {}
+
+ double& operator[](size_t i) { return array_[i]; }
+ double operator[](size_t i) const { return array_[i]; }
+
+ private:
+ std::array<double, 4> array_;
+};
+
+// A double-precision 4x4 matrix.
+class Matrix4 {
+ public:
+ // Constructs an identity matrix.
+ constexpr Matrix4()
+ : Matrix4(1, 0, 0, 0, //
+ 0, 1, 0, 0, //
+ 0, 0, 1, 0, //
+ 0, 0, 0, 1) {}
+
+ // Constructs a matrix with the given values, in row-major order.
+ constexpr Matrix4(double m00, double m01, double m02, double m03, //
+ double m10, double m11, double m12, double m13, //
+ double m20, double m21, double m22, double m23, //
+ double m30, double m31, double m32, double m33)
+ : array_{{{m00, m01, m02, m03},
+ {m10, m11, m12, m13},
+ {m20, m21, m22, m23},
+ {m30, m31, m32, m33}}} {}
+
+ // Constructs a matrix s.t. all values are zero.
+ static constexpr Matrix4 Zero() {
+ return {0, 0, 0, 0, //
+ 0, 0, 0, 0, //
+ 0, 0, 0, 0, //
+ 0, 0, 0, 0};
+ }
+
+ // Returns a copy of the matrix with its rows and columns swapped, i.e.
+ // original.At(i, j) == transposed.At(j, i).
+ Matrix4 Transpose() const {
+ Matrix4 result;
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ result.At(i, j) = At(j, i);
+ }
+ }
+ return result;
+ }
+
+ double& At(size_t row, size_t column) { return array_[row][column]; }
+ double At(size_t row, size_t column) const { return array_[row][column]; }
+
+ private:
+ std::array<Vec4, 4> array_;
+};
+
+// Computes the dot product of two vectors. Given vectors a and b, this is
+// equivalent to the matrix product:
+// [a₀ a₁ a₂ a₃]⎡b₀⎤
+// ⎢b₁⎥
+// ⎢b₂⎥
+// ⎣b₃⎦
+double DotProduct(const Vec4& lhs, const Vec4& rhs);
+
+// Computes the outer product of two vectors. Given vectors a and b, this is
+// equivalent to the matrix product:
+// ⎡a₀⎤[b₀ b₁ b₂ b₃]
+// ⎢a₁⎥
+// ⎢a₂⎥
+// ⎣a₃⎦
+Matrix4 OuterProduct(const Vec4& lhs, const Vec4& rhs);
+
+bool operator==(const Vec4& lhs, const Vec4& rhs);
+bool operator!=(const Vec4& lhs, const Vec4& rhs);
+Vec4 operator+(const Vec4& lhs, const Vec4& rhs);
+Vec4 operator*(const Vec4& v, double k);
+Vec4 operator/(const Vec4& v, double k);
+
+bool operator==(const Matrix4& lhs, const Matrix4& rhs);
+bool operator!=(const Matrix4& lhs, const Matrix4& rhs);
+Matrix4 operator*(const Matrix4& lhs, const Matrix4& rhs);
+Matrix4 operator+(const Matrix4& lhs, const Matrix4& rhs);
+Matrix4 operator-(const Matrix4& lhs, const Matrix4& rhs);
+
+Matrix4 operator*(const Matrix4& m, double k);
+Vec4 operator*(const Matrix4& m, const Vec4& v);
+Vec4 operator*(const Vec4& v, const Matrix4& m);
+
+std::ostream& operator<<(std::ostream& stream, const Vec4& v);
+std::ostream& operator<<(std::ostream& stream, const Matrix4& m);
+
+// ============================================================================
+// Inline function implementations
+// ============================================================================
+
+inline double DotProduct(const Vec4& lhs, const Vec4& rhs) {
+ double result = 0;
+ for (int i = 0; i < 4; ++i) result += lhs[i] * rhs[i];
+ return result;
+}
+
+inline Matrix4 OuterProduct(const Vec4& lhs, const Vec4& rhs) {
+ Matrix4 result = Matrix4::Zero();
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ result.At(i, j) = lhs[i] * rhs[j];
+ }
+ }
+ return result;
+}
+
+inline bool operator==(const Vec4& lhs, const Vec4& rhs) {
+ for (int i = 0; i < 4; ++i) {
+ if (lhs[i] != rhs[i]) return false;
+ }
+ return true;
+}
+inline bool operator!=(const Vec4& lhs, const Vec4& rhs) {
+ return !(lhs == rhs);
+}
+
+inline Vec4 operator+(const Vec4& lhs, const Vec4& rhs) {
+ Vec4 result;
+ for (int i = 0; i < 4; ++i) result[i] = lhs[i] + rhs[i];
+ return result;
+}
+
+inline Vec4 operator*(const Vec4& v, double k) {
+ Vec4 result;
+ for (int i = 0; i < 4; ++i) result[i] = v[i] * k;
+ return result;
+}
+
+inline Vec4 operator/(const Vec4& v, double k) {
+ Vec4 result;
+ for (int i = 0; i < 4; ++i) result[i] = v[i] / k;
+ return result;
+}
+
+inline bool operator==(const Matrix4& lhs, const Matrix4& rhs) {
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ if (lhs.At(i, j) != rhs.At(i, j)) return false;
+ }
+ }
+ return true;
+}
+
+inline bool operator!=(const Matrix4& lhs, const Matrix4& rhs) {
+ return !(lhs == rhs);
+}
+
+inline Matrix4 operator*(const Matrix4& lhs, const Matrix4& rhs) {
+ Matrix4 result = Matrix4::Zero();
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 4; ++k) {
+ result.At(i, j) += lhs.At(i, k) * rhs.At(k, j);
+ }
+ }
+ }
+ return result;
+}
+
+inline Matrix4 operator+(const Matrix4& lhs, const Matrix4& rhs) {
+ Matrix4 result;
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ result.At(i, j) = lhs.At(i, j) + rhs.At(i, j);
+ }
+ }
+ return result;
+}
+
+inline Matrix4 operator-(const Matrix4& lhs, const Matrix4& rhs) {
+ Matrix4 result;
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ result.At(i, j) = lhs.At(i, j) - rhs.At(i, j);
+ }
+ }
+ return result;
+}
+
+inline Matrix4 operator*(const Matrix4& m, double k) {
+ Matrix4 result;
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ result.At(i, j) = m.At(i, j) * k;
+ }
+ }
+ return result;
+}
+
+inline Vec4 operator*(const Matrix4& m, const Vec4& v) {
+ Vec4 result;
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ result[i] += v[j] * m.At(i, j);
+ }
+ }
+ return result;
+}
+
+inline Vec4 operator*(const Vec4& v, const Matrix4& m) {
+ Vec4 result;
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ result[i] += v[j] * m.At(j, i);
+ }
+ }
+ return result;
+}
+
+inline std::ostream& operator<<(std::ostream& stream, const Vec4& v) {
+ stream << "(" << v[0];
+ for (int i = 1; i < 4; ++i) stream << ", " << v[i];
+ return stream << ")";
+}
+
+inline std::ostream& operator<<(std::ostream& stream, const Matrix4& m) {
+ for (int i = 0; i < 4; ++i) {
+ stream << '\n' << m.At(i, 0);
+ for (int j = 1; j < 4; ++j) stream << '\t' << m.At(i, j);
+ }
+ return stream;
+}
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_MATRIX_H_
diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/matrix_test.cc b/ink_stroke_modeler/internal/prediction/kalman_filter/matrix_test.cc
new file mode 100644
index 0000000..b11970d
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_filter/matrix_test.cc
@@ -0,0 +1,215 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h"
+
+#include <sstream>
+#include <string>
+
+#include "gtest/gtest.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+TEST(MatrixTest, Vec4Equality) {
+ EXPECT_EQ(Vec4(1, 2, 3, 4), Vec4(1, 2, 3, 4));
+ EXPECT_NE(Vec4(1, 2, 3, 4), Vec4(1, 2, 7, 4));
+ EXPECT_NE(Vec4(.1, .7, -4, 6), Vec4(-1, 64, .3, 200));
+}
+
+TEST(MatrixTest, Vec4Addition) {
+ EXPECT_EQ(Vec4(0, 1, 2, 4) + Vec4(5, -1, 6, 7), Vec4(5, 0, 8, 11));
+ EXPECT_EQ(Vec4(0.25, -3, 17, 0) + Vec4(.5, -.5, -1, 2),
+ Vec4(.75, -3.5, 16, 2));
+}
+
+TEST(MatrixTest, Vec4ScalarMultiplication) {
+ EXPECT_EQ(Vec4(6, -12, 7, .25) * 2, Vec4(12, -24, 14, .5));
+ EXPECT_EQ(Vec4(17, -3, 5.5, 0) * -.25, Vec4(-4.25, .75, -1.375, 0));
+}
+
+TEST(MatrixTest, Vec4ScalarDivision) {
+ EXPECT_EQ(Vec4(13, -8, 0, 100) / -2, Vec4(-6.5, 4, 0, -50));
+ EXPECT_EQ(Vec4(0, -3, 20, 1) / .2, Vec4(0, -15, 100, 5));
+}
+
+TEST(MatrixTest, Matrix4Equality) {
+ EXPECT_EQ(Matrix4(0, 1, 2, 3, //
+ 4, 5, 6, 7, //
+ 8, 9, 10, 11, //
+ 12, 13, 14, 15),
+ Matrix4(0, 1, 2, 3, //
+ 4, 5, 6, 7, //
+ 8, 9, 10, 11, //
+ 12, 13, 14, 15));
+ EXPECT_NE(Matrix4(0, 1, 2, 3, //
+ 4, 5, 6, 7, //
+ 8, 9, 10, 11, //
+ 12, 13, 14, 15),
+ Matrix4(1, 2, 0, 4, //
+ 4, 9, 6, 12, //
+ 9, 9, 9, 9, //
+ -1, -2, 14, 99));
+}
+
+TEST(MatrixTest, Matrix4IdentityCtor) {
+ EXPECT_EQ(Matrix4(), Matrix4(1, 0, 0, 0, //
+ 0, 1, 0, 0, //
+ 0, 0, 1, 0, //
+ 0, 0, 0, 1));
+}
+
+TEST(MatrixTest, Matrix4Zero) {
+ EXPECT_EQ(Matrix4::Zero(), Matrix4(0, 0, 0, 0, //
+ 0, 0, 0, 0, //
+ 0, 0, 0, 0, //
+ 0, 0, 0, 0));
+}
+
+TEST(MatrixTest, Matrix4Transpose) {
+ Matrix4 m(0, 1, 2, 3, //
+ 4, 5, 6, 7, //
+ 8, 9, 10, 11, //
+ 12, 13, 14, 15);
+ EXPECT_EQ(m.Transpose(), Matrix4(0, 4, 8, 12, //
+ 1, 5, 9, 13, //
+ 2, 6, 10, 14, //
+ 3, 7, 11, 15));
+}
+
+TEST(MatrixTest, Matrix4Multiplication) {
+ Matrix4 a(-4, 4, 2, 9, //
+ -2, -5, 6, 1, //
+ -2, 7, 10, 1, //
+ -4, -5, 2, 6);
+ Matrix4 b(-1, 7, 9, 3, //
+ 0, 7, -3, 8, //
+ -9, 7, 7, -10, //
+ 1, -1, -3, -1);
+ EXPECT_EQ(a * b, Matrix4(-5, 5, -61, -9, //
+ -51, -8, 36, -107, //
+ -87, 104, 28, -51, //
+ -8, -55, -25, -78));
+ EXPECT_EQ(b * a, Matrix4(-40, 9, 136, 25, //
+ -40, -96, 28, 52, //
+ 48, 28, 74, -127, //
+ 8, -7, -36, -1));
+}
+
+TEST(MatrixTest, Matrix4Addition) {
+ Matrix4 a(2, 0, -10, -1, //
+ -4, -4, -7, 3, //
+ 7, -1, 7, 3, //
+ -7, -4, -4, -4);
+ Matrix4 b(9, -6, -10, 0, //
+ 6, 1, -5, 9, //
+ -7, -4, -3, -6, //
+ 7, 7, -10, -9);
+ EXPECT_EQ(a + b, Matrix4(11, -6, -20, -1, //
+ 2, -3, -12, 12, //
+ 0, -5, 4, -3, //
+ 0, 3, -14, -13));
+ EXPECT_EQ(b + a, Matrix4(11, -6, -20, -1, //
+ 2, -3, -12, 12, //
+ 0, -5, 4, -3, //
+ 0, 3, -14, -13));
+}
+
+TEST(MatrixTest, Matrix4Subtraction) {
+ Matrix4 a(-7, -9, 9, 9, //
+ -4, 10, -3, -1, //
+ 8, 9, 6, 4, //
+ 9, -7, 7, 4);
+ Matrix4 b(2, -1, 2, 6, //
+ -1, -8, -1, 10, //
+ 3, 0, -6, -1, //
+ -6, -3, 6, 7);
+ EXPECT_EQ(a - b, Matrix4(-9, -8, 7, 3, //
+ -3, 18, -2, -11, //
+ 5, 9, 12, 5, //
+ 15, -4, 1, -3));
+ EXPECT_EQ(b - a, Matrix4(9, 8, -7, -3, //
+ 3, -18, 2, 11, //
+ -5, -9, -12, -5, //
+ -15, 4, -1, 3));
+}
+
+TEST(MatrixTest, Matrix4ScalarMultiplication) {
+ Matrix4 m(6, 8, 6, 7, //
+ 2, -4, 10, 8, //
+ -1, 4, 9, -7, //
+ -2, -9, 10, 10);
+ EXPECT_EQ(m * 3, Matrix4(18, 24, 18, 21, //
+ 6, -12, 30, 24, //
+ -3, 12, 27, -21, //
+ -6, -27, 30, 30));
+ EXPECT_EQ(m * .5, Matrix4(3, 4, 3, 3.5, //
+ 1, -2, 5, 4, //
+ -.5, 2, 4.5, -3.5, //
+ -1, -4.5, 5, 5));
+}
+
+TEST(MatrixTest, Matrix4VectorMultiplication) {
+ Matrix4 m(3, 0, 4, 3, //
+ 7, -6, 7, -10, //
+ 6, -2, -10, -5, //
+ -3, 9, 1, -5);
+ Vec4 v(-6, 9, -7, -10);
+ EXPECT_EQ(m * v, Vec4(-76, -45, 66, 142));
+ EXPECT_EQ(v * m, Vec4(33, -130, 99, -23));
+}
+
+TEST(MatrixTest, DotProduct) {
+ Vec4 a(0, -3, 0, -5);
+ Vec4 b(6, 4, -4, 6);
+ EXPECT_EQ(DotProduct(a, b), -42);
+ EXPECT_EQ(DotProduct(b, a), -42);
+}
+
+TEST(MatrixTest, OuterProduct) {
+ Vec4 a(-8, -3, 6, -8);
+ Vec4 b(4, -9, -1, 10);
+ EXPECT_EQ(OuterProduct(a, b), Matrix4(-32, 72, 8, -80, //
+ -12, 27, 3, -30, //
+ 24, -54, -6, 60, //
+ -32, 72, 8, -80));
+ EXPECT_EQ(OuterProduct(b, a), Matrix4(-32, -12, 24, -32, //
+ 72, 27, -54, 72, //
+ 8, 3, -6, 8, //
+ -80, -30, 60, -80));
+}
+
+TEST(MatrixTest, Vec4Stream) {
+ std::stringstream s;
+ s << Vec4(1.28, -9, .9, 2.7);
+ EXPECT_EQ(s.str(), "(1.28, -9, 0.9, 2.7)");
+}
+
+TEST(MatrixTest, Matrix4Stream) {
+ std::stringstream s;
+ s << Matrix4(7.5, -7.7, -4.6, 8, //
+ 6.4, -8.52, 0, 8.8, //
+ -3.5, -5.2, -.5, 9, //
+ -2.6, -3.4, 5.5, 8.3);
+ EXPECT_EQ(s.str(),
+ "\n7.5\t-7.7\t-4.6\t8"
+ "\n6.4\t-8.52\t0\t8.8"
+ "\n-3.5\t-5.2\t-0.5\t9"
+ "\n-2.6\t-3.4\t5.5\t8.3");
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/prediction/kalman_predictor.cc b/ink_stroke_modeler/internal/prediction/kalman_predictor.cc
new file mode 100644
index 0000000..cc8949c
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_predictor.cc
@@ -0,0 +1,265 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/kalman_predictor.h"
+
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <vector>
+
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/internal/utils.h"
+#include "ink_stroke_modeler/params.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+KalmanPredictor::State EvaluateCubic(const KalmanPredictor::State &start_state,
+ Duration delta_time) {
+ float dt = delta_time.Value();
+ auto dt_squared = dt * dt;
+ auto dt_cubed = dt_squared * dt;
+
+ KalmanPredictor::State end_state;
+ end_state.position = start_state.position + start_state.velocity * dt +
+ start_state.acceleration * dt_squared / 2.f +
+ start_state.jerk * dt_cubed / 6.f;
+ end_state.velocity = start_state.velocity + start_state.acceleration * dt +
+ start_state.jerk * dt_squared / 2.f;
+ end_state.acceleration = start_state.acceleration + start_state.jerk * dt;
+ end_state.jerk = start_state.jerk;
+
+ return end_state;
+}
+
+} // namespace
+
+void KalmanPredictor::Reset() {
+ x_predictor_.Reset();
+ y_predictor_.Reset();
+ sample_times_.clear();
+ last_position_received_ = absl::nullopt;
+}
+
+void KalmanPredictor::Update(Vec2 position, Time time) {
+ last_position_received_ = position;
+ sample_times_.push_back(time);
+ if (predictor_params_.max_time_samples < 0 ||
+ sample_times_.size() > (uint)predictor_params_.max_time_samples)
+ sample_times_.pop_front();
+
+ x_predictor_.Update(position.x);
+ y_predictor_.Update(position.y);
+}
+
+absl::optional<KalmanPredictor::State> KalmanPredictor::GetEstimatedState()
+ const {
+ if (!IsStable() || sample_times_.empty()) return absl::nullopt;
+
+ State estimated_state;
+ estimated_state.position = {static_cast<float>(x_predictor_.GetPosition()),
+ static_cast<float>(y_predictor_.GetPosition())};
+ estimated_state.velocity = {static_cast<float>(x_predictor_.GetVelocity()),
+ static_cast<float>(y_predictor_.GetVelocity())};
+ estimated_state.acceleration = {
+ static_cast<float>(x_predictor_.GetAcceleration()),
+ static_cast<float>(y_predictor_.GetAcceleration())};
+ estimated_state.jerk = {static_cast<float>(x_predictor_.GetJerk()),
+ static_cast<float>(y_predictor_.GetJerk())};
+
+ // The axis predictors are not time-aware, assuming that the time delta
+ // between measurements is always 1. To correct for this, we divide the
+ // velocity, acceleration, and jerk by the average observed time delta, raised
+ // to the appropriate power.
+ auto dt = static_cast<float>(
+ (sample_times_.back() - sample_times_.front()).Value()) /
+ sample_times_.size();
+ auto dt_squared = dt * dt;
+ auto dt_cubed = dt_squared * dt;
+ estimated_state.velocity /= dt;
+ estimated_state.acceleration /= dt_squared;
+ estimated_state.jerk /= dt_cubed;
+
+ // We want our predictions to tend more towards linearity -- to achieve this,
+ // we reduce the acceleration and jerk.
+ estimated_state.acceleration *= predictor_params_.acceleration_weight;
+ estimated_state.jerk *= predictor_params_.jerk_weight;
+
+ return estimated_state;
+}
+
+std::vector<TipState> KalmanPredictor::ConstructPrediction(
+ const TipState &last_state) const {
+ auto estimated_state = GetEstimatedState();
+ if (!estimated_state || !last_position_received_) {
+ // We don't yet have enough data to construct a prediction.
+ return {};
+ }
+
+ Duration sample_dt{1. / sampling_params_.min_output_rate};
+ std::vector<TipState> prediction;
+ ConstructCubicConnector(last_state, *estimated_state, predictor_params_,
+ sample_dt, &prediction);
+ auto start_time =
+ prediction.empty() ? last_state.time : prediction.back().time;
+ ConstructCubicPrediction(*estimated_state, predictor_params_, start_time,
+ sample_dt, NumberOfPointsToPredict(*estimated_state),
+ &prediction);
+ return prediction;
+}
+
+void KalmanPredictor::ConstructCubicPrediction(
+ const State &estimated_state, const KalmanPredictorParams &params,
+ Time start_time, Duration sample_dt, int n_samples,
+ std::vector<TipState> *output) {
+ auto current_state = estimated_state;
+ auto current_time = start_time;
+ for (int i = 0; i < n_samples; ++i) {
+ auto next_state = EvaluateCubic(current_state, sample_dt);
+ current_time += sample_dt;
+ output->push_back({next_state.position, next_state.velocity, current_time});
+ current_state = next_state;
+ }
+}
+
+void KalmanPredictor::ConstructCubicConnector(
+ const TipState &last_tip_state, const State &estimated_state,
+ const KalmanPredictorParams &params, Duration sample_dt,
+ std::vector<TipState> *output) {
+ // Estimate how long it will take for the tip to travel from its last position
+ // to the estimated position, based on the start and end velocities. We define
+ // a minimum "reasonable" velocity to avoid division by zero.
+ auto distance_traveled =
+ Distance(last_tip_state.position, estimated_state.position);
+ auto max_velocity_at_ends = std::max(last_tip_state.velocity.Magnitude(),
+ estimated_state.velocity.Magnitude());
+ Duration target_duration{
+ distance_traveled /
+ std::max(max_velocity_at_ends, params.min_catchup_velocity)};
+
+ // Determine how many samples this will give us, ensuring that there's always
+ // at least one. Then, pick a duration that's a multiple of the sample dt.
+ int n_points = std::max(std::ceil(static_cast<float>(target_duration.Value() /
+ sample_dt.Value())),
+ 1.f);
+ auto duration = n_points * sample_dt;
+
+ // We want to construct a cubic curve connecting the last tip state and the
+ // estimated state. Given positions p₀ and p₁, velocities v₀ and v₁, and times
+ // t₀ and t₁ at the start and end of the curve, we define a pair of functions,
+ // f and g, such that the curve is described by the composite function
+ // f(g(t)):
+ // f(x) = ax³ + bx² + cx + d
+ // g(t) = (t - t₀) / (t₁ - t₀)
+ // We then find the derivatives:
+ // f'(x) = 3ax² + 2bx + c
+ // g'(t) = 1 / (t₁ - t₀)
+ // (f∘g)'(t) = f'(g(t)) ⋅ g'(t) = (3ax² + 2bx + c) / (t₁ - t₀)
+ // We then plug in the given values:
+ // f(g(t₀)) = f(0) = p₀
+ // ax³ + bx² + cx + d
+ // f(g(t₁)) = f(1) = p₁
+ // (f∘g)'(t₀) = f'(0) ⋅ g'(t₀) = v₀
+ // (f∘g)'(t₁) = f'(1) ⋅ g'(t₁) = v₁
+ // This gives us four linear equations:
+ // a⋅0³ + b⋅0² + c⋅0 + d = p₀
+ // a⋅1³ + b⋅1² + c⋅1 + d = p₁
+ // (3a⋅0² + 2b⋅0 + c) / (t₁ - t₀) = v₀
+ // (3a⋅1² + 2b⋅1 + c) / (t₁ - t₀) = v₁
+ // Finally, we can solve for a, b, c, and d:
+ // a = 2p₀ - 2p₁ + (v₀ + v₁)(t₁ - t₀)
+ // b = -3p₀ + 3p₁ - (2v₀ + v₁)(t₁ - t₀)
+ // c = v₀(t₁ - t₀)
+ // d = p₀
+ float float_duration = duration.Value();
+ auto a =
+ 2.f * last_tip_state.position - 2.f * estimated_state.position +
+ (last_tip_state.velocity + estimated_state.velocity) * float_duration;
+ auto b = -3.f * last_tip_state.position + 3.f * estimated_state.position -
+ (2.f * last_tip_state.velocity + estimated_state.velocity) *
+ float_duration;
+ auto c = last_tip_state.velocity * float_duration;
+ auto d = last_tip_state.position;
+
+ output->reserve(output->size() + n_points);
+ for (int i = 1; i <= n_points; ++i) {
+ float t = static_cast<float>(i) / n_points;
+ float t_squared = t * t;
+ float t_cubed = t_squared * t;
+ auto position = a * t_cubed + b * t_squared + c * t + d;
+ auto velocity = 3.f * a * t_squared + 2.f * b * t + c;
+ auto time = last_tip_state.time + duration * t;
+ output->push_back({position, velocity / float_duration, time});
+ }
+}
+
+int KalmanPredictor::NumberOfPointsToPredict(
+ const State &estimated_state) const {
+ const KalmanPredictorParams::ConfidenceParams &confidence_params =
+ predictor_params_.confidence_params;
+
+ auto target_number =
+ static_cast<float>(predictor_params_.prediction_interval.Value() *
+ sampling_params_.min_output_rate);
+
+ // The more samples we've received, the less effect the noise from each
+ // individual input affects the result.
+ float sample_ratio =
+ std::min(1.f, static_cast<float>(x_predictor_.NumIterations()) /
+ confidence_params.desired_number_of_samples);
+
+ // The further the last given position is from the estimated position, the
+ // less confidence we have in the result.
+ float estimated_error =
+ Distance(*last_position_received_, estimated_state.position);
+ float normalized_error =
+ 1.f - Normalize01(0.f, confidence_params.max_estimation_distance,
+ estimated_error);
+
+ // This is the state that the prediction would end at if we predicted the full
+ // interval (i.e. if confidence == 1).
+ auto end_state =
+ EvaluateCubic(estimated_state, predictor_params_.prediction_interval);
+
+ // If the prediction is not traveling quickly, then changes in direction
+ // become more apparent, making the prediction appear wobbly.
+ float travel_speed =
+ Distance(estimated_state.position, end_state.position) /
+ static_cast<float>(predictor_params_.prediction_interval.Value());
+ float normalized_distance =
+ Normalize01(confidence_params.min_travel_speed,
+ confidence_params.max_travel_speed, travel_speed);
+
+ // If the actual prediction differs too much from the linear prediction, it
+ // suggests that the acceleration and jerk components overtake the velocity,
+ // resulting in a prediction that flies far off from the stroke.
+ float deviation_from_linear_prediction = Distance(
+ end_state.position,
+ estimated_state.position +
+ static_cast<float>(predictor_params_.prediction_interval.Value()) *
+ estimated_state.velocity);
+ float linearity =
+ Interp(confidence_params.baseline_linearity_confidence, 1.f,
+ 1.f - Normalize01(0.f, confidence_params.max_linear_deviation,
+ deviation_from_linear_prediction));
+
+ auto confidence =
+ sample_ratio * normalized_error * normalized_distance * linearity;
+ return std::ceil(target_number * confidence);
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/prediction/kalman_predictor.h b/ink_stroke_modeler/internal/prediction/kalman_predictor.h
new file mode 100644
index 0000000..b848770
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_predictor.h
@@ -0,0 +1,103 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_PREDICTOR_H_
+#define INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_PREDICTOR_H_
+
+#include <deque>
+#include <vector>
+
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/internal/prediction/input_predictor.h"
+#include "ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// This class constructs a prediction by using a pair of Kalman filters (one
+// each for the x- and y-dimension) to model the true state of the tip, assuming
+// that the data we receive contains some noise.
+// To construct a prediction, we first fetch the estimation of the position,
+// velocity, acceleration, and jerk from the Kalman filters. The prediction is
+// then constructed in two parts: one cubic spline that connects the last tip
+// state to the estimation, constructed from the positions and velocities at the
+// endpoints; and one cubic spline that extends into the future, constructed
+// from the estimated position, velocity, acceleration, and jerk.
+class KalmanPredictor : public InputPredictor {
+ public:
+ explicit KalmanPredictor(const KalmanPredictorParams &predictor_params,
+ const SamplingParams &sampling_params)
+ : predictor_params_(predictor_params),
+ sampling_params_(sampling_params),
+ x_predictor_(predictor_params_.process_noise,
+ predictor_params_.measurement_noise,
+ predictor_params_.min_stable_iteration),
+ y_predictor_(predictor_params_.process_noise,
+ predictor_params_.measurement_noise,
+ predictor_params_.min_stable_iteration) {}
+
+ void Reset() override;
+ void Update(Vec2 position, Time time) override;
+ std::vector<TipState> ConstructPrediction(
+ const TipState &last_state) const override;
+
+ struct State {
+ Vec2 position{0};
+ Vec2 velocity{0};
+ Vec2 acceleration{0};
+ Vec2 jerk{0};
+ };
+
+ // Returns the current estimate of the tip's true state, as modeled by the
+ // Kalman filters, or absl::nullopt if the predictor does not yet have enough
+ // data to make a reasonable estimate.
+ absl::optional<State> GetEstimatedState() const;
+
+ private:
+ bool IsStable() const {
+ return x_predictor_.Stable() && y_predictor_.Stable();
+ }
+
+ static void ConstructCubicConnector(const TipState &last_tip_state,
+ const State &estimated_state,
+ const KalmanPredictorParams &params,
+ Duration sample_dt,
+ std::vector<TipState> *output);
+
+ static void ConstructCubicPrediction(const State &estimated_state,
+ const KalmanPredictorParams &params,
+ Time start_time, Duration sample_dt,
+ int n_samples,
+ std::vector<TipState> *output);
+
+ int NumberOfPointsToPredict(const State &estimated_state) const;
+
+ KalmanPredictorParams predictor_params_;
+ SamplingParams sampling_params_;
+
+ absl::optional<Vec2> last_position_received_;
+
+ std::deque<Time> sample_times_;
+
+ AxisPredictor x_predictor_;
+ AxisPredictor y_predictor_;
+};
+
+} // namespace stroke_model
+} // namespace ink
+#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_PREDICTOR_H_
diff --git a/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc b/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc
new file mode 100644
index 0000000..66bac66
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc
@@ -0,0 +1,215 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/kalman_predictor.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/types/optional.h"
+#include "ink_stroke_modeler/internal/prediction/input_predictor.h"
+#include "ink_stroke_modeler/internal/type_matchers.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Matcher;
+using ::testing::Optional;
+
+constexpr float kTol = 1e-4;
+
+const KalmanPredictorParams kDefaultKalmanParams{
+ .process_noise = .00026458,
+ .measurement_noise = .026458,
+ .min_catchup_velocity = .01,
+ .prediction_interval = Duration(1. / 60),
+ .confidence_params{.max_estimation_distance = .04,
+ .min_travel_speed = 3,
+ .max_travel_speed = 15,
+ .max_linear_deviation = .2}};
+constexpr SamplingParams kDefaultSamplingParams{
+ .min_output_rate = 180, .end_of_stroke_stopping_distance = .001};
+
+// Matcher for the State. Note that, because each of position,
+// velocity, acceleration, and jerk are divided by increasing powers of the time
+// delta, the values grow exponentially. As such, this uses a relative tolerance
+// unless one of the arguments is exactly zero.
+MATCHER_P5(StateNearMatcher, position, velocity, acceleration, jerk,
+ relative_tol, "") {
+ auto within_relative_tol = [](float lhs, float rhs, float tol) {
+ if (lhs == 0 || rhs == 0) {
+ return std::abs(lhs - rhs) < tol;
+ }
+ return std::abs(lhs / rhs - 1.f) < tol;
+ };
+
+ if (within_relative_tol(position.x, arg.position.x, relative_tol) &&
+ within_relative_tol(position.y, arg.position.y, relative_tol) &&
+ within_relative_tol(velocity.x, arg.velocity.x, relative_tol) &&
+ within_relative_tol(velocity.y, arg.velocity.y, relative_tol) &&
+ within_relative_tol(acceleration.x, arg.acceleration.x, relative_tol) &&
+ within_relative_tol(acceleration.y, arg.acceleration.y, relative_tol) &&
+ within_relative_tol(jerk.x, arg.jerk.x, relative_tol) &&
+ within_relative_tol(jerk.y, arg.jerk.y, relative_tol)) {
+ return true;
+ }
+
+ *result_listener << "\n expected:" //
+ << "\n p = " << position //
+ << "\n v = " << velocity //
+ << "\n a = " << acceleration //
+ << "\n j = " << jerk //
+ << "\n actual:" //
+ << "\n p = " << arg.position //
+ << "\n v = " << arg.velocity //
+ << "\n a = " << arg.acceleration //
+ << "\n j = " << arg.jerk;
+ return false;
+}
+
+// Wrapping the matcher in a function allows the compiler to perform template
+// deduction, so we can specify arguments as initializer lists.
+Matcher<KalmanPredictor::State> StateNear(Vec2 position, Vec2 velocity,
+ Vec2 acceleration, Vec2 jerk,
+ float tolerance) {
+ return StateNearMatcher(position, velocity, acceleration, jerk, tolerance);
+}
+
+TEST(KalmanPredictorTest, EmptyPrediction) {
+ KalmanPredictor predictor{kDefaultKalmanParams, kDefaultSamplingParams};
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{4, 3}, {2, -4}, Time{3}}).empty());
+
+ predictor.Update({1, 3}, Time{4});
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{1, 3}, {0, 0}, Time{3.1}}).empty());
+}
+
+TEST(KalmanPredictorTest, TypicalCase) {
+ KalmanPredictor predictor{kDefaultKalmanParams, kDefaultSamplingParams};
+
+ predictor.Update({0, 0}, Time{0});
+ predictor.Update({.1, 0}, Time{.01});
+ predictor.Update({.2, 0}, Time{.02});
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{4, 3}, {2, -4}, Time{3}}).empty());
+
+ predictor.Update({.3, 0}, Time{.03});
+ EXPECT_THAT(predictor.GetEstimatedState(),
+ Optional(StateNear({.30078, 0}, {13.584, 0}, {-66.806, 0},
+ {-3382.8, 0}, kTol)));
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{.2, 0}, {10, 0}, Time{.03}}),
+ ElementsAre(TipStateNear({{.2454, 0}, {7.7094, 0}, Time{.0356}}, kTol),
+ TipStateNear({{.3008, 0}, {13.5837, 0}, Time{.0411}}, kTol),
+ TipStateNear({{.3751, 0}, {13.1604, 0}, Time{.0467}}, kTol)));
+
+ predictor.Update({.5, .1}, Time{.04});
+ EXPECT_THAT(predictor.GetEstimatedState(),
+ Optional(StateNear({.49705, .097146}, {28.217, 16.732},
+ {671.91, 813.82}, {4454.3, 6998.2}, kTol)));
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{.3, 0}, {10, 0}, Time{.04}}),
+ ElementsAre(
+ TipStateNear({{.3732, .0253}, {17.047, 8.9317}, Time{.0456}}, kTol),
+ TipStateNear({{.497, .0971}, {28.2172, 16.7319}, Time{.0511}}, kTol),
+ TipStateNear({{.6643, .2029}, {32.0188, 21.3611}, Time{.0567}},
+ kTol)));
+}
+
+TEST(KalmanPredictorTest, AlternateParams) {
+ auto kalman_params = kDefaultKalmanParams;
+ auto sampling_params = kDefaultSamplingParams;
+ kalman_params.prediction_interval = Duration(.025);
+ sampling_params.min_output_rate = 200;
+ KalmanPredictor predictor{kalman_params, sampling_params};
+
+ predictor.Update({2, 5}, Time{1});
+ predictor.Update({2.2, 4.9}, Time{1.02});
+ predictor.Update({2.3, 4.7}, Time{1.04});
+ predictor.Update({2.3, 4.4}, Time{1.06});
+ EXPECT_THAT(
+ predictor.GetEstimatedState(),
+ Optional(StateNear({2.3016, 4.3992}, {-3.9981, -24.374},
+ {-338.22, -288.12}, {-1852.9, -584.31}, kTol)));
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{2.25, 4.75}, {1, -20}, Time{1.06}}),
+ ElementsAre(
+ TipStateNear({{2.27, 4.6417}, {5.917, -23.0547}, Time{1.065}}, kTol),
+ TipStateNear({{2.2982, 4.5221}, {4.251, -24.5126}, Time{1.07}}, kTol),
+ TipStateNear({{2.3016, 4.3992}, {-3.9981, -24.3736}, Time{1.075}},
+ kTol),
+ TipStateNear({{2.2773, 4.2738}, {-5.7123, -25.8215}, Time{1.08}},
+ kTol)));
+
+ predictor.Update({2.2, 4.2}, Time{1.08});
+ EXPECT_THAT(predictor.GetEstimatedState(),
+ Optional(StateNear({2.1987, 4.1933}, {-11.457, -11.953},
+ {-328.01, 185.32}, {-1133.8, 1569.8}, kTol)));
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{2.25, 4.5}, {-1, -20}, Time{1.08}}),
+ ElementsAre(
+ TipStateNear({{2.2499, 4.407}, {.5082, -17.2661}, Time{1.085}}, kTol),
+ TipStateNear({{2.2505, 4.3265}, {-.7319, -15.0137}, Time{1.09}},
+ kTol),
+ TipStateNear({{2.238, 4.2561}, {-4.7203, -13.2427}, Time{1.095}},
+ kTol),
+ TipStateNear({{2.1987, 4.1933}, {-11.4569, -11.9531}, Time{1.1}},
+ kTol),
+ TipStateNear({{2.1373, 4.1359}, {-13.1112, -11.0068}, Time{1.105}},
+ kTol)));
+}
+
+TEST(KalmanPredictorTest, Reset) {
+ KalmanPredictor predictor{kDefaultKalmanParams, kDefaultSamplingParams};
+
+ predictor.Update({4, -4}, Time{6});
+ predictor.Update({-6, 9}, Time{6.03});
+ predictor.Update({10, 5}, Time{6.06});
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{6.06}}).empty());
+
+ predictor.Update({2, 4}, Time{6.09});
+ EXPECT_NE(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_FALSE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{6.06}}).empty());
+
+ predictor.Reset();
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{6.09}}).empty());
+
+ predictor.Update({-9, 3}, Time{2});
+ predictor.Update({-6, -1}, Time{2.1});
+ predictor.Update({6, -6}, Time{2.2});
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{2.2}}).empty());
+
+ predictor.Update({3, 6}, Time{2.3});
+ EXPECT_NE(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_FALSE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{2.3}}).empty());
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/prediction/stroke_end_predictor.cc b/ink_stroke_modeler/internal/prediction/stroke_end_predictor.cc
new file mode 100644
index 0000000..3e83720
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/stroke_end_predictor.cc
@@ -0,0 +1,52 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/stroke_end_predictor.h"
+
+#include <iterator>
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/internal/position_modeler.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+void StrokeEndPredictor::Update(Vec2 position, Time time) {
+ last_position_ = position;
+}
+
+std::vector<TipState> StrokeEndPredictor::ConstructPrediction(
+ const TipState &last_state) const {
+ if (!last_position_) {
+ // We don't yet have enough data to construct a prediction.
+ return {};
+ }
+
+ std::vector<TipState> prediction;
+ prediction.reserve(sampling_params_.end_of_stroke_max_iterations);
+ PositionModeler modeler;
+ modeler.Reset(last_state, position_modeler_params_);
+ modeler.ModelEndOfStroke(*last_position_,
+ Duration(1. / sampling_params_.min_output_rate),
+ sampling_params_.end_of_stroke_max_iterations,
+ sampling_params_.end_of_stroke_stopping_distance,
+ std::back_inserter(prediction));
+ return prediction;
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/prediction/stroke_end_predictor.h b/ink_stroke_modeler/internal/prediction/stroke_end_predictor.h
new file mode 100644
index 0000000..420d596
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/stroke_end_predictor.h
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_STROKE_END_PREDICTOR_H_
+#define INK_STROKE_MODELER_INTERNAL_PREDICTION_STROKE_END_PREDICTOR_H_
+
+#include <vector>
+
+#include "absl/types/optional.h"
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/internal/prediction/input_predictor.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// This class constructs a prediction using the same PositionModeler class as
+// the SpringBasedModeler, fixing the anchor position and allowing the stroke to
+// "catch up". The way the prediction is constructed is very similar to how the
+// SpringBasedModeler models the end of a stroke.
+class StrokeEndPredictor : public InputPredictor {
+ public:
+ explicit StrokeEndPredictor(
+ const PositionModelerParams &position_modeler_params,
+ const SamplingParams &sampling_params)
+ : position_modeler_params_(position_modeler_params),
+ sampling_params_(sampling_params) {}
+
+ void Reset() override { last_position_ = absl::nullopt; }
+ void Update(Vec2 position, Time time) override;
+ std::vector<TipState> ConstructPrediction(
+ const TipState &last_state) const override;
+
+ private:
+ PositionModelerParams position_modeler_params_;
+ SamplingParams sampling_params_;
+
+ absl::optional<Vec2> last_position_;
+};
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_STROKE_END_PREDICTOR_H_
diff --git a/ink_stroke_modeler/internal/prediction/stroke_end_predictor_test.cc b/ink_stroke_modeler/internal/prediction/stroke_end_predictor_test.cc
new file mode 100644
index 0000000..5a7750d
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/stroke_end_predictor_test.cc
@@ -0,0 +1,137 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/stroke_end_predictor.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "ink_stroke_modeler/internal/prediction/input_predictor.h"
+#include "ink_stroke_modeler/internal/type_matchers.h"
+#include "ink_stroke_modeler/params.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::IsEmpty;
+using ::testing::Not;
+
+constexpr float kTol = 1e-4;
+
+constexpr SamplingParams kDefaultSamplingParams{
+ .min_output_rate = 180,
+ .end_of_stroke_stopping_distance = .001,
+ .end_of_stroke_max_iterations = 20};
+
+TEST(StrokeEndPredictorTest, EmptyPrediction) {
+ StrokeEndPredictor predictor{PositionModelerParams{}, kDefaultSamplingParams};
+ EXPECT_THAT(predictor.ConstructPrediction({{4, 6}, {-1, 1}, Time{5}}),
+ IsEmpty());
+
+ predictor.Reset();
+ EXPECT_THAT(predictor.ConstructPrediction({{-2, 11}, {0, 0}, Time{1}}),
+ IsEmpty());
+}
+
+TEST(StrokeEndPredictorTest, SingleInput) {
+ StrokeEndPredictor predictor{PositionModelerParams{}, kDefaultSamplingParams};
+ predictor.Update({4, 5}, Time{2});
+
+ EXPECT_THAT(predictor.ConstructPrediction({{4, 5}, {0, 0}, Time{2}}),
+ IsEmpty());
+}
+
+TEST(StrokeEndPredictorTest, MultipleInputs) {
+ StrokeEndPredictor predictor{PositionModelerParams{}, kDefaultSamplingParams};
+
+ predictor.Update({-1, 1}, Time{1});
+ EXPECT_THAT(predictor.ConstructPrediction({{-1, 1}, {0, 0}, Time{1}}),
+ IsEmpty());
+
+ predictor.Update({-1, 1.2}, Time{1.02});
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{-1, 1.1}, {0, 5}, Time{1.02}}),
+ ElementsAre(
+ TipStateNear({{-1, 1.1258}, {0, 4.6364}, Time{1.0256}}, kTol),
+ TipStateNear({{-1, 1.1480}, {0, 3.9967}, Time{1.0311}}, kTol),
+ TipStateNear({{-1, 1.1660}, {0, 3.2496}, Time{1.0367}}, kTol),
+ TipStateNear({{-1, 1.1799}, {0, 2.5059}, Time{1.0422}}, kTol),
+ TipStateNear({{-1, 1.1901}, {0, 1.8318}, Time{1.0478}}, kTol),
+ TipStateNear({{-1, 1.1971}, {0, 1.2609}, Time{1.0533}}, kTol),
+ TipStateNear({{-1, 1.2000}, {0, 1.0323}, Time{1.0561}}, kTol)));
+
+ predictor.Update({-1, 1.4}, Time{1.04});
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{-1, 1.2}, {0, 5}, Time{1.04}}),
+ ElementsAre(
+ TipStateNear({{-1, 1.2348}, {0, 6.2727}, Time{1.0455}}, kTol),
+ TipStateNear({{-1, 1.2708}, {0, 6.4661}, Time{1.0511}}, kTol),
+ TipStateNear({{-1, 1.3041}, {0, 5.9943}, Time{1.0566}}, kTol),
+ TipStateNear({{-1, 1.3328}, {0, 5.1663}, Time{1.0622}}, kTol),
+ TipStateNear({{-1, 1.3561}, {0, 4.1998}, Time{1.0677}}, kTol),
+ TipStateNear({{-1, 1.3741}, {0, 3.2381}, Time{1.0733}}, kTol),
+ TipStateNear({{-1, 1.3872}, {0, 2.3668}, Time{1.0788}}, kTol),
+ TipStateNear({{-1, 1.3963}, {0, 1.6288}, Time{1.0844}}, kTol),
+ TipStateNear({{-1, 1.4000}, {0, 1.3333}, Time{1.0872}}, kTol)));
+}
+
+TEST(StrokeEndPredictorTest, Reset) {
+ StrokeEndPredictor predictor{PositionModelerParams{}, kDefaultSamplingParams};
+
+ predictor.Update({-9, 6}, Time{5});
+ EXPECT_THAT(predictor.ConstructPrediction({{-9, 6}, {0, 0}, Time{5}}),
+ IsEmpty());
+ predictor.Update({1, 4}, Time{7});
+ EXPECT_THAT(predictor.ConstructPrediction({{-4, 5}, {5, -1}, Time{7}}),
+ Not(IsEmpty()));
+
+ predictor.Reset();
+ EXPECT_THAT(predictor.ConstructPrediction({{0, 1}, {0, 0}, Time{1}}),
+ IsEmpty());
+}
+
+TEST(StrokeEndPredictorTest, AlternateSamplingParams) {
+ StrokeEndPredictor predictor{
+ PositionModelerParams{},
+ SamplingParams{.min_output_rate = 200,
+ .end_of_stroke_stopping_distance = .005}};
+
+ predictor.Update({4, -7}, Time{3});
+ EXPECT_THAT(predictor.ConstructPrediction({{4, -7}, {0, 0}, Time{3}}),
+ IsEmpty());
+
+ predictor.Update({4.2, -6.8}, Time{3.01});
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{4.1, -6.9}, {2, 2}, Time{3.01}}),
+ ElementsAre(
+ TipStateNear({{4.1138, -6.8862}, {2.7527, 2.7527}, Time{3.015}},
+ kTol),
+ TipStateNear({{4.1289, -6.8711}, {3.0318, 3.0318}, Time{3.02}}, kTol),
+ TipStateNear({{4.1439, -6.8561}, {2.9871, 2.9871}, Time{3.025}},
+ kTol),
+ TipStateNear({{4.1576, -6.8424}, {2.7386, 2.7386}, Time{3.03}}, kTol),
+ TipStateNear({{4.1694, -6.8306}, {2.3779, 2.3779}, Time{3.035}},
+ kTol),
+ TipStateNear({{4.1793, -6.8207}, {1.9719, 1.9719}, Time{3.04}}, kTol),
+ TipStateNear({{4.1871, -6.8129}, {1.5669, 1.5669}, Time{3.045}},
+ kTol),
+ TipStateNear({{4.1931, -6.8069}, {1.1923, 1.1923}, Time{3.05}}, kTol),
+ TipStateNear({{4.1974, -6.8026}, {0.8647, 0.8647}, Time{3.055}},
+ kTol)));
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/stylus_state_modeler.cc b/ink_stroke_modeler/internal/stylus_state_modeler.cc
new file mode 100644
index 0000000..e7c1bbc
--- /dev/null
+++ b/ink_stroke_modeler/internal/stylus_state_modeler.cc
@@ -0,0 +1,101 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/stylus_state_modeler.h"
+
+#include <limits>
+
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/internal/utils.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+void StylusStateModeler::Update(Vec2 position, const StylusState &state) {
+ if (state.pressure < 0) received_unknown_pressure_ = true;
+ if (state.tilt < 0) received_unknown_tilt_ = true;
+ if (state.orientation < 0) received_unknown_orientation_ = true;
+
+ if (received_unknown_pressure_ && received_unknown_tilt_ &&
+ received_unknown_orientation_) {
+ // We've stopped tracking all fields, so there's no need to keep updating.
+ positions_and_states_.clear();
+ return;
+ }
+
+ positions_and_states_.push_back({position, state});
+
+ if (params_.max_input_samples < 0 ||
+ positions_and_states_.size() > (uint)params_.max_input_samples) {
+ positions_and_states_.pop_front();
+ }
+}
+
+void StylusStateModeler::Reset(const StylusStateModelerParams &params) {
+ params_ = params;
+ positions_and_states_.clear();
+ received_unknown_pressure_ = false;
+ received_unknown_tilt_ = false;
+ received_unknown_orientation_ = false;
+}
+
+StylusState StylusStateModeler::Query(Vec2 position) const {
+ if (positions_and_states_.empty())
+ return {.pressure = -1, .tilt = -1, .orientation = -1};
+
+ if (positions_and_states_.size() == 1) {
+ const auto &state = positions_and_states_.front().state;
+ return {
+ .pressure = received_unknown_pressure_ ? -1 : state.pressure,
+ .tilt = received_unknown_tilt_ ? -1 : state.tilt,
+ .orientation = received_unknown_orientation_ ? -1 : state.orientation};
+ }
+
+ int closest_segment = -1;
+ float min_distance = std::numeric_limits<float>::infinity();
+ float interp_value = 0;
+ for (decltype(positions_and_states_.size()) i = 0;
+ i < positions_and_states_.size() - 1; ++i) {
+ const Vec2 segment_start = positions_and_states_[i].position;
+ const Vec2 segment_end = positions_and_states_[i + 1].position;
+ float param = NearestPointOnSegment(segment_start, segment_end, position);
+ float distance =
+ Distance(position, Interp(segment_start, segment_end, param));
+ if (distance <= min_distance) {
+ closest_segment = i;
+ min_distance = distance;
+ interp_value = param;
+ }
+ }
+
+ auto from_state = positions_and_states_[closest_segment].state;
+ auto to_state = positions_and_states_[closest_segment + 1].state;
+ return StylusState{
+ .pressure =
+ received_unknown_pressure_
+ ? -1
+ : Interp(from_state.pressure, to_state.pressure, interp_value),
+ .tilt = received_unknown_tilt_
+ ? -1
+ : Interp(from_state.tilt, to_state.tilt, interp_value),
+ .orientation = received_unknown_orientation_
+ ? -1
+ : InterpAngle(from_state.orientation,
+ to_state.orientation, interp_value)};
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/stylus_state_modeler.h b/ink_stroke_modeler/internal/stylus_state_modeler.h
new file mode 100644
index 0000000..a5b925f
--- /dev/null
+++ b/ink_stroke_modeler/internal/stylus_state_modeler.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_STYLUS_STATE_MODELER_H_
+#define INK_STROKE_MODELER_INTERNAL_STYLUS_STATE_MODELER_H_
+
+#include <deque>
+#include <ostream>
+
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// This class is used to model the state of the stylus for a given position,
+// based on the state of the stylus at the original input points.
+//
+// The stylus is modeled by storing the last max_input_samples positions and
+// states received via Update(); when queried, it treats the stored positions as
+// a polyline, and finds the closest segment. The returned stylus state is a
+// linear interpolation between the states associated with the endpoints of the
+// segment, correcting angles to account for the "wraparound" that occurs at 0
+// and 2π. The value used for interpolation is based on how far along the
+// segment the closest point lies.
+//
+// If Update() is called with a state in which a field (i.e. pressure, tilt, or
+// orientation) has a negative value (indicating no information), then the
+// results of Query() will be -1 for that field until Reset() is called. This is
+// tracked independently for each field; e.g., if you pass in tilt = -1, then
+// pressure and orientation will continue to be interpolated normally.
+class StylusStateModeler {
+ public:
+ // Adds a position and state pair to the model. During stroke modeling, these
+ // values will be taken from the raw input.
+ void Update(Vec2 position, const StylusState &state);
+
+ // Clear the model and reset.
+ void Reset(const StylusStateModelerParams &params);
+
+ // Query the model for the state at the given position. During stroke
+ // modeling, the position will be taken from the modeled input.
+ //
+ // If no Update() calls have been received since the last Reset(), this will
+ // return {.pressure = -1, .tilt = -1, .orientation = -1}.
+ StylusState Query(Vec2 position) const;
+
+ private:
+ struct PositionAndState {
+ Vec2 position{0};
+ StylusState state;
+
+ PositionAndState(Vec2 position_in, const StylusState &state_in)
+ : position(position_in), state(state_in) {}
+ };
+
+ bool received_unknown_pressure_ = false;
+ bool received_unknown_tilt_ = false;
+ bool received_unknown_orientation_ = false;
+
+ std::deque<PositionAndState> positions_and_states_;
+ StylusStateModelerParams params_;
+};
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_STYLUS_STATE_MODELER_H_
diff --git a/ink_stroke_modeler/internal/stylus_state_modeler_test.cc b/ink_stroke_modeler/internal/stylus_state_modeler_test.cc
new file mode 100644
index 0000000..c905ec4
--- /dev/null
+++ b/ink_stroke_modeler/internal/stylus_state_modeler_test.cc
@@ -0,0 +1,269 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/stylus_state_modeler.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/internal/type_matchers.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+constexpr float kTol = 1e-5;
+constexpr StylusState kUnknown{.pressure = -1, .tilt = -1, .orientation = -1};
+
+TEST(StylusStateModelerTest, QueryEmpty) {
+ StylusStateModeler modeler;
+ EXPECT_EQ(modeler.Query({0, 0}), kUnknown);
+ EXPECT_EQ(modeler.Query({-5, 3}), kUnknown);
+}
+
+TEST(StylusStateModelerTest, QuerySingleInput) {
+ StylusStateModeler modeler;
+ modeler.Update({0, 0}, {.pressure = 0.75, .tilt = 0.75, .orientation = 0.75});
+ EXPECT_THAT(modeler.Query({0, 0}),
+ StylusStateNear(
+ {.pressure = .75, .tilt = .75, .orientation = .75}, kTol));
+ EXPECT_THAT(modeler.Query({1, 1}),
+ StylusStateNear(
+ {.pressure = .75, .tilt = .75, .orientation = .75}, kTol));
+}
+
+TEST(StylusStateModelerTest, QueryMultipleInputs) {
+ StylusStateModeler modeler;
+ modeler.Update({.5, 1.5}, {.pressure = .3, .tilt = .8, .orientation = .1});
+ modeler.Update({2, 1.5}, {.pressure = .6, .tilt = .5, .orientation = .7});
+ modeler.Update({3, 3.5}, {.pressure = .8, .tilt = .1, .orientation = .3});
+ modeler.Update({3.5, 4}, {.pressure = .2, .tilt = .2, .orientation = .2});
+
+ EXPECT_THAT(
+ modeler.Query({0, 2}),
+ StylusStateNear({.pressure = .3, .tilt = .8, .orientation = .1}, kTol));
+ EXPECT_THAT(
+ modeler.Query({1, 2}),
+ StylusStateNear({.pressure = .4, .tilt = .7, .orientation = .3}, kTol));
+ EXPECT_THAT(
+ modeler.Query({2, 1.5}),
+ StylusStateNear({.pressure = .6, .tilt = .5, .orientation = .7}, kTol));
+ EXPECT_THAT(
+ modeler.Query({2.5, 1.875}),
+ StylusStateNear({.pressure = .65, .tilt = .4, .orientation = .6}, kTol));
+ EXPECT_THAT(
+ modeler.Query({2.5, 3.125}),
+ StylusStateNear({.pressure = .75, .tilt = .2, .orientation = .4}, kTol));
+ EXPECT_THAT(
+ modeler.Query({2.5, 4}),
+ StylusStateNear({.pressure = .8, .tilt = .1, .orientation = .3}, kTol));
+ EXPECT_THAT(
+ modeler.Query({3, 4}),
+ StylusStateNear({.pressure = .5, .tilt = .15, .orientation = .25}, kTol));
+ EXPECT_THAT(
+ modeler.Query({4, 4}),
+ StylusStateNear({.pressure = .2, .tilt = .2, .orientation = .2}, kTol));
+}
+
+TEST(StylusStateModelerTest, QueryStaleInputsAreDiscarded) {
+ StylusStateModeler modeler;
+ modeler.Update({1, 1}, {.pressure = .6, .tilt = .5, .orientation = .4});
+ modeler.Update({-1, 2}, {.pressure = .3, .tilt = .7, .orientation = .6});
+ modeler.Update({-4, 0}, {.pressure = .9, .tilt = .7, .orientation = .3});
+ modeler.Update({-6, -3}, {.pressure = .4, .tilt = .3, .orientation = .5});
+ modeler.Update({-5, -5}, {.pressure = .3, .tilt = .3, .orientation = .1});
+ modeler.Update({-3, -4}, {.pressure = .6, .tilt = .8, .orientation = .3});
+ modeler.Update({-6, -7}, {.pressure = .9, .tilt = .8, .orientation = .1});
+ modeler.Update({-9, -8}, {.pressure = .8, .tilt = .2, .orientation = .2});
+ modeler.Update({-11, -5}, {.pressure = .2, .tilt = .4, .orientation = .7});
+ modeler.Update({-10, -2}, {.pressure = .7, .tilt = .3, .orientation = .2});
+
+ EXPECT_THAT(
+ modeler.Query({2, 0}),
+ StylusStateNear({.pressure = .6, .tilt = .5, .orientation = .4}, kTol));
+ EXPECT_THAT(
+ modeler.Query({1, 3.5}),
+ StylusStateNear({.pressure = .45, .tilt = .6, .orientation = .5}, kTol));
+ EXPECT_THAT(
+ modeler.Query({-3, 17. / 6}),
+ StylusStateNear({.pressure = .5, .tilt = .7, .orientation = .5}, kTol));
+
+ // This causes the point at {1, 1} to be discarded.
+ modeler.Update({-8, 0}, {.pressure = .6, .tilt = .8, .orientation = .9});
+ EXPECT_THAT(
+ modeler.Query({2, 0}),
+ StylusStateNear({.pressure = .3, .tilt = .7, .orientation = .6}, kTol));
+ EXPECT_THAT(
+ modeler.Query({1, 3.5}),
+ StylusStateNear({.pressure = .3, .tilt = .7, .orientation = .6}, kTol));
+ EXPECT_THAT(
+ modeler.Query({-3, 17. / 6}),
+ StylusStateNear({.pressure = .5, .tilt = .7, .orientation = .5}, kTol));
+
+ // This causes the point at {-1, 2} to be discarded.
+ modeler.Update({-8, 0}, {.pressure = .6, .tilt = .8, .orientation = .9});
+ EXPECT_THAT(
+ modeler.Query({2, 0}),
+ StylusStateNear({.pressure = .9, .tilt = .7, .orientation = .3}, kTol));
+ EXPECT_THAT(
+ modeler.Query({1, 3.5}),
+ StylusStateNear({.pressure = .9, .tilt = .7, .orientation = .3}, kTol));
+ EXPECT_THAT(
+ modeler.Query({-3, 17. / 6}),
+ StylusStateNear({.pressure = .9, .tilt = .7, .orientation = .3}, kTol));
+}
+
+TEST(StylusStateModelerTest, QueryCyclicOrientationInterpolation) {
+ StylusStateModeler modeler;
+ modeler.Update({0, 0}, {.pressure = 0, .tilt = 0, .orientation = 1.8 * M_PI});
+ modeler.Update({0, 1}, {.pressure = 0, .tilt = 0, .orientation = .2 * M_PI});
+ modeler.Update({0, 2}, {.pressure = 0, .tilt = 0, .orientation = 1.6 * M_PI});
+
+ EXPECT_NEAR(modeler.Query({0, .25}).orientation, 1.9 * M_PI, 1e-5);
+ EXPECT_NEAR(modeler.Query({0, .75}).orientation, .1 * M_PI, 1e-5);
+ EXPECT_NEAR(modeler.Query({0, 1.25}).orientation, .05 * M_PI, 1e-5);
+ EXPECT_NEAR(modeler.Query({0, 1.75}).orientation, 1.75 * M_PI, 1e-5);
+}
+
+TEST(StylusStateModelerTest, QueryAndReset) {
+ StylusStateModeler modeler;
+
+ modeler.Update({4, 5}, {.pressure = .4, .tilt = .9, .orientation = .1});
+ modeler.Update({7, 8}, {.pressure = .1, .tilt = .2, .orientation = .5});
+ EXPECT_THAT(
+ modeler.Query({10, 12}),
+ StylusStateNear({.pressure = .1, .tilt = .2, .orientation = .5}, kTol));
+
+ modeler.Reset(StylusStateModelerParams{});
+ EXPECT_EQ(modeler.Query({10, 12}), kUnknown);
+
+ modeler.Update({-1, 4}, {.pressure = .4, .tilt = .6, .orientation = .8});
+ EXPECT_THAT(
+ modeler.Query({6, 7}),
+ StylusStateNear({.pressure = .4, .tilt = .6, .orientation = .8}, kTol));
+
+ modeler.Update({-3, 0}, {.pressure = .7, .tilt = .2, .orientation = .5});
+ EXPECT_THAT(
+ modeler.Query({-2, 2}),
+ StylusStateNear({.pressure = .55, .tilt = .4, .orientation = .65}, kTol));
+ EXPECT_THAT(
+ modeler.Query({0, 5}),
+ StylusStateNear({.pressure = .4, .tilt = .6, .orientation = .8}, kTol));
+}
+
+TEST(StylusStateModelerTest, UpdateWithUnknownState) {
+ StylusStateModeler modeler;
+
+ modeler.Update({1, 2}, {.pressure = .1, .tilt = .2, .orientation = .3});
+ modeler.Update({2, 3}, {.pressure = .3, .tilt = .4, .orientation = .5});
+ EXPECT_THAT(
+ modeler.Query({2, 2}),
+ StylusStateNear({.pressure = .2, .tilt = .3, .orientation = .4}, kTol));
+
+ modeler.Update({5, 5}, kUnknown);
+ EXPECT_EQ(modeler.Query({5, 5}), kUnknown);
+
+ modeler.Update({2, 3}, {.pressure = .3, .tilt = .4, .orientation = .5});
+ EXPECT_EQ(modeler.Query({1, 2}), kUnknown);
+
+ modeler.Update({-1, 3}, kUnknown);
+ EXPECT_EQ(modeler.Query({7, 9}), kUnknown);
+
+ modeler.Reset(StylusStateModelerParams{});
+ modeler.Update({3, 3}, {.pressure = .7, .tilt = .6, .orientation = .5});
+ EXPECT_THAT(
+ modeler.Query({3, 3}),
+ StylusStateNear({.pressure = .7, .tilt = .6, .orientation = .5}, kTol));
+}
+
+TEST(StylusStateModelerTest, ModelPressureOnly) {
+ StylusStateModeler modeler;
+
+ modeler.Update({0, 0}, {.pressure = .5, .tilt = -2, .orientation = -.1});
+ EXPECT_THAT(
+ modeler.Query({1, 1}),
+ StylusStateNear({.pressure = .5, .tilt = -1, .orientation = -1}, kTol));
+
+ modeler.Update({2, 0}, {.pressure = .7, .tilt = -2, .orientation = -.1});
+ EXPECT_THAT(
+ modeler.Query({1, 1}),
+ StylusStateNear({.pressure = .6, .tilt = -1, .orientation = -1}, kTol));
+}
+
+TEST(StylusStateModelerTest, ModelTiltOnly) {
+ StylusStateModeler modeler;
+
+ modeler.Update({0, 0}, {.pressure = -2, .tilt = .5, .orientation = -.1});
+ EXPECT_THAT(
+ modeler.Query({1, 1}),
+ StylusStateNear({.pressure = -1, .tilt = .5, .orientation = -1}, kTol));
+
+ modeler.Update({2, 0}, {.pressure = -2, .tilt = .3, .orientation = -.1});
+ EXPECT_THAT(
+ modeler.Query({1, 1}),
+ StylusStateNear({.pressure = -1, .tilt = .4, .orientation = -1}, kTol));
+}
+
+TEST(StylusStateModelerTest, ModelOrientationOnly) {
+ StylusStateModeler modeler;
+
+ modeler.Update({0, 0}, {.pressure = -2, .tilt = -.1, .orientation = 1});
+ EXPECT_THAT(
+ modeler.Query({1, 1}),
+ StylusStateNear({.pressure = -1, .tilt = -1, .orientation = 1}, kTol));
+
+ modeler.Update({2, 0}, {.pressure = -2, .tilt = -.3, .orientation = 2});
+ EXPECT_THAT(
+ modeler.Query({1, 1}),
+ StylusStateNear({.pressure = -1, .tilt = -1, .orientation = 1.5}, kTol));
+}
+
+TEST(StylusStateModelerTest, DropFieldsOneByOne) {
+ StylusStateModeler modeler;
+
+ modeler.Update({0, 0}, {.pressure = .5, .tilt = .5, .orientation = .5});
+ EXPECT_THAT(
+ modeler.Query({1, 0}),
+ StylusStateNear({.pressure = .5, .tilt = .5, .orientation = .5}, kTol));
+
+ modeler.Update({2, 0}, {.pressure = .3, .tilt = .7, .orientation = -1});
+ EXPECT_THAT(
+ modeler.Query({1, 0}),
+ StylusStateNear({.pressure = .4, .tilt = .6, .orientation = -1}, kTol));
+
+ modeler.Update({4, 0}, {.pressure = .1, .tilt = -1, .orientation = 1});
+ EXPECT_THAT(
+ modeler.Query({3, 0}),
+ StylusStateNear({.pressure = .2, .tilt = -1, .orientation = -1}, kTol));
+
+ modeler.Update({6, 0}, {.pressure = -1, .tilt = .2, .orientation = 0});
+ EXPECT_THAT(modeler.Query({5, 0}), StylusStateNear(kUnknown, kTol));
+
+ modeler.Update({8, 0}, {.pressure = .3, .tilt = .4, .orientation = .5});
+ EXPECT_THAT(modeler.Query({7, 0}), StylusStateNear(kUnknown, kTol));
+
+ modeler.Reset(StylusStateModelerParams{});
+ EXPECT_THAT(modeler.Query({1, 0}), StylusStateNear(kUnknown, kTol));
+
+ modeler.Update({0, 0}, {.pressure = .1, .tilt = .8, .orientation = .3});
+ EXPECT_THAT(
+ modeler.Query({1, 0}),
+ StylusStateNear({.pressure = .1, .tilt = .8, .orientation = .3}, kTol));
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/type_matchers.cc b/ink_stroke_modeler/internal/type_matchers.cc
new file mode 100644
index 0000000..f974b61
--- /dev/null
+++ b/ink_stroke_modeler/internal/type_matchers.cc
@@ -0,0 +1,69 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/type_matchers.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+using ::testing::DoubleNear;
+using ::testing::FloatEq;
+using ::testing::FloatNear;
+using ::testing::Matcher;
+using ::testing::Matches;
+
+MATCHER_P(Vec2EqMatcher, expected, "") {
+ return Matches(FloatEq(expected.x))(arg.x) &&
+ Matches(FloatEq(expected.y))(arg.y);
+}
+
+MATCHER_P2(Vec2NearMatcher, expected, tolerance, "") {
+ return Matches(FloatNear(expected.x, tolerance))(arg.x) &&
+ Matches(FloatNear(expected.y, tolerance))(arg.y);
+}
+MATCHER_P2(TipStateNearMatcher, expected, tolerance, "") {
+ return Matches(Vec2Near(expected.position, tolerance))(arg.position) &&
+ Matches(Vec2Near(expected.velocity, tolerance))(arg.velocity) &&
+ Matches(DoubleNear(expected.time.Value(), tolerance))(
+ arg.time.Value());
+}
+
+MATCHER_P2(StylusStateNearMatcher, expected, tolerance, "") {
+ return Matches(FloatNear(expected.pressure, tolerance))(arg.pressure) &&
+ Matches(FloatNear(expected.tilt, tolerance))(arg.tilt) &&
+ Matches(FloatNear(expected.orientation, tolerance))(arg.orientation);
+}
+
+} // namespace
+
+Matcher<Vec2> Vec2Eq(const Vec2 v) { return Vec2EqMatcher(v); }
+Matcher<Vec2> Vec2Near(const Vec2 v, float tolerance) {
+ return Vec2NearMatcher(v, tolerance);
+}
+Matcher<TipState> TipStateNear(const TipState &expected, float tolerance) {
+ return TipStateNearMatcher(expected, tolerance);
+}
+Matcher<StylusState> StylusStateNear(const StylusState &expected,
+ float tolerance) {
+ return StylusStateNearMatcher(expected, tolerance);
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/type_matchers.h b/ink_stroke_modeler/internal/type_matchers.h
new file mode 100644
index 0000000..97f8876
--- /dev/null
+++ b/ink_stroke_modeler/internal/type_matchers.h
@@ -0,0 +1,42 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_TYPE_MATCHERS_H_
+#define INK_STROKE_MODELER_INTERNAL_TYPE_MATCHERS_H_
+
+#include "gtest/gtest.h"
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// These matchers compare Vec2s component-wise, delegating to
+// ::testing::FloatEq() and ::testing::FloatNear(), respectively.
+::testing::Matcher<Vec2> Vec2Eq(const Vec2 v);
+::testing::Matcher<Vec2> Vec2Near(const Vec2 v, float tolerance);
+
+// These convenience matchers perform comparisons using ::testing::FloatNear(),
+// ::testing::DoubleNear(), and Vec2Near().
+::testing::Matcher<TipState> TipStateNear(const TipState &expected,
+ float tolerance);
+::testing::Matcher<StylusState> StylusStateNear(const StylusState &expected,
+ float tolerance);
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_TYPE_MATCHERS_H_
diff --git a/ink_stroke_modeler/internal/utils.h b/ink_stroke_modeler/internal/utils.h
new file mode 100644
index 0000000..6bf2a27
--- /dev/null
+++ b/ink_stroke_modeler/internal/utils.h
@@ -0,0 +1,99 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_UTILS_H_
+#define INK_STROKE_MODELER_INTERNAL_UTILS_H_
+
+#include <cmath>
+
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// General utility functions for use within the stroke model.
+
+// Clamps the given value to the range [0, 1].
+inline float Clamp01(float value) {
+ if (value < 0.f) return 0.f;
+ if (value > 1.f) return 1.f;
+ return value;
+}
+
+// Returns the ratio of the difference from `start` to `value` and the
+// difference from `start` to `end`, clamped to the range [0, 1]. If
+// `start` == `end`, returns 1 if `value` > `start`, 0 otherwise.
+inline float Normalize01(float start, float end, float value) {
+ if (start == end) {
+ return value > start ? 1 : 0;
+ }
+ return Clamp01((value - start) / (end - start));
+}
+
+// Linearly interpolates between `start` and `end`, clamping the interpolation
+// value to the range [0, 1].
+template <typename ValueType>
+inline ValueType Interp(ValueType start, ValueType end, float interp_amount) {
+ return start + (end - start) * Clamp01(interp_amount);
+}
+
+// Linearly interpolates from `start` to `end`, traveling around the shorter
+// path (e.g. interpolating from π/4 to 7π/4 is equivalent to interpolating from
+// π/4 to 0, then 2π to 7π/4). The returned angle will be normalized to the
+// interval [0, 2π). All angles are measured in radians.
+inline float InterpAngle(float start, float end, float interp_amount) {
+ auto normalize_angle = [](float angle) {
+ while (angle < 0) angle += 2 * M_PI;
+ while (angle > 2 * M_PI) angle -= 2 * M_PI;
+ return angle;
+ };
+
+ start = normalize_angle(start);
+ end = normalize_angle(end);
+ float delta = end - start;
+ if (delta < -M_PI) {
+ end += 2 * M_PI;
+ } else if (delta > M_PI) {
+ end -= 2 * M_PI;
+ }
+ return normalize_angle(Interp(start, end, interp_amount));
+}
+
+// Returns the distance between two points.
+inline float Distance(Vec2 start, Vec2 end) {
+ return (end - start).Magnitude();
+}
+
+// Returns the point on the line segment from `segment_start` to `segment_end`
+// that is closest to `point`, represented as the ratio of the length along the
+// segment.
+inline float NearestPointOnSegment(Vec2 segment_start, Vec2 segment_end,
+ Vec2 point) {
+ if (segment_start == segment_end) return 0;
+
+ auto dot_product = [](Vec2 lhs, Vec2 rhs) {
+ return lhs.x * rhs.x + lhs.y * rhs.y;
+ };
+ Vec2 segment_vector = segment_end - segment_start;
+ Vec2 projection_vector = point - segment_start;
+ return Clamp01(dot_product(projection_vector, segment_vector) /
+ dot_product(segment_vector, segment_vector));
+}
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_UTILS_H_
diff --git a/ink_stroke_modeler/internal/utils_test.cc b/ink_stroke_modeler/internal/utils_test.cc
new file mode 100644
index 0000000..b50d21a
--- /dev/null
+++ b/ink_stroke_modeler/internal/utils_test.cc
@@ -0,0 +1,87 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/utils.h"
+
+#include <cmath>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "ink_stroke_modeler/internal/type_matchers.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+TEST(UtilsTest, Clamp01) {
+ EXPECT_FLOAT_EQ(Clamp01(-2), 0);
+ EXPECT_FLOAT_EQ(Clamp01(0), 0);
+ EXPECT_FLOAT_EQ(Clamp01(.3), .3);
+ EXPECT_FLOAT_EQ(Clamp01(.7), .7);
+ EXPECT_FLOAT_EQ(Clamp01(1), 1);
+ EXPECT_FLOAT_EQ(Clamp01(1.1), 1);
+}
+
+TEST(UtilsTest, Normalize01) {
+ EXPECT_FLOAT_EQ(Normalize01(1, 2, 1.5), .5);
+ EXPECT_FLOAT_EQ(Normalize01(7, 3, 4), .75);
+ EXPECT_FLOAT_EQ(Normalize01(-1, 1, 2), 1);
+ EXPECT_FLOAT_EQ(Normalize01(1, 1, 1), 0);
+ EXPECT_FLOAT_EQ(Normalize01(1, 1, 0), 0);
+ EXPECT_FLOAT_EQ(Normalize01(1, 1, 2), 1);
+}
+
+TEST(UtilsTest, InterpFloat) {
+ EXPECT_FLOAT_EQ(Interp(5, 10, .2), 6);
+ EXPECT_FLOAT_EQ(Interp(10, -2, .75), 1);
+ EXPECT_FLOAT_EQ(Interp(-1, 2, -3), -1);
+ EXPECT_FLOAT_EQ(Interp(5, 7, 20), 7);
+}
+
+TEST(UtilsTest, InterpVec2) {
+ EXPECT_THAT(Interp(Vec2{1, 2}, {3, 5}, .5), Vec2Eq({2, 3.5}));
+ EXPECT_THAT(Interp(Vec2{-5, 5}, {-15, 0}, .4), Vec2Eq({-9, 3}));
+ EXPECT_THAT(Interp(Vec2{7, 9}, {25, 30}, -.1), Vec2Eq({7, 9}));
+ EXPECT_THAT(Interp(Vec2{12, 5}, {13, 14}, 3.2), Vec2Eq({13, 14}));
+}
+
+TEST(UtilsTest, InterpAngle) {
+ EXPECT_NEAR(InterpAngle(.25 * M_PI, .5 * M_PI, .4), .35 * M_PI, 1e-6);
+ EXPECT_NEAR(InterpAngle(1.05 * M_PI, .25 * M_PI, .5), .65 * M_PI, 1e-6);
+ EXPECT_NEAR(InterpAngle(.25 * M_PI, 1.75 * M_PI, .1), .2 * M_PI, 1e-6);
+ EXPECT_NEAR(InterpAngle(.25 * M_PI, 1.75 * M_PI, .7), 1.9 * M_PI, 1e-6);
+ EXPECT_NEAR(InterpAngle(1.6 * M_PI, .4 * M_PI, .25), 1.8 * M_PI, 1e-6);
+ EXPECT_NEAR(InterpAngle(1.6 * M_PI, .4 * M_PI, .625), .1 * M_PI, 1e-6);
+}
+
+TEST(UtilsTest, Distance) {
+ EXPECT_FLOAT_EQ(Distance({0, 0}, {1, 0}), 1);
+ EXPECT_FLOAT_EQ(Distance({1, 1}, {-2, 5}), 5);
+}
+
+TEST(UtilsTest, NearestPointOnSegment) {
+ EXPECT_FLOAT_EQ(NearestPointOnSegment({0, 0}, {1, 0}, {.25, .5}), .25);
+ EXPECT_FLOAT_EQ(NearestPointOnSegment({3, 4}, {5, 6}, {-1, -1}), 0);
+ EXPECT_FLOAT_EQ(NearestPointOnSegment({20, 10}, {10, 5}, {2, 2}), 1);
+ EXPECT_FLOAT_EQ(NearestPointOnSegment({0, 5}, {5, 0}, {3, 3}), .5);
+}
+
+TEST(UtilsTest, NearestPointOnSegmentDegenerateCase) {
+ EXPECT_FLOAT_EQ(NearestPointOnSegment({0, 0}, {0, 0}, {5, 10}), 0);
+ EXPECT_FLOAT_EQ(NearestPointOnSegment({3, 7}, {3, 7}, {0, -20}), 0);
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/validation.h b/ink_stroke_modeler/internal/validation.h
new file mode 100644
index 0000000..e01aa1a
--- /dev/null
+++ b/ink_stroke_modeler/internal/validation.h
@@ -0,0 +1,48 @@
+#ifndef INK_STROKE_MODELER_INTERNAL_VALIDATION_H_
+#define INK_STROKE_MODELER_INTERNAL_VALIDATION_H_
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+
+template <typename T>
+absl::Status ValidateIsFiniteNumber(T value, absl::string_view label) {
+ if (std::isnan(value)) {
+ return absl::InvalidArgumentError(absl::Substitute("$0 is NaN", label));
+ }
+ if (std::isinf(value)) {
+ return absl::InvalidArgumentError(
+ absl::Substitute("$0 is infinite", label));
+ }
+ return absl::OkStatus();
+}
+
+template <typename T>
+absl::Status ValidateGreaterThanZero(T value, absl::string_view label) {
+ if (absl::Status status = ValidateIsFiniteNumber(value, label);
+ !status.ok()) {
+ return status;
+ }
+ if (value <= 0) {
+ return absl::InvalidArgumentError(absl::Substitute(
+ "$0 must be greater than zero. Actual value: $1", label, value));
+ }
+ return absl::OkStatus();
+}
+
+template <typename T>
+absl::Status ValidateGreaterThanOrEqualToZero(T value,
+ absl::string_view label) {
+ if (absl::Status status = ValidateIsFiniteNumber(value, label);
+ !status.ok()) {
+ return status;
+ }
+ if (value < 0) {
+ return absl::InvalidArgumentError(absl::Substitute(
+ "$0 must be greater than or equal to zero. Actual value: $1", label,
+ value));
+ }
+ return absl::OkStatus();
+}
+
+#endif // INK_STROKE_MODELER_INTERNAL_VALIDATION_H_
diff --git a/ink_stroke_modeler/internal/validation_test.cc b/ink_stroke_modeler/internal/validation_test.cc
new file mode 100644
index 0000000..51482ff
--- /dev/null
+++ b/ink_stroke_modeler/internal/validation_test.cc
@@ -0,0 +1,82 @@
+#include "ink_stroke_modeler/internal/validation.h"
+
+#include <cmath>
+
+#include "gtest/gtest.h"
+
+namespace {
+
+TEST(ValidateIsFiniteNumberTest, AcceptFinite) {
+ ASSERT_TRUE(ValidateIsFiniteNumber(1, "foo").ok());
+}
+
+TEST(ValidateIsFiniteNumberTest, RejectNan) {
+ absl::Status status = ValidateIsFiniteNumber(NAN, "foo");
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.message(), "foo is NaN");
+}
+
+TEST(ValidateIsFiniteNumberTest, RejectInf) {
+ absl::Status status = ValidateIsFiniteNumber(INFINITY, "foo");
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.message(), "foo is infinite");
+}
+
+TEST(ValidateGreaterThanZeroTest, AcceptPositive) {
+ ASSERT_TRUE(ValidateGreaterThanZero(1, "foo").ok());
+}
+
+TEST(ValidateGreaterThanZeroTest, RejectZero) {
+ absl::Status status = ValidateGreaterThanZero(0, "foo");
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.message(), "foo must be greater than zero. Actual value: 0");
+}
+
+TEST(ValidateGreaterThanZeroTest, RejectNegative) {
+ absl::Status status = ValidateGreaterThanZero(-1, "foo");
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.message(),
+ "foo must be greater than zero. Actual value: -1");
+}
+
+TEST(ValidateGreaterThanZeroTest, RejectNan) {
+ absl::Status status = ValidateGreaterThanZero(NAN, "foo");
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.message(), "foo is NaN");
+}
+
+TEST(ValidateGreaterThanZeroTest, RejectInf) {
+ absl::Status status = ValidateGreaterThanZero(INFINITY, "foo");
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.message(), "foo is infinite");
+}
+
+TEST(ValidateGreaterThanOrEqualToZeroTest, AcceptPositive) {
+ ASSERT_TRUE(ValidateGreaterThanOrEqualToZero(1, "foo").ok());
+}
+
+TEST(ValidateGreaterThanOrEqualToZeroTest, AcceptZero) {
+ absl::Status status = ValidateGreaterThanOrEqualToZero(0, "foo");
+ ASSERT_TRUE(ValidateGreaterThanOrEqualToZero(0, "foo").ok());
+}
+
+TEST(ValidateGreaterThanOrEqualToZeroTest, RejectNegative) {
+ absl::Status status = ValidateGreaterThanOrEqualToZero(-1, "foo");
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.message(),
+ "foo must be greater than or equal to zero. Actual value: -1");
+}
+
+TEST(ValidateGreaterThanOrEqualToZeroTest, RejectNan) {
+ absl::Status status = ValidateGreaterThanOrEqualToZero(NAN, "foo");
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.message(), "foo is NaN");
+}
+
+TEST(ValidateGreaterThanOrEqualToZeroTest, RejectInf) {
+ absl::Status status = ValidateGreaterThanOrEqualToZero(INFINITY, "foo");
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.message(), "foo is infinite");
+}
+
+} // namespace
diff --git a/ink_stroke_modeler/internal/wobble_smoother.cc b/ink_stroke_modeler/internal/wobble_smoother.cc
new file mode 100644
index 0000000..a403ecb
--- /dev/null
+++ b/ink_stroke_modeler/internal/wobble_smoother.cc
@@ -0,0 +1,70 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/wobble_smoother.h"
+
+#include <algorithm>
+
+#include "ink_stroke_modeler/internal/utils.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+void WobbleSmoother::Reset(const WobbleSmootherParams& params, Vec2 position,
+ Time time) {
+ params_ = params;
+ samples_.clear();
+ // Initialize with the "fast" speed -- otherwise, we'll lag behind at the
+ // start of the stroke.
+ position_sum_ = position;
+ speed_sum_ = params_.speed_ceiling;
+ samples_.push_back(
+ {.position = position, .speed = params_.speed_ceiling, .time = time});
+}
+
+Vec2 WobbleSmoother::Update(Vec2 position, Time time) {
+ // The moving average acts as a low-pass signal filter, removing
+ // high-frequency fluctuations in the position caused by the discrete nature
+ // of the touch digitizer. To compensate for the distance between the average
+ // position and the actual position, we interpolate between them, based on
+ // speed, to determine the position to use for the input model.
+ float distance = Distance(position, samples_.back().position);
+ Duration delta_time = time - samples_.back().time;
+ float speed = 0;
+ if (delta_time == Duration(0)) {
+ // We're going to assume that you're not actually moving infinitely fast.
+ speed = std::max(params_.speed_ceiling, speed_sum_ / samples_.size());
+ } else {
+ speed = distance / delta_time.Value();
+ }
+
+ samples_.push_back({.position = position, .speed = speed, .time = time});
+ position_sum_ += position;
+ speed_sum_ += speed;
+ while (samples_.front().time < time - params_.timeout) {
+ position_sum_ -= samples_.front().position;
+ speed_sum_ -= samples_.front().speed;
+ samples_.pop_front();
+ }
+
+ Vec2 avg_position = position_sum_ / samples_.size();
+ float avg_speed = speed_sum_ / samples_.size();
+ return Interp(
+ avg_position, position,
+ Normalize01(params_.speed_floor, params_.speed_ceiling, avg_speed));
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/internal/wobble_smoother.h b/ink_stroke_modeler/internal/wobble_smoother.h
new file mode 100644
index 0000000..7eb85b8
--- /dev/null
+++ b/ink_stroke_modeler/internal/wobble_smoother.h
@@ -0,0 +1,57 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_INTERNAL_WOBBLE_SMOOTHER_H_
+#define INK_STROKE_MODELER_INTERNAL_WOBBLE_SMOOTHER_H_
+
+#include <deque>
+
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// This class smooths "wobble" in input positions from high-frequency noise. It
+// does so by maintaining a moving average of the positions, and interpolating
+// between the given input and the moving average based on how quickly it's
+// moving. When moving at a speed above the ceiling in the WobbleSmootherParams,
+// the result will be the unmodified input; when moving at a speed below the
+// floor, the result will be the moving average.
+class WobbleSmoother {
+ public:
+ void Reset(const WobbleSmootherParams &params, Vec2 position, Time time);
+
+ // Updates the average position and speed, and returns the smoothed position.
+ Vec2 Update(Vec2 position, Time time);
+
+ private:
+ struct Sample {
+ Vec2 position{0};
+ float speed{0};
+ Time time{0};
+ };
+ std::deque<Sample> samples_;
+ Vec2 position_sum_{0};
+ float speed_sum_{0};
+
+ WobbleSmootherParams params_;
+};
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_INTERNAL_WOBBLE_SMOOTHER_H_
diff --git a/ink_stroke_modeler/internal/wobble_smoother_test.cc b/ink_stroke_modeler/internal/wobble_smoother_test.cc
new file mode 100644
index 0000000..7bffd37
--- /dev/null
+++ b/ink_stroke_modeler/internal/wobble_smoother_test.cc
@@ -0,0 +1,104 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/wobble_smoother.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "ink_stroke_modeler/internal/type_matchers.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+const WobbleSmootherParams kDefaultParams{
+ .timeout = Duration(.04), .speed_floor = 1.31, .speed_ceiling = 1.44};
+
+TEST(WobbleSmootherTest, SlowStraightLine) {
+ // The line moves at 1 cm/s, which is below the floor of 1.31 cm/s.
+ WobbleSmoother filter;
+ filter.Reset(kDefaultParams, {3, 4}, Time{1});
+ EXPECT_THAT(filter.Update({3.016, 4}, Time{1.016}), Vec2Eq({3.008, 4}));
+ EXPECT_THAT(filter.Update({3.032, 4}, Time{1.032}), Vec2Eq({3.016, 4}));
+ EXPECT_THAT(filter.Update({3.048, 4}, Time{1.048}), Vec2Eq({3.032, 4}));
+ EXPECT_THAT(filter.Update({3.064, 4}, Time{1.064}), Vec2Eq({3.048, 4}));
+}
+
+TEST(WobbleSmootherTest, SlowStraightLineEqualFloorAndCeiling) {
+ // The line moves at 1 cm/s, which is below the floor of 1.31 cm/s.
+ WobbleSmootherParams equal_floor_and_ceiling_params{
+ .timeout = Duration(.04), .speed_floor = 1.31, .speed_ceiling = 1.31};
+ WobbleSmoother filter;
+ filter.Reset(equal_floor_and_ceiling_params, {3, 4}, Time{1});
+ EXPECT_THAT(filter.Update({3.016, 4}, Time{1.016}), Vec2Eq({3.008, 4}));
+ EXPECT_THAT(filter.Update({3.032, 4}, Time{1.032}), Vec2Eq({3.016, 4}));
+ EXPECT_THAT(filter.Update({3.048, 4}, Time{1.048}), Vec2Eq({3.032, 4}));
+ EXPECT_THAT(filter.Update({3.064, 4}, Time{1.064}), Vec2Eq({3.048, 4}));
+}
+
+TEST(WobbleSmootherTest, FastStraightLine) {
+ // The line moves at 1.5 cm/s, which is above the ceiling of 1.44 cm/s.
+ WobbleSmoother filter;
+ filter.Reset(kDefaultParams, {-1, 0}, Time{0});
+ EXPECT_THAT(filter.Update({-1, .024}, Time{.016}), Vec2Eq({-1, .024}));
+ EXPECT_THAT(filter.Update({-1, .048}, Time{.032}), Vec2Eq({-1, .048}));
+ EXPECT_THAT(filter.Update({-1, .072}, Time{.048}), Vec2Eq({-1, .072}));
+}
+
+TEST(WobbleSmootherTest, FastStraightLineEqualFloorAndCeiling) {
+ // The line moves at 1.5 cm/s, which is above the ceiling of 1.44 cm/s.
+ WobbleSmoother filter;
+ WobbleSmootherParams equal_floor_and_ceiling_params{
+ .timeout = Duration(.04), .speed_floor = 1.41, .speed_ceiling = 1.41};
+ filter.Reset(equal_floor_and_ceiling_params, {-1, 0}, Time{0});
+ EXPECT_THAT(filter.Update({-1, .024}, Time{.016}), Vec2Eq({-1, .024}));
+ EXPECT_THAT(filter.Update({-1, .048}, Time{.032}), Vec2Eq({-1, .048}));
+ EXPECT_THAT(filter.Update({-1, .072}, Time{.048}), Vec2Eq({-1, .072}));
+}
+
+TEST(WobbleSmootherTest, SlowZigZag) {
+ // The line moves at 1 cm/s, which is below the floor of 1.31 cm/s.
+ WobbleSmoother filter;
+ filter.Reset(kDefaultParams, {1, 2}, Time{5});
+ EXPECT_THAT(filter.Update({1.016, 2}, Time{5.016}), Vec2Eq({1.008, 2}));
+ EXPECT_THAT(filter.Update({1.016, 2.016}, Time{5.032}),
+ Vec2Eq({1.0106667, 2.0053333}));
+ EXPECT_THAT(filter.Update({1.032, 2.016}, Time{5.048}),
+ Vec2Eq({1.0213333, 2.0106667}));
+ EXPECT_THAT(filter.Update({1.032, 2.032}, Time{5.064}),
+ Vec2Eq({1.0266667, 2.0213333}));
+ EXPECT_THAT(filter.Update({1.048, 2.032}, Time{5.080}),
+ Vec2Eq({1.0373333, 2.0266667}));
+ EXPECT_THAT(filter.Update({1.048, 2.048}, Time{5.096}),
+ Vec2Eq({1.0426667, 2.0373333}));
+}
+
+TEST(WobbleSmootherTest, FastZigZag) {
+ // The line moves at 1.5 cm/s, which is above the ceiling of 1.44 cm/s.
+ WobbleSmoother filter;
+ filter.Reset(kDefaultParams, {7, 3}, Time{8});
+ EXPECT_THAT(filter.Update({7, 3.024}, Time{8.016}), Vec2Eq({7, 3.024}));
+ EXPECT_THAT(filter.Update({7.024, 3.024}, Time{8.032}),
+ Vec2Eq({7.024, 3.024}));
+ EXPECT_THAT(filter.Update({7.024, 3.048}, Time{8.048}),
+ Vec2Eq({7.024, 3.048}));
+ EXPECT_THAT(filter.Update({7.048, 3.048}, Time{8.064}),
+ Vec2Eq({7.048, 3.048}));
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/params.cc b/ink_stroke_modeler/params.cc
new file mode 100644
index 0000000..99542e1
--- /dev/null
+++ b/ink_stroke_modeler/params.cc
@@ -0,0 +1,157 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/params.h"
+
+#include <cmath>
+
+#include "absl/status/status.h"
+#include "absl/strings/string_view.h"
+#include "absl/strings/substitute.h"
+#include "absl/types/variant.h"
+#include "ink_stroke_modeler/internal/validation.h"
+
+// This convenience macro evaluates the given expression, and if it does not
+// return an OK status, returns and propagates the status.
+#define RETURN_IF_ERROR(expr) \
+ do { \
+ if (auto status = (expr); !status.ok()) return status; \
+ } while (false)
+
+namespace ink {
+namespace stroke_model {
+
+absl::Status ValidatePositionModelerParams(
+ const PositionModelerParams& params) {
+ RETURN_IF_ERROR(ValidateGreaterThanZero(params.spring_mass_constant,
+ "PredictionParams::spring_mass"));
+ return ValidateGreaterThanZero(params.drag_constant,
+ "PredictionParams::drag_ratio");
+}
+
+absl::Status ValidateSamplingParams(const SamplingParams& params) {
+ RETURN_IF_ERROR(ValidateGreaterThanZero(params.min_output_rate,
+ "PredictionParams::min_output_rate"));
+ RETURN_IF_ERROR(ValidateGreaterThanZero(
+ params.end_of_stroke_stopping_distance,
+ "PredictionParams::end_of_stroke_stopping_distance"));
+ return ValidateGreaterThanZero(
+ params.end_of_stroke_max_iterations,
+ "PredictionParams::end_of_stroke_stopping_distance");
+}
+
+absl::Status ValidateStylusStateModelerParams(
+ const StylusStateModelerParams& params) {
+ return ValidateGreaterThanZero(params.max_input_samples,
+ "StylusStateModelerParams::max_input_samples");
+}
+
+absl::Status ValidateWobbleSmootherParams(const WobbleSmootherParams& params) {
+ RETURN_IF_ERROR(ValidateGreaterThanOrEqualToZero(
+ params.timeout.Value(), "WobbleSmootherParams::timeout"));
+ RETURN_IF_ERROR(ValidateGreaterThanOrEqualToZero(
+ params.speed_floor, "WobbleSmootherParams::speed_floor"));
+ RETURN_IF_ERROR(ValidateIsFiniteNumber(
+ params.speed_ceiling, "WobbleSmootherParams::speed_ceiling"));
+ if (params.speed_ceiling < params.speed_floor) {
+ return absl::InvalidArgumentError(absl::Substitute(
+ "WobbleSmootherParams::speed_ceiling must be greater than or "
+ "equal to WobbleSmootherParams::speed_floor ($0). Actual "
+ "value: $1",
+ params.speed_floor, params.speed_ceiling));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ValidatePredictionParams(const PredictionParams& params) {
+ if (absl::holds_alternative<StrokeEndPredictorParams>(params)) {
+ // Nothing to validate.
+ return absl::OkStatus();
+ }
+
+ const KalmanPredictorParams& kalman_params =
+ absl::get<KalmanPredictorParams>(params);
+ RETURN_IF_ERROR(ValidateGreaterThanZero(
+ kalman_params.process_noise, "KalmanPredictorParams::process_noise"));
+ RETURN_IF_ERROR(
+ ValidateGreaterThanZero(kalman_params.measurement_noise,
+ "KalmanPredictorParams::measurement_noise"));
+ RETURN_IF_ERROR(
+ ValidateGreaterThanZero(kalman_params.min_stable_iteration,
+ "KalmanPredictorParams::min_stable_iteration"));
+ RETURN_IF_ERROR(
+ ValidateGreaterThanZero(kalman_params.max_time_samples,
+ "KalmanPredictorParams::max_time_samples"));
+ RETURN_IF_ERROR(
+ ValidateGreaterThanZero(kalman_params.min_catchup_velocity,
+ "KalmanPredictorParams::min_catchup_velocity"));
+ RETURN_IF_ERROR(
+ ValidateIsFiniteNumber(kalman_params.acceleration_weight,
+ "KalmanPredictorParams::acceleration_weight"));
+ RETURN_IF_ERROR(ValidateIsFiniteNumber(kalman_params.jerk_weight,
+ "KalmanPredictorParams::jerk_weight"));
+ RETURN_IF_ERROR(
+ ValidateGreaterThanZero(kalman_params.prediction_interval.Value(),
+ "KalmanPredictorParams::jerk_weight"));
+
+ const KalmanPredictorParams::ConfidenceParams& confidence_params =
+ kalman_params.confidence_params;
+ RETURN_IF_ERROR(ValidateGreaterThanZero(
+ confidence_params.desired_number_of_samples,
+ "KalmanPredictorParams::ConfidenceParams::desired_number_of_samples"));
+ RETURN_IF_ERROR(ValidateGreaterThanZero(
+ confidence_params.max_estimation_distance,
+ "KalmanPredictorParams::ConfidenceParams::max_estimation_distance"));
+ RETURN_IF_ERROR(ValidateGreaterThanOrEqualToZero(
+ confidence_params.min_travel_speed,
+ "KalmanPredictorParams::ConfidenceParams::min_travel_speed"));
+ RETURN_IF_ERROR(ValidateIsFiniteNumber(
+ confidence_params.max_travel_speed,
+ "KalmanPredictorParams::ConfidenceParams::max_travel_speed"));
+ if (confidence_params.max_travel_speed < confidence_params.min_travel_speed) {
+ return absl::InvalidArgumentError(
+ absl::Substitute("KalmanPredictorParams::ConfidenceParams::max_"
+ "travel_speed must be greater than or equal to "
+ "KalmanPredictorParams::ConfidenceParams::min_"
+ "travel_speed ($0). Actual value: $1",
+ confidence_params.min_travel_speed,
+ confidence_params.max_travel_speed));
+ }
+ RETURN_IF_ERROR(ValidateGreaterThanZero(
+ confidence_params.max_linear_deviation,
+ "KalmanPredictorParams::ConfidenceParams::max_linear_deviation"));
+ if (confidence_params.baseline_linearity_confidence < 0 ||
+ confidence_params.baseline_linearity_confidence > 1) {
+ return absl::InvalidArgumentError(absl::Substitute(
+ "KalmanPredictorParams::ConfidenceParams::baseline_linearity_"
+ "confidence must lie in the interval [0, 1]. Actual value: $0",
+ confidence_params.baseline_linearity_confidence));
+ }
+ return absl::OkStatus();
+}
+
+absl::Status ValidateStrokeModelParams(const StrokeModelParams& params) {
+ RETURN_IF_ERROR(ValidateWobbleSmootherParams(params.wobble_smoother_params));
+ RETURN_IF_ERROR(
+ ValidatePositionModelerParams(params.position_modeler_params));
+ RETURN_IF_ERROR(ValidateSamplingParams(params.sampling_params));
+ RETURN_IF_ERROR(
+ ValidateStylusStateModelerParams(params.stylus_state_modeler_params));
+ return ValidatePredictionParams(params.prediction_params);
+}
+
+} // namespace stroke_model
+} // namespace ink
+
+#undef RETURN_IF_ERROR
diff --git a/ink_stroke_modeler/params.h b/ink_stroke_modeler/params.h
new file mode 100644
index 0000000..102edba
--- /dev/null
+++ b/ink_stroke_modeler/params.h
@@ -0,0 +1,224 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_PARAMS_H_
+#define INK_STROKE_MODELER_PARAMS_H_
+
+#include "absl/status/status.h"
+#include "absl/types/variant.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// These structs contain parameters for tuning the behavior of the stroke
+// modeler.
+//
+// The stroke modeler is unit-agnostic, in both time and space. That is, the
+// stroke modeler does not know or care whether the inputs and parameters are
+// specified in feet and minutes, meters and seconds, or millimeters and years.
+// As such, instead of referring to specific units, we refer to "unit distance"
+// and "unit time".
+//
+// These parameters will need to be "tuned" to your use case. Because of this,
+// and because of the modeler's unit-agnosticism, it's impossible to define
+// "reasonable" default values for many of the parameters -- these parameters
+// instead default to -1, which will cause the validation functions to return an
+// error.
+//
+// Where possible, we've indicated what a good starting point for tuning might
+// be, but you'll likely need to adjust these for best results.
+
+// These parameters are used for modeling the position of the pen.
+struct PositionModelerParams {
+ // The mass of the "weight" being pulled along the path, multiplied by the
+ // spring constant.
+ float spring_mass_constant = 11.f / 32400;
+
+ // The ratio of the pen's velocity that is subtracted from the pen's
+ // acceleration, to simulate drag.
+ float drag_constant = 72.f;
+};
+
+// These parameters are used for sampling.
+struct SamplingParams {
+ // The minimum number of modeled inputs to output per unit time. If inputs are
+ // received at a lower rate, they will be upsampled to produce output of at
+ // least min_output_rate. If inputs are received at a higher rate, the
+ // output rate will match the input rate.
+ double min_output_rate = -1;
+
+ // This determines stop condition for end-of-stroke modeling; if the position
+ // is within this distance of the final raw input, or if the last update
+ // iteration moved less than this distance, it stop iterating.
+ //
+ // This should be a small distance; a good starting point is 2-3 orders of
+ // magnitude smaller than the expected distance between input points.
+ float end_of_stroke_stopping_distance = -1;
+
+ // The maximum number of iterations to perform at the end of the stroke, if it
+ // does not stop due to the constraints of end_of_stroke_stopping_distance.
+ int end_of_stroke_max_iterations = 20;
+};
+
+// These parameters are used modeling the state of the stylus once the position
+// has been modeled.
+struct StylusStateModelerParams {
+ // The maximum number of raw inputs to look at when finding the nearest states
+ // for interpolation.
+ int max_input_samples = 10;
+};
+
+// These parameters are used for applying smoothing to the input to reduce
+// wobble in the prediction.
+struct WobbleSmootherParams {
+ // The length of the window over which the moving average of speed and
+ // position are calculated.
+ //
+ // A good starting point is 2.5 divided by the expected number of inputs per
+ // unit time.
+ Duration timeout{-1};
+
+ // The range of speeds considered for wobble smoothing. At speed_floor, the
+ // maximum amount of smoothing is applied. At speed_ceiling, no smoothing is
+ // applied.
+ //
+ // Good starting points are 2% and 3% of the expected speed of the inputs.
+ float speed_floor = -1;
+ float speed_ceiling = -1;
+};
+
+// This struct indicates the "stroke end" prediction strategy should be used,
+// which models a prediction as though the last seen input was the
+// end-of-stroke. There aren't actually any tunable parameters for this; it uses
+// the same PositionModelerParams and SamplingParams as the overall model. Note
+// that this "prediction" doesn't actually predict substantially into the
+// future, it only allows for very quickly "catching up" to the position of the
+// raw input.
+struct StrokeEndPredictorParams {};
+
+// This struct indicates that the Kalman filter-based prediction strategy should
+// be used, and provides the parameters for tuning it.
+//
+// Unlike the "stroke end" predictor, this strategy can predict an extension
+// of the stroke beyond the last Input position, in addition to the "catch up"
+// step.
+struct KalmanPredictorParams {
+ // The variance of the noise inherent to the stroke itself.
+ double process_noise = -1;
+
+ // The variance of the noise that rises from errors in measurement of the
+ // stroke.
+ double measurement_noise = -1;
+
+ // The minimum number of inputs received before the Kalman predictor is
+ // considered stable enough to make a prediction.
+ int min_stable_iteration = 4;
+
+ // The Kalman filter assumes that input is received in uniform time steps, but
+ // this is not always the case. We hold on to the most recent input timestamps
+ // for use in calculating the correction for this. This determines the maximum
+ // number of timestamps to save.
+ int max_time_samples = 20;
+
+ // The minimum allowed velocity of the "catch up" portion of the prediction,
+ // which covers the distance between the last Result (the last corrected
+ // position) and the
+ //
+ // A good starting point is 3 orders of magnitude smaller than the expected
+ // speed of the inputs.
+ float min_catchup_velocity = -1;
+
+ // These weights are applied to the acceleration (x²) and jerk (x³) terms of
+ // the cubic prediction polynomial. The closer they are to zero, the more
+ // linear the prediction will be.
+ float acceleration_weight = .5;
+ float jerk_weight = .1;
+
+ // This value is a hint to the predictor, indicating the desired duration of
+ // of the portion of the prediction extending beyond the position of the last
+ // input. The actual duration of that portion of the prediction may be less
+ // than this, based on the predictor's confidence, but it will never be
+ // greater.
+ Duration prediction_interval{-1};
+
+ // The Kalman predictor uses several heuristics to evaluate confidence in the
+ // prediction. Each heuristic produces a confidence value between 0 and 1, and
+ // then we take their product as the total confidence.
+ // These parameters may be used to tune those heuristics.
+ struct ConfidenceParams {
+ // The first heuristic simply increases confidence as we receive more sample
+ // (i.e. input points). It evaluates to 0 at no samples, and 1 at
+ // desired_number_of_samples.
+ int desired_number_of_samples = 20;
+
+ // The second heuristic is based on the distance between the last sample
+ // and the current estimate. If the distance is 0, it evaluates to 1, and if
+ // the distance is greater than or equal to max_estimation_distance, it
+ // evaluates to 0.
+ //
+ // A good starting point is 1.5 times measurement_noise.
+ float max_estimation_distance = -1;
+
+ // The third heuristic is based on the speed of the prediction, which is
+ // approximated by measuring the from the start of the prediction to the
+ // projected endpoint (if it were extended for the full
+ // prediction_interval). It evaluates to 0 at min_travel_speed, and 1
+ // at max_travel_speed.
+ //
+ // Good starting points are 5% and 25% of the expected speed of the inputs.
+ float min_travel_speed = -1;
+ float max_travel_speed = -1;
+
+ // The fourth heuristic is based on the linearity of the prediction, which
+ // is approximated by comparing the endpoint of the prediction with the
+ // endpoint of a linear prediction (again, extended for the full
+ // prediction_interval). It evaluates to 1 at zero distance, and
+ // baseline_linearity_confidence at a distance of max_linear_deviation.
+ //
+ // A good starting point is an 10 times the measurement_noise.
+ float max_linear_deviation = -1;
+ float baseline_linearity_confidence = .4;
+ };
+ ConfidenceParams confidence_params;
+};
+using PredictionParams =
+ absl::variant<StrokeEndPredictorParams, KalmanPredictorParams>;
+
+// This convenience struct is a collection of the parameters for the individual
+// parameter structs.
+struct StrokeModelParams {
+ WobbleSmootherParams wobble_smoother_params;
+ PositionModelerParams position_modeler_params;
+ SamplingParams sampling_params;
+ StylusStateModelerParams stylus_state_modeler_params;
+ PredictionParams prediction_params = StrokeEndPredictorParams{};
+};
+
+// These validation functions will return an error if the given parameters are
+// invalid.
+absl::Status ValidatePositionModelerParams(const PositionModelerParams& params);
+absl::Status ValidateSamplingParams(const SamplingParams& params);
+absl::Status ValidateStylusStateModelerParams(
+ const StylusStateModelerParams& params);
+absl::Status ValidateWobbleSmootherParams(const WobbleSmootherParams& params);
+absl::Status ValidatePredictionParams(const PredictionParams& params);
+absl::Status ValidateStrokeModelParams(const StrokeModelParams& params);
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_PARAMS_H_
diff --git a/ink_stroke_modeler/params_test.cc b/ink_stroke_modeler/params_test.cc
new file mode 100644
index 0000000..15bad8a
--- /dev/null
+++ b/ink_stroke_modeler/params_test.cc
@@ -0,0 +1,259 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/params.h"
+
+#include <limits>
+
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+const KalmanPredictorParams kGoodKalmanParams{
+ .process_noise = .01,
+ .measurement_noise = .1,
+ .min_stable_iteration = 2,
+ .max_time_samples = 10,
+ .min_catchup_velocity = 1,
+ .acceleration_weight = -1,
+ .jerk_weight = 200,
+ .prediction_interval{Duration(1)},
+ .confidence_params{.desired_number_of_samples = 10,
+ .max_estimation_distance = 1,
+ .min_travel_speed = 6,
+ .max_travel_speed = 50,
+ .max_linear_deviation = 2,
+ .baseline_linearity_confidence = .5}};
+
+const StrokeModelParams kGoodStrokeModelParams{
+ .wobble_smoother_params{
+ .timeout = Duration(.5), .speed_floor = 1, .speed_ceiling = 20},
+ .position_modeler_params{.spring_mass_constant = .2, .drag_constant = 4},
+ .sampling_params{.min_output_rate = 3,
+ .end_of_stroke_stopping_distance = 1e-6,
+ .end_of_stroke_max_iterations = 1},
+ .stylus_state_modeler_params{.max_input_samples = 7},
+ .prediction_params = StrokeEndPredictorParams{}};
+
+TEST(ParamsTest, ValidatePositionModelerParams) {
+ EXPECT_TRUE(ValidatePositionModelerParams(
+ {.spring_mass_constant = 1, .drag_constant = 3})
+ .ok());
+
+ EXPECT_EQ(ValidatePositionModelerParams(
+ {.spring_mass_constant = 0, .drag_constant = 1})
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+ EXPECT_EQ(ValidatePositionModelerParams(
+ {.spring_mass_constant = 1, .drag_constant = 0})
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+}
+
+TEST(ParamsTest, ValidateSamplingParams) {
+ EXPECT_TRUE(ValidateSamplingParams({.min_output_rate = 10,
+ .end_of_stroke_stopping_distance = .1,
+ .end_of_stroke_max_iterations = 3})
+ .ok());
+
+ EXPECT_EQ(ValidateSamplingParams({.min_output_rate = 0,
+ .end_of_stroke_stopping_distance = .1,
+ .end_of_stroke_max_iterations = 3})
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+ EXPECT_EQ(ValidateSamplingParams({.min_output_rate = 1,
+ .end_of_stroke_stopping_distance = 0,
+ .end_of_stroke_max_iterations = 3})
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+ EXPECT_EQ(ValidateSamplingParams({.min_output_rate = 1,
+ .end_of_stroke_stopping_distance = 5,
+ .end_of_stroke_max_iterations = 0})
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+}
+
+TEST(ParamsTest, ValidateStylusStateModelerParams) {
+ EXPECT_TRUE(ValidateStylusStateModelerParams({.max_input_samples = 1}).ok());
+
+ EXPECT_EQ(ValidateStylusStateModelerParams({.max_input_samples = 0}).code(),
+ absl::StatusCode::kInvalidArgument);
+}
+
+TEST(ParamsTest, ValidateWobbleSmootherParams) {
+ EXPECT_TRUE(
+ ValidateWobbleSmootherParams(
+ {.timeout = Duration(1), .speed_floor = 2, .speed_ceiling = 3})
+ .ok());
+ EXPECT_TRUE(
+ ValidateWobbleSmootherParams(
+ {.timeout = Duration(0), .speed_floor = 0, .speed_ceiling = 0})
+ .ok());
+
+ EXPECT_EQ(ValidateWobbleSmootherParams(
+ {.timeout = Duration(-1), .speed_floor = 2, .speed_ceiling = 5})
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+ EXPECT_EQ(ValidateWobbleSmootherParams(
+ {.timeout = Duration(1), .speed_floor = -2, .speed_ceiling = 1})
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+ EXPECT_EQ(ValidateWobbleSmootherParams(
+ {.timeout = Duration(1), .speed_floor = 7, .speed_ceiling = 4})
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+}
+
+TEST(ParamsTest, ValidateStrokeEndPredictorParams) {
+ EXPECT_TRUE(ValidatePredictionParams(StrokeEndPredictorParams()).ok());
+}
+
+TEST(ParamsTest, ValidateKalmanPredictorParams) {
+ EXPECT_TRUE(ValidatePredictionParams(kGoodKalmanParams).ok());
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.process_noise = 0;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.measurement_noise = 0;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.min_stable_iteration = 0;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.max_time_samples = 0;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.prediction_interval = Duration(0);
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+}
+
+TEST(ParamsTest, ValidateKalmanPredictorConfidenceParams) {
+ EXPECT_TRUE(ValidatePredictionParams(kGoodKalmanParams).ok());
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.confidence_params.desired_number_of_samples = 0;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.confidence_params.max_estimation_distance = 0;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.confidence_params.min_travel_speed = -1;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.confidence_params.min_travel_speed = 10;
+ bad_params.confidence_params.max_travel_speed = 1;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.confidence_params.max_linear_deviation = 0;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.confidence_params.baseline_linearity_confidence = -.3;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodKalmanParams;
+ bad_params.confidence_params.baseline_linearity_confidence = 1.01;
+ EXPECT_EQ(ValidatePredictionParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+}
+
+TEST(ParamsTest, ValidateStrokeModelParams) {
+ EXPECT_TRUE(ValidateStrokeModelParams(kGoodStrokeModelParams).ok());
+ {
+ auto bad_params = kGoodStrokeModelParams;
+ bad_params.wobble_smoother_params.timeout = Duration(-10);
+ EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodStrokeModelParams;
+ bad_params.position_modeler_params.spring_mass_constant = -1;
+ EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodStrokeModelParams;
+ bad_params.stylus_state_modeler_params.max_input_samples = 0;
+ EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodStrokeModelParams;
+ bad_params.sampling_params.end_of_stroke_max_iterations = -3;
+ EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+ {
+ auto bad_params = kGoodStrokeModelParams;
+ bad_params.prediction_params =
+ KalmanPredictorParams{.prediction_interval = Duration(-1)};
+ EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+ }
+}
+
+TEST(ParamsTest, NaNIsNotAValidValue) {
+ auto bad_params = kGoodStrokeModelParams;
+ bad_params.position_modeler_params.spring_mass_constant =
+ std::numeric_limits<float>::quiet_NaN();
+ EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+}
+
+TEST(ParamsTest, InfinityIsNotAValidValue) {
+ auto bad_params = kGoodStrokeModelParams;
+ bad_params.position_modeler_params.spring_mass_constant =
+ std::numeric_limits<float>::infinity();
+ EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(),
+ absl::StatusCode::kInvalidArgument);
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/stroke_modeler.cc b/ink_stroke_modeler/stroke_modeler.cc
new file mode 100644
index 0000000..37d17f5
--- /dev/null
+++ b/ink_stroke_modeler/stroke_modeler.cc
@@ -0,0 +1,253 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/stroke_modeler.h"
+
+#include <iterator>
+#include <type_traits>
+#include <vector>
+
+#include "absl/base/attributes.h"
+#include "absl/memory/memory.h"
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/types/optional.h"
+#include "absl/types/variant.h"
+#include "ink_stroke_modeler/internal/internal_types.h"
+#include "ink_stroke_modeler/internal/position_modeler.h"
+#include "ink_stroke_modeler/internal/prediction/input_predictor.h"
+#include "ink_stroke_modeler/internal/prediction/kalman_predictor.h"
+#include "ink_stroke_modeler/internal/prediction/stroke_end_predictor.h"
+#include "ink_stroke_modeler/internal/stylus_state_modeler.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+std::vector<Result> ModelStylus(
+ const std::vector<TipState> &tip_states,
+ const StylusStateModeler &stylus_state_modeler) {
+ std::vector<Result> result;
+ result.reserve(tip_states.size());
+ for (const auto &tip_state : tip_states) {
+ auto stylus_state = stylus_state_modeler.Query(tip_state.position);
+ result.push_back({.position = tip_state.position,
+ .velocity = tip_state.velocity,
+ .time = tip_state.time,
+ .pressure = stylus_state.pressure,
+ .tilt = stylus_state.tilt,
+ .orientation = stylus_state.orientation});
+ }
+ return result;
+}
+
+int GetNumberOfSteps(Time start_time, Time end_time, double min_rate) {
+ float float_delta = (end_time - start_time).Value();
+ return std::ceil(float_delta * min_rate);
+}
+
+template <typename>
+ABSL_ATTRIBUTE_UNUSED inline constexpr bool kAlwaysFalse = false;
+
+} // namespace
+
+absl::Status StrokeModeler::Reset(
+ const StrokeModelParams &stroke_model_params) {
+ if (auto status = ValidateStrokeModelParams(stroke_model_params);
+ !status.ok()) {
+ return status;
+ }
+
+ // Note that many of the sub-modelers require some knowledge about the stroke
+ // (e.g. start position, input type) when resetting, and as such are reset in
+ // ProcessTDown() instead.
+ stroke_model_params_ = stroke_model_params;
+ last_input_ = absl::nullopt;
+
+ absl::visit(
+ [this](auto &&params) {
+ using ParamType = std::decay_t<decltype(params)>;
+ if constexpr (std::is_same_v<ParamType, KalmanPredictorParams>) {
+ predictor_ = absl::make_unique<KalmanPredictor>(
+ params, stroke_model_params_->sampling_params);
+ } else if constexpr (std::is_same_v<ParamType,
+ StrokeEndPredictorParams>) {
+ predictor_ = absl::make_unique<StrokeEndPredictor>(
+ stroke_model_params_->position_modeler_params,
+ stroke_model_params_->sampling_params);
+ } else {
+ static_assert(kAlwaysFalse<ParamType>,
+ "Unknown prediction parameter type");
+ }
+ },
+ stroke_model_params_->prediction_params);
+ return absl::OkStatus();
+}
+
+absl::StatusOr<std::vector<Result>> StrokeModeler::Update(const Input &input) {
+ if (!stroke_model_params_.has_value()) {
+ return absl::FailedPreconditionError(
+ "Stroke model has not yet been initialized");
+ }
+
+ if (absl::Status status = ValidateInput(input); !status.ok()) {
+ return status;
+ }
+
+ if (last_input_) {
+ if (last_input_->input == input) {
+ return absl::InvalidArgumentError("Received duplicate input");
+ }
+
+ if (input.time < last_input_->input.time) {
+ return absl::InvalidArgumentError("Inputs travel backwards in time");
+ }
+ }
+
+ switch (input.event_type) {
+ case Input::EventType::kDown:
+ return ProcessDownEvent(input);
+ case Input::EventType::kMove:
+ return ProcessMoveEvent(input);
+ case Input::EventType::kUp:
+ return ProcessUpEvent(input);
+ }
+ return absl::InvalidArgumentError("Invalid EventType.");
+}
+
+absl::StatusOr<std::vector<Result>> StrokeModeler::Predict() const {
+ if (!stroke_model_params_.has_value()) {
+ return absl::FailedPreconditionError(
+ "Stroke model has not yet been initialized");
+ }
+
+ if (last_input_ == std::nullopt) {
+ return absl::FailedPreconditionError(
+ "Cannot construct prediction when no stroke is in-progress");
+ }
+
+ return ModelStylus(
+ predictor_->ConstructPrediction(position_modeler_.CurrentState()),
+ stylus_state_modeler_);
+}
+
+absl::StatusOr<std::vector<Result>> StrokeModeler::ProcessDownEvent(
+ const Input &input) {
+ if (last_input_) {
+ return absl::FailedPreconditionError(
+ "Received down event while stroke is in-progress");
+ }
+
+ // Note that many of the sub-modelers require some knowledge about the stroke
+ // (e.g. start position, input type) when resetting, and as such are reset
+ // here instead of in Reset().
+ wobble_smoother_.Reset(stroke_model_params_->wobble_smoother_params,
+ input.position, input.time);
+ position_modeler_.Reset({input.position, {0, 0}, input.time},
+ stroke_model_params_->position_modeler_params);
+ stylus_state_modeler_.Reset(
+ stroke_model_params_->stylus_state_modeler_params);
+ stylus_state_modeler_.Update(input.position,
+ {.pressure = input.pressure,
+ .tilt = input.tilt,
+ .orientation = input.orientation});
+
+ const TipState &tip_state = position_modeler_.CurrentState();
+ predictor_->Reset();
+ predictor_->Update(input.position, input.time);
+
+ // We don't correct the position on the down event, so we set
+ // corrected_position to use the input position.
+ last_input_ = {.input = input, .corrected_position = input.position};
+ return {{{.position = tip_state.position,
+ .velocity = tip_state.velocity,
+ .time = tip_state.time,
+ .pressure = input.pressure,
+ .tilt = input.tilt,
+ .orientation = input.orientation}}};
+}
+
+absl::StatusOr<std::vector<Result>> StrokeModeler::ProcessUpEvent(
+ const Input &input) {
+ if (!last_input_) {
+ return absl::FailedPreconditionError(
+ "Received up event while no stroke is in-progress");
+ }
+
+ int n_steps =
+ GetNumberOfSteps(last_input_->input.time, input.time,
+ stroke_model_params_->sampling_params.min_output_rate);
+ std::vector<TipState> tip_states;
+ tip_states.reserve(
+ n_steps +
+ stroke_model_params_->sampling_params.end_of_stroke_max_iterations);
+ position_modeler_.UpdateAlongLinearPath(
+ last_input_->corrected_position, last_input_->input.time, input.position,
+ input.time, n_steps, std::back_inserter(tip_states));
+
+ position_modeler_.ModelEndOfStroke(
+ input.position,
+ Duration(1. / stroke_model_params_->sampling_params.min_output_rate),
+ stroke_model_params_->sampling_params.end_of_stroke_max_iterations,
+ stroke_model_params_->sampling_params.end_of_stroke_stopping_distance,
+ std::back_inserter(tip_states));
+
+ if (tip_states.empty()) {
+ // If we haven't generated any new states, add the current state. This can
+ // happen if the TUp has the same timestamp as the last in-contact input.
+ tip_states.push_back(position_modeler_.CurrentState());
+ }
+
+ stylus_state_modeler_.Update(input.position,
+ {.pressure = input.pressure,
+ .tilt = input.tilt,
+ .orientation = input.orientation});
+
+ // This indicates that we've finished the stroke.
+ last_input_ = absl::nullopt;
+
+ return ModelStylus(tip_states, stylus_state_modeler_);
+}
+
+absl::StatusOr<std::vector<Result>> StrokeModeler::ProcessMoveEvent(
+ const Input &input) {
+ if (!last_input_) {
+ return absl::FailedPreconditionError(
+ "Received move event while no stroke is in-progress");
+ }
+
+ Vec2 corrected_position = wobble_smoother_.Update(input.position, input.time);
+ stylus_state_modeler_.Update(corrected_position,
+ {.pressure = input.pressure,
+ .tilt = input.tilt,
+ .orientation = input.orientation});
+
+ int n_steps =
+ GetNumberOfSteps(last_input_->input.time, input.time,
+ stroke_model_params_->sampling_params.min_output_rate);
+ std::vector<TipState> tip_states;
+ tip_states.reserve(n_steps);
+ position_modeler_.UpdateAlongLinearPath(
+ last_input_->corrected_position, last_input_->input.time,
+ corrected_position, input.time, n_steps, std::back_inserter(tip_states));
+
+ predictor_->Update(corrected_position, input.time);
+ last_input_ = {.input = input, .corrected_position = corrected_position};
+ return ModelStylus(tip_states, stylus_state_modeler_);
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/stroke_modeler.h b/ink_stroke_modeler/stroke_modeler.h
new file mode 100644
index 0000000..a25b971
--- /dev/null
+++ b/ink_stroke_modeler/stroke_modeler.h
@@ -0,0 +1,99 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_STROKE_MODELER_H_
+#define INK_STROKE_MODELER_STROKE_MODELER_H_
+
+#include <memory>
+#include <vector>
+
+#include "absl/status/status.h"
+#include "absl/status/statusor.h"
+#include "absl/types/optional.h"
+#include "ink_stroke_modeler/internal/position_modeler.h"
+#include "ink_stroke_modeler/internal/prediction/input_predictor.h"
+#include "ink_stroke_modeler/internal/stylus_state_modeler.h"
+#include "ink_stroke_modeler/internal/wobble_smoother.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+
+// This class models a stroke from a raw input stream. The modeling is performed
+// in several stages, which are delegated to component classes:
+// - Wobble Smoothing: Dampens high-frequency noise from quantization error.
+// - Position Modeling: Models the pen tip as a mass, connected by a spring, to
+// a moving anchor.
+// - Stylus State Modeling: Constructs stylus states for modeled positions by
+// interpolating over the raw input.
+//
+// Additionally, this class provides prediction of the modeled stroke.
+//
+// StrokeModeler is completely unit-agnostic. That is, it doesn't matter what
+// units or coordinate-system the input is given in; the output will be given in
+// the same coordinate-system and units.
+class StrokeModeler {
+ public:
+ // Clears any in-progress stroke, and initializes (or re-initializes) the
+ // model with the given parameters. Returns an error if the parameters are
+ // invalid.
+ absl::Status Reset(const StrokeModelParams &stroke_model_params);
+
+ // Updates the model with a raw input, returning the generated results. Any
+ // previously generated results are stable, i.e. any previously returned
+ // Results are still valid.
+ //
+ // Returns an error if the the model has not yet been initialized (via Reset)
+ // or if the input stream is malformed (e.g decreasing time, Up event before
+ // Down event).
+ //
+ // If this does not return an error, the result will contain at least one
+ // Result, and potentially more than one if the inputs are slower than
+ // the minimum output rate.
+ absl::StatusOr<std::vector<Result>> Update(const Input &input);
+
+ // Model the given input prediction without changing the internal model state.
+ //
+ // Returns an error if the the model has not yet been initialized (via Reset),
+ // or if there is no stroke in progress. The output is limited to results
+ // where the predictor has sufficient confidence,
+ absl::StatusOr<std::vector<Result>> Predict() const;
+
+ private:
+ absl::StatusOr<std::vector<Result>> ProcessDownEvent(const Input &input);
+ absl::StatusOr<std::vector<Result>> ProcessMoveEvent(const Input &input);
+ absl::StatusOr<std::vector<Result>> ProcessUpEvent(const Input &input);
+
+ std::unique_ptr<InputPredictor> predictor_;
+
+ absl::optional<StrokeModelParams> stroke_model_params_;
+
+ WobbleSmoother wobble_smoother_;
+ PositionModeler position_modeler_;
+ StylusStateModeler stylus_state_modeler_;
+
+ struct InputAndCorrectedPosition {
+ Input input;
+ Vec2 corrected_position{0};
+ };
+ absl::optional<InputAndCorrectedPosition> last_input_;
+};
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_STROKE_MODELER_H_
diff --git a/ink_stroke_modeler/stroke_modeler_test.cc b/ink_stroke_modeler/stroke_modeler_test.cc
new file mode 100644
index 0000000..00f1dba
--- /dev/null
+++ b/ink_stroke_modeler/stroke_modeler_test.cc
@@ -0,0 +1,1395 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/stroke_modeler.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/status/status.h"
+#include "ink_stroke_modeler/internal/type_matchers.h"
+#include "ink_stroke_modeler/params.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+using ::testing::DoubleNear;
+using ::testing::ElementsAre;
+using ::testing::FloatNear;
+using ::testing::IsEmpty;
+using ::testing::Matches;
+using ::testing::Not;
+
+constexpr float kTol = 1e-4;
+
+// These parameters use cm for distance and seconds for time.
+const StrokeModelParams kDefaultParams{
+ .wobble_smoother_params{
+ .timeout = Duration(.04), .speed_floor = 1.31, .speed_ceiling = 1.44},
+ .position_modeler_params{.spring_mass_constant = 11.f / 32400,
+ .drag_constant = 72.f},
+ .sampling_params{.min_output_rate = 180,
+ .end_of_stroke_stopping_distance = .001,
+ .end_of_stroke_max_iterations = 20},
+ .stylus_state_modeler_params{.max_input_samples = 20},
+ .prediction_params = StrokeEndPredictorParams()};
+
+MATCHER_P2(ResultNearMatcher, expected, tolerance, "") {
+ if (Matches(Vec2Near(expected.position, tolerance))(arg.position) &&
+ Matches(Vec2Near(expected.velocity, tolerance))(arg.velocity) &&
+ Matches(DoubleNear(expected.time.Value(), tolerance))(arg.time.Value()) &&
+ Matches(FloatNear(expected.pressure, tolerance))(arg.pressure) &&
+ Matches(FloatNear(expected.tilt, tolerance))(arg.tilt) &&
+ Matches(FloatNear(expected.orientation, tolerance))(arg.orientation)) {
+ return true;
+ }
+
+ return false;
+}
+
+::testing::Matcher<Result> ResultNear(const Result &expected, float tolerance) {
+ return ResultNearMatcher(expected, tolerance);
+}
+
+TEST(StrokeModelerTest, NoPredictionUponInit) {
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+ EXPECT_EQ(modeler.Predict().status().code(),
+ absl::StatusCode::kFailedPrecondition);
+}
+
+TEST(StrokeModelerTest, InputRateSlowerThanMinOutputRate) {
+ const Duration kDeltaTime{1. / 30};
+
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+
+ Time time{0};
+ absl::StatusOr<std::vector<Result>> results =
+ modeler.Update({.event_type = Input::EventType::kDown,
+ .position = {3, 4},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(
+ *results,
+ ElementsAre(ResultNear(
+ {.position = {3, 4}, .velocity = {0, 0}, .time = Time(0)}, kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, IsEmpty());
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {3.2, 4.2},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.0019, 4.0019},
+ .velocity = {0.4007, 0.4007},
+ .time = Time(0.0048)},
+ kTol),
+ ResultNear({.position = {3.0069, 4.0069},
+ .velocity = {1.0381, 1.0381},
+ .time = Time(0.0095)},
+ kTol),
+ ResultNear({.position = {3.0154, 4.0154},
+ .velocity = {1.7883, 1.7883},
+ .time = Time(0.0143)},
+ kTol),
+ ResultNear({.position = {3.0276, 4.0276},
+ .velocity = {2.5626, 2.5626},
+ .time = Time(0.0190)},
+ kTol),
+ ResultNear({.position = {3.0433, 4.0433},
+ .velocity = {3.3010, 3.3010},
+ .time = Time(0.0238)},
+ kTol),
+ ResultNear({.position = {3.0622, 4.0622},
+ .velocity = {3.9665, 3.9665},
+ .time = Time(0.0286)},
+ kTol),
+ ResultNear({.position = {3.0838, 4.0838},
+ .velocity = {4.5397, 4.5397},
+ .time = Time(0.0333)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.1095, 4.1095},
+ .velocity = {4.6253, 4.6253},
+ .time = Time(0.0389)},
+ kTol),
+ ResultNear({.position = {3.1331, 4.1331},
+ .velocity = {4.2563, 4.2563},
+ .time = Time(0.0444)},
+ kTol),
+ ResultNear({.position = {3.1534, 4.1534},
+ .velocity = {3.6479, 3.6479},
+ .time = Time(0.0500)},
+ kTol),
+ ResultNear({.position = {3.1698, 4.1698},
+ .velocity = {2.9512, 2.9512},
+ .time = Time(0.0556)},
+ kTol),
+ ResultNear({.position = {3.1824, 4.1824},
+ .velocity = {2.2649, 2.2649},
+ .time = Time(0.0611)},
+ kTol),
+ ResultNear({.position = {3.1915, 4.1915},
+ .velocity = {1.6473, 1.6473},
+ .time = Time(0.0667)},
+ kTol),
+ ResultNear({.position = {3.1978, 4.1978},
+ .velocity = {1.1269, 1.1269},
+ .time = Time(0.0722)},
+ kTol),
+ ResultNear({.position = {3.1992, 4.1992},
+ .velocity = {1.0232, 1.0232},
+ .time = Time(0.0736)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {3.5, 4.2},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.1086, 4.1058},
+ .velocity = {5.2142, 4.6131},
+ .time = Time(0.0381)},
+ kTol),
+ ResultNear({.position = {3.1368, 4.1265},
+ .velocity = {5.9103, 4.3532},
+ .time = Time(0.0429)},
+ kTol),
+ ResultNear({.position = {3.1681, 4.1450},
+ .velocity = {6.5742, 3.8917},
+ .time = Time(0.0476)},
+ kTol),
+ ResultNear({.position = {3.2022, 4.1609},
+ .velocity = {7.1724, 3.3285},
+ .time = Time(0.0524)},
+ kTol),
+ ResultNear({.position = {3.2388, 4.1739},
+ .velocity = {7.6876, 2.7361},
+ .time = Time(0.0571)},
+ kTol),
+ ResultNear({.position = {3.2775, 4.1842},
+ .velocity = {8.1138, 2.1640},
+ .time = Time(0.0619)},
+ kTol),
+ ResultNear({.position = {3.3177, 4.1920},
+ .velocity = {8.4531, 1.6436},
+ .time = Time(0.0667)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.3625, 4.1982},
+ .velocity = {8.0545, 1.1165},
+ .time = Time(0.0722)},
+ kTol),
+ ResultNear({.position = {3.4018, 4.2021},
+ .velocity = {7.0831, 0.6987},
+ .time = Time(0.0778)},
+ kTol),
+ ResultNear({.position = {3.4344, 4.2043},
+ .velocity = {5.8564, 0.3846},
+ .time = Time(0.0833)},
+ kTol),
+ ResultNear({.position = {3.4598, 4.2052},
+ .velocity = {4.5880, 0.1611},
+ .time = Time(0.0889)},
+ kTol),
+ ResultNear({.position = {3.4788, 4.2052},
+ .velocity = {3.4098, 0.0124},
+ .time = Time(0.0944)},
+ kTol),
+ ResultNear({.position = {3.4921, 4.2048},
+ .velocity = {2.3929, -0.0780},
+ .time = Time(0.1000)},
+ kTol),
+ ResultNear({.position = {3.4976, 4.2045},
+ .velocity = {1.9791, -0.1015},
+ .time = Time(0.1028)},
+ kTol),
+ ResultNear({.position = {3.5001, 4.2044},
+ .velocity = {1.7911, -0.1098},
+ .time = Time(0.1042)},
+ kTol)));
+
+ time += kDeltaTime;
+ // We get more results at the end of the stroke as it tries to "catch up" to
+ // the raw input.
+ results = modeler.Update({.event_type = Input::EventType::kUp,
+ .position = {3.7, 4.4},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.3583, 4.1996},
+ .velocity = {8.5122, 1.5925},
+ .time = Time(0.0714)},
+ kTol),
+ ResultNear({.position = {3.3982, 4.2084},
+ .velocity = {8.3832, 1.8534},
+ .time = Time(0.0762)},
+ kTol),
+ ResultNear({.position = {3.4369, 4.2194},
+ .velocity = {8.1393, 2.3017},
+ .time = Time(0.0810)},
+ kTol),
+ ResultNear({.position = {3.4743, 4.2329},
+ .velocity = {7.8362, 2.8434},
+ .time = Time(0.0857)},
+ kTol),
+ ResultNear({.position = {3.5100, 4.2492},
+ .velocity = {7.5143, 3.4101},
+ .time = Time(0.0905)},
+ kTol),
+ ResultNear({.position = {3.5443, 4.2680},
+ .velocity = {7.2016, 3.9556},
+ .time = Time(0.0952)},
+ kTol),
+ ResultNear({.position = {3.5773, 4.2892},
+ .velocity = {6.9159, 4.4505},
+ .time = Time(0.1000)},
+ kTol),
+ ResultNear({.position = {3.6115, 4.3141},
+ .velocity = {6.1580, 4.4832},
+ .time = Time(0.1056)},
+ kTol),
+ ResultNear({.position = {3.6400, 4.3369},
+ .velocity = {5.1434, 4.0953},
+ .time = Time(0.1111)},
+ kTol),
+ ResultNear({.position = {3.6626, 4.3563},
+ .velocity = {4.0671, 3.4902},
+ .time = Time(0.1167)},
+ kTol),
+ ResultNear({.position = {3.6796, 4.3719},
+ .velocity = {3.0515, 2.8099},
+ .time = Time(0.1222)},
+ kTol),
+ ResultNear({.position = {3.6916, 4.3838},
+ .velocity = {2.1648, 2.1462},
+ .time = Time(0.1278)},
+ kTol),
+ ResultNear({.position = {3.6996, 4.3924},
+ .velocity = {1.4360, 1.5529},
+ .time = Time(0.1333)},
+ kTol),
+ ResultNear({.position = {3.7028, 4.3960},
+ .velocity = {1.1520, 1.3044},
+ .time = Time(0.1361)},
+ kTol)));
+
+ // The stroke is finished, so there's nothing to predict anymore.
+ EXPECT_EQ(modeler.Predict().status().code(),
+ absl::StatusCode::kFailedPrecondition);
+}
+
+TEST(StrokeModelerTest, InputRateFasterThanMinOutputRate) {
+ const Duration kDeltaTime{1. / 300};
+
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+
+ Time time{2};
+ absl::StatusOr<std::vector<Result>> results =
+ modeler.Update({.event_type = Input::EventType::kDown,
+ .position = {5, -3},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(
+ *results,
+ ElementsAre(ResultNear(
+ {.position = {5, -3}, .velocity = {0, 0}, .time = Time(2)}, kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, IsEmpty());
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {5, -3.1},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {5, -3.0033},
+ .velocity = {0, -0.9818},
+ .time = Time(2.0033)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {5, -3.0153},
+ .velocity = {0, -2.1719},
+ .time = Time(2.0089)},
+ kTol),
+ ResultNear({.position = {5, -3.0303},
+ .velocity = {0, -2.6885},
+ .time = Time(2.0144)},
+ kTol),
+ ResultNear({.position = {5, -3.0456},
+ .velocity = {0, -2.7541},
+ .time = Time(2.0200)},
+ kTol),
+ ResultNear({.position = {5, -3.0597},
+ .velocity = {0, -2.5430},
+ .time = Time(2.0256)},
+ kTol),
+ ResultNear({.position = {5, -3.0718},
+ .velocity = {0, -2.1852},
+ .time = Time(2.0311)},
+ kTol),
+ ResultNear({.position = {5, -3.0817},
+ .velocity = {0, -1.7719},
+ .time = Time(2.0367)},
+ kTol),
+ ResultNear({.position = {5, -3.0893},
+ .velocity = {0, -1.3628},
+ .time = Time(2.0422)},
+ kTol),
+ ResultNear({.position = {5, -3.0948},
+ .velocity = {0, -0.9934},
+ .time = Time(2.0478)},
+ kTol),
+ ResultNear({.position = {5, -3.0986},
+ .velocity = {0, -0.6815},
+ .time = Time(2.0533)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {4.975, -3.175},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9992, -3.0114},
+ .velocity = {-0.2455, -2.4322},
+ .time = Time(2.0067)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9962, -3.0344},
+ .velocity = {-0.5430, -4.1368},
+ .time = Time(2.0122)},
+ kTol),
+ ResultNear({.position = {4.9924, -3.0609},
+ .velocity = {-0.6721, -4.7834},
+ .time = Time(2.0178)},
+ kTol),
+ ResultNear({.position = {4.9886, -3.0873},
+ .velocity = {-0.6885, -4.7365},
+ .time = Time(2.0233)},
+ kTol),
+ ResultNear({.position = {4.9851, -3.1110},
+ .velocity = {-0.6358, -4.2778},
+ .time = Time(2.0289)},
+ kTol),
+ ResultNear({.position = {4.9820, -3.1311},
+ .velocity = {-0.5463, -3.6137},
+ .time = Time(2.0344)},
+ kTol),
+ ResultNear({.position = {4.9796, -3.1471},
+ .velocity = {-0.4430, -2.8867},
+ .time = Time(2.0400)},
+ kTol),
+ ResultNear({.position = {4.9777, -3.1593},
+ .velocity = {-0.3407, -2.1881},
+ .time = Time(2.0456)},
+ kTol),
+ ResultNear({.position = {4.9763, -3.1680},
+ .velocity = {-0.2484, -1.5700},
+ .time = Time(2.0511)},
+ kTol),
+ ResultNear({.position = {4.9754, -3.1739},
+ .velocity = {-0.1704, -1.0564},
+ .time = Time(2.0567)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {4.9, -3.2},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9953, -3.0237},
+ .velocity = {-1.1603, -3.7004},
+ .time = Time(2.0100)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9828, -3.0521},
+ .velocity = {-2.2559, -5.1049},
+ .time = Time(2.0156)},
+ kTol),
+ ResultNear({.position = {4.9677, -3.0825},
+ .velocity = {-2.7081, -5.4835},
+ .time = Time(2.0211)},
+ kTol),
+ ResultNear({.position = {4.9526, -3.1115},
+ .velocity = {-2.7333, -5.2122},
+ .time = Time(2.0267)},
+ kTol),
+ ResultNear({.position = {4.9387, -3.1369},
+ .velocity = {-2.4999, -4.5756},
+ .time = Time(2.0322)},
+ kTol),
+ ResultNear({.position = {4.9268, -3.1579},
+ .velocity = {-2.1326, -3.7776},
+ .time = Time(2.0378)},
+ kTol),
+ ResultNear({.position = {4.9173, -3.1743},
+ .velocity = {-1.7184, -2.9554},
+ .time = Time(2.0433)},
+ kTol),
+ ResultNear({.position = {4.9100, -3.1865},
+ .velocity = {-1.3136, -2.1935},
+ .time = Time(2.0489)},
+ kTol),
+ ResultNear({.position = {4.9047, -3.1950},
+ .velocity = {-0.9513, -1.5369},
+ .time = Time(2.0544)},
+ kTol),
+ ResultNear({.position = {4.9011, -3.2006},
+ .velocity = {-0.6475, -1.0032},
+ .time = Time(2.0600)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {4.825, -3.2},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9868, -3.0389},
+ .velocity = {-2.5540, -4.5431},
+ .time = Time(2.0133)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9636, -3.0687},
+ .velocity = {-4.1801, -5.3627},
+ .time = Time(2.0189)},
+ kTol),
+ ResultNear({.position = {4.9370, -3.0985},
+ .velocity = {-4.7757, -5.3670},
+ .time = Time(2.0244)},
+ kTol),
+ ResultNear({.position = {4.9109, -3.1256},
+ .velocity = {-4.6989, -4.8816},
+ .time = Time(2.0300)},
+ kTol),
+ ResultNear({.position = {4.8875, -3.1486},
+ .velocity = {-4.2257, -4.1466},
+ .time = Time(2.0356)},
+ kTol),
+ ResultNear({.position = {4.8677, -3.1671},
+ .velocity = {-3.5576, -3.3287},
+ .time = Time(2.0411)},
+ kTol),
+ ResultNear({.position = {4.8520, -3.1812},
+ .velocity = {-2.8333, -2.5353},
+ .time = Time(2.0467)},
+ kTol),
+ ResultNear({.position = {4.8401, -3.1914},
+ .velocity = {-2.1411, -1.8288},
+ .time = Time(2.0522)},
+ kTol),
+ ResultNear({.position = {4.8316, -3.1982},
+ .velocity = {-1.5312, -1.2386},
+ .time = Time(2.0578)},
+ kTol),
+ ResultNear({.position = {4.8280, -3.2010},
+ .velocity = {-1.2786, -1.0053},
+ .time = Time(2.0606)},
+ kTol),
+ ResultNear({.position = {4.8272, -3.2017},
+ .velocity = {-1.2209, -0.9529},
+ .time = Time(2.0613)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {4.75, -3.225},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9726, -3.0565},
+ .velocity = {-4.2660, -5.2803},
+ .time = Time(2.0167)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9381, -3.0894},
+ .velocity = {-6.2018, -5.9261},
+ .time = Time(2.0222)},
+ kTol),
+ ResultNear({.position = {4.9004, -3.1215},
+ .velocity = {-6.7995, -5.7749},
+ .time = Time(2.0278)},
+ kTol),
+ ResultNear({.position = {4.8640, -3.1501},
+ .velocity = {-6.5400, -5.1591},
+ .time = Time(2.0333)},
+ kTol),
+ ResultNear({.position = {4.8319, -3.1741},
+ .velocity = {-5.7897, -4.3207},
+ .time = Time(2.0389)},
+ kTol),
+ ResultNear({.position = {4.8051, -3.1932},
+ .velocity = {-4.8133, -3.4248},
+ .time = Time(2.0444)},
+ kTol),
+ ResultNear({.position = {4.7841, -3.2075},
+ .velocity = {-3.7898, -2.5759},
+ .time = Time(2.0500)},
+ kTol),
+ ResultNear({.position = {4.7683, -3.2176},
+ .velocity = {-2.8312, -1.8324},
+ .time = Time(2.0556)},
+ kTol),
+ ResultNear({.position = {4.7572, -3.2244},
+ .velocity = {-1.9986, -1.2198},
+ .time = Time(2.0611)},
+ kTol),
+ ResultNear({.position = {4.7526, -3.2271},
+ .velocity = {-1.6580, -0.9805},
+ .time = Time(2.0639)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {4.7, -3.3},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9529, -3.0778},
+ .velocity = {-5.9184, -6.4042},
+ .time = Time(2.0200)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9101, -3.1194},
+ .velocity = {-7.6886, -7.4784},
+ .time = Time(2.0256)},
+ kTol),
+ ResultNear({.position = {4.8654, -3.1607},
+ .velocity = {-8.0518, -7.4431},
+ .time = Time(2.0311)},
+ kTol),
+ ResultNear({.position = {4.8235, -3.1982},
+ .velocity = {-7.5377, -6.7452},
+ .time = Time(2.0367)},
+ kTol),
+ ResultNear({.position = {4.7872, -3.2299},
+ .velocity = {-6.5440, -5.7133},
+ .time = Time(2.0422)},
+ kTol),
+ ResultNear({.position = {4.7574, -3.2553},
+ .velocity = {-5.3529, -4.5748},
+ .time = Time(2.0478)},
+ kTol),
+ ResultNear({.position = {4.7344, -3.2746},
+ .velocity = {-4.1516, -3.4758},
+ .time = Time(2.0533)},
+ kTol),
+ ResultNear({.position = {4.7174, -3.2885},
+ .velocity = {-3.0534, -2.5004},
+ .time = Time(2.0589)},
+ kTol),
+ ResultNear({.position = {4.7056, -3.2979},
+ .velocity = {-2.1169, -1.6879},
+ .time = Time(2.0644)},
+ kTol),
+ ResultNear({.position = {4.7030, -3.3000},
+ .velocity = {-1.9283, -1.5276},
+ .time = Time(2.0658)},
+ kTol),
+ ResultNear({.position = {4.7017, -3.3010},
+ .velocity = {-1.8380, -1.4512},
+ .time = Time(2.0665)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {4.675, -3.4},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9288, -3.1046},
+ .velocity = {-7.2260, -8.0305},
+ .time = Time(2.0233)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.8816, -3.1582},
+ .velocity = {-8.4881, -9.6525},
+ .time = Time(2.0289)},
+ kTol),
+ ResultNear({.position = {4.8345, -3.2124},
+ .velocity = {-8.4738, -9.7482},
+ .time = Time(2.0344)},
+ kTol),
+ ResultNear({.position = {4.7918, -3.2619},
+ .velocity = {-7.6948, -8.9195},
+ .time = Time(2.0400)},
+ kTol),
+ ResultNear({.position = {4.7555, -3.3042},
+ .velocity = {-6.5279, -7.6113},
+ .time = Time(2.0456)},
+ kTol),
+ ResultNear({.position = {4.7264, -3.3383},
+ .velocity = {-5.2343, -6.1345},
+ .time = Time(2.0511)},
+ kTol),
+ ResultNear({.position = {4.7043, -3.3643},
+ .velocity = {-3.9823, -4.6907},
+ .time = Time(2.0567)},
+ kTol),
+ ResultNear({.position = {4.6884, -3.3832},
+ .velocity = {-2.8691, -3.3980},
+ .time = Time(2.0622)},
+ kTol),
+ ResultNear({.position = {4.6776, -3.3961},
+ .velocity = {-1.9403, -2.3135},
+ .time = Time(2.0678)},
+ kTol),
+ ResultNear({.position = {4.6752, -3.3990},
+ .velocity = {-1.7569, -2.0983},
+ .time = Time(2.0692)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {4.675, -3.525},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9022, -3.1387},
+ .velocity = {-7.9833, -10.2310},
+ .time = Time(2.0267)},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.8549, -3.2079},
+ .velocity = {-8.5070, -12.4602},
+ .time = Time(2.0322)},
+ kTol),
+ ResultNear({.position = {4.8102, -3.2783},
+ .velocity = {-8.0479, -12.6650},
+ .time = Time(2.0378)},
+ kTol),
+ ResultNear({.position = {4.7711, -3.3429},
+ .velocity = {-7.0408, -11.6365},
+ .time = Time(2.0433)},
+ kTol),
+ ResultNear({.position = {4.7389, -3.3983},
+ .velocity = {-5.7965, -9.9616},
+ .time = Time(2.0489)},
+ kTol),
+ ResultNear({.position = {4.7137, -3.4430},
+ .velocity = {-4.5230, -8.0510},
+ .time = Time(2.0544)},
+ kTol),
+ ResultNear({.position = {4.6951, -3.4773},
+ .velocity = {-3.3477, -6.1727},
+ .time = Time(2.0600)},
+ kTol),
+ ResultNear({.position = {4.6821, -3.5022},
+ .velocity = {-2.3381, -4.4846},
+ .time = Time(2.0656)},
+ kTol),
+ ResultNear({.position = {4.6737, -3.5192},
+ .velocity = {-1.5199, -3.0641},
+ .time = Time(2.0711)},
+ kTol),
+ ResultNear({.position = {4.6718, -3.5231},
+ .velocity = {-1.3626, -2.7813},
+ .time = Time(2.0725)},
+ kTol)));
+
+ time += kDeltaTime;
+ // We get more results at the end of the stroke as it tries to "catch up" to
+ // the raw input.
+ results = modeler.Update({.event_type = Input::EventType::kUp,
+ .position = {4.7, -3.6},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.8753, -3.1797},
+ .velocity = {-8.0521, -12.3049},
+ .time = Time(2.0300)},
+ kTol),
+ ResultNear({.position = {4.8325, -3.2589},
+ .velocity = {-7.7000, -14.2607},
+ .time = Time(2.0356)},
+ kTol),
+ ResultNear({.position = {4.7948, -3.3375},
+ .velocity = {-6.7888, -14.1377},
+ .time = Time(2.0411)},
+ kTol),
+ ResultNear({.position = {4.7636, -3.4085},
+ .velocity = {-5.6249, -12.7787},
+ .time = Time(2.0467)},
+ kTol),
+ ResultNear({.position = {4.7390, -3.4685},
+ .velocity = {-4.4152, -10.8015},
+ .time = Time(2.0522)},
+ kTol),
+ ResultNear({.position = {4.7208, -3.5164},
+ .velocity = {-3.2880, -8.6333},
+ .time = Time(2.0578)},
+ kTol),
+ ResultNear({.position = {4.7079, -3.5528},
+ .velocity = {-2.3128, -6.5475},
+ .time = Time(2.0633)},
+ kTol),
+ ResultNear({.position = {4.6995, -3.5789},
+ .velocity = {-1.5174, -4.7008},
+ .time = Time(2.0689)},
+ kTol),
+ ResultNear({.position = {4.6945, -3.5965},
+ .velocity = {-0.9022, -3.1655},
+ .time = Time(2.0744)},
+ kTol),
+ ResultNear({.position = {4.6942, -3.5976},
+ .velocity = {-0.8740, -3.0899},
+ .time = Time(2.0748)},
+ kTol)));
+
+ // The stroke is finished, so there's nothing to predict anymore.
+ EXPECT_EQ(modeler.Predict().status().code(),
+ absl::StatusCode::kFailedPrecondition);
+}
+
+TEST(StrokeModelerTest, WobbleSmoothed) {
+ const Duration kDeltaTime{.0167};
+
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+
+ Time time{4};
+ absl::StatusOr<std::vector<Result>> results =
+ modeler.Update({.event_type = Input::EventType::kDown,
+ .position = {-6, -2},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(
+ *results,
+ ElementsAre(ResultNear(
+ {.position = {-6, -2}, .velocity = {0, 0}, .time = Time(4)}, kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {-6.02, -2},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {-6.0001, -2},
+ .velocity = {-0.0328, 0},
+ .time = Time(4.0042)},
+ kTol),
+ ResultNear({.position = {-6.0005, -2},
+ .velocity = {-0.0869, 0},
+ .time = Time(4.0084)},
+ kTol),
+ ResultNear({.position = {-6.0011, -2},
+ .velocity = {-0.1531, 0},
+ .time = Time(4.0125)},
+ kTol),
+ ResultNear({.position = {-6.0021, -2},
+ .velocity = {-0.2244, 0},
+ .time = Time(4.0167)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {-6.02, -2.02},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {-6.0032, -2.0001},
+ .velocity = {-0.2709, -0.0205},
+ .time = Time(4.0209)},
+ kTol),
+ ResultNear({.position = {-6.0044, -2.0003},
+ .velocity = {-0.2977, -0.0543},
+ .time = Time(4.0251)},
+ kTol),
+ ResultNear({.position = {-6.0057, -2.0007},
+ .velocity = {-0.3093, -0.0956},
+ .time = Time(4.0292)},
+ kTol),
+ ResultNear({.position = {-6.0070, -2.0013},
+ .velocity = {-0.3097, -0.1401},
+ .time = Time(4.0334)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {-6.04, -2.02},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {-6.0084, -2.0021},
+ .velocity = {-0.3350, -0.1845},
+ .time = Time(4.0376)},
+ kTol),
+ ResultNear({.position = {-6.0100, -2.0030},
+ .velocity = {-0.3766, -0.2266},
+ .time = Time(4.0418)},
+ kTol),
+ ResultNear({.position = {-6.0118, -2.0041},
+ .velocity = {-0.4273, -0.2649},
+ .time = Time(4.0459)},
+ kTol),
+ ResultNear({.position = {-6.0138, -2.0054},
+ .velocity = {-0.4818, -0.2986},
+ .time = Time(4.0501)},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {-6.04, -2.04},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {-6.0160, -2.0068},
+ .velocity = {-0.5157, -0.3478},
+ .time = Time(4.0543)},
+ kTol),
+ ResultNear({.position = {-6.0182, -2.0085},
+ .velocity = {-0.5334, -0.4054},
+ .time = Time(4.0585)},
+ kTol),
+ ResultNear({.position = {-6.0204, -2.0105},
+ .velocity = {-0.5389, -0.4658},
+ .time = Time(4.0626)},
+ kTol),
+ ResultNear({.position = {-6.0227, -2.0126},
+ .velocity = {-0.5356, -0.5251},
+ .time = Time(4.0668)},
+ kTol)));
+}
+
+TEST(StrokeModelerTest, Reset) {
+ const Duration kDeltaTime{1. / 50};
+
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+
+ Time time{0};
+ absl::StatusOr<std::vector<Result>> results =
+ modeler.Update({.event_type = Input::EventType::kDown,
+ .position = {-8, -10},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, IsEmpty());
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {-10, -8},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {-11, -5},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+ EXPECT_EQ(modeler.Predict().status().code(),
+ absl::StatusCode::kFailedPrecondition);
+}
+
+TEST(StrokeModelerTest, IgnoreInputsBeforeTDown) {
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+
+ EXPECT_EQ(modeler
+ .Update({.event_type = Input::EventType::kMove,
+ .position = {0, 0},
+ .time = Time(0)})
+ .status()
+ .code(),
+ absl::StatusCode::kFailedPrecondition);
+
+ EXPECT_EQ(modeler
+ .Update({.event_type = Input::EventType::kUp,
+ .position = {0, 0},
+ .time = Time(1)})
+ .status()
+ .code(),
+ absl::StatusCode::kFailedPrecondition);
+}
+
+TEST(StrokeModelerTest, IgnoreTDownWhileStrokeIsInProgress) {
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+
+ absl::StatusOr<std::vector<Result>> results =
+ modeler.Update({.event_type = Input::EventType::kDown,
+ .position = {0, 0},
+ .time = Time(0)});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ EXPECT_EQ(modeler
+ .Update({.event_type = Input::EventType::kDown,
+ .position = {1, 1},
+ .time = Time(1)})
+ .status()
+ .code(),
+ absl::StatusCode::kFailedPrecondition);
+
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {1, 1},
+ .time = Time(1)});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ EXPECT_EQ(modeler
+ .Update({.event_type = Input::EventType::kDown,
+ .position = {2, 2},
+ .time = Time(2)})
+ .status()
+ .code(),
+ absl::StatusCode::kFailedPrecondition);
+}
+
+TEST(StrokeModelerTest, AlternateParams) {
+ const Duration kDeltaTime{1. / 50};
+
+ StrokeModelParams stroke_model_params = kDefaultParams;
+ stroke_model_params.sampling_params.min_output_rate = 70;
+
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(stroke_model_params).ok());
+
+ Time time{3};
+ absl::StatusOr<std::vector<Result>> results =
+ modeler.Update({.event_type = Input::EventType::kDown,
+ .position = {0, 0},
+ .time = time,
+ .pressure = .5,
+ .tilt = .2,
+ .orientation = .4});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear(
+ {{0, 0}, {0, 0}, Time{3}, .5, .2, .4}, kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, IsEmpty());
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {0, .5},
+ .time = time,
+ .pressure = .4,
+ .tilt = .3,
+ .orientation = .3});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(
+ *results,
+ ElementsAre(
+ ResultNear(
+ {{0, 0.0736}, {0, 7.3636}, Time{3.0100}, 0.4853, 0.2147, 0.3853},
+ kTol),
+ ResultNear(
+ {{0, 0.2198}, {0, 14.6202}, Time{3.0200}, 0.4560, 0.2440, 0.3560},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(
+ *results,
+ ElementsAre(
+ ResultNear(
+ {{0, 0.3823}, {0, 11.3709}, Time{3.0343}, 0.4235, 0.2765, 0.3235},
+ kTol),
+ ResultNear(
+ {{0, 0.4484}, {0, 4.6285}, Time{3.0486}, 0.4103, 0.2897, 0.3103},
+ kTol),
+ ResultNear(
+ {{0, 0.4775}, {0, 2.0389}, Time{3.0629}, 0.4045, 0.2955, 0.3045},
+ kTol),
+ ResultNear(
+ {{0, 0.4902}, {0, 0.8873}, Time{3.0771}, 0.4020, 0.2980, 0.3020},
+ kTol),
+ ResultNear(
+ {{0, 0.4957}, {0, 0.3868}, Time{3.0914}, 0.4009, 0.2991, 0.3009},
+ kTol),
+ ResultNear(
+ {{0, 0.4981}, {0, 0.1686}, Time{3.1057}, 0.4004, 0.2996, 0.3004},
+ kTol),
+ ResultNear(
+ {{0, 0.4992}, {0, 0.0735}, Time{3.1200}, 0.4002, 0.2998, 0.3002},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {.2, 1},
+ .time = time,
+ .pressure = .3,
+ .tilt = .4,
+ .orientation = .2});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({{0.0295, 0.4169},
+ {2.9455, 19.7093},
+ Time{3.0300},
+ 0.4166,
+ 0.2834,
+ 0.3166},
+ kTol),
+ ResultNear({{0.0879, 0.6439},
+ {5.8481, 22.6926},
+ Time{3.0400},
+ 0.3691,
+ 0.3309,
+ 0.2691},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({{0.1529, 0.8487},
+ {4.5484, 14.3374},
+ Time{3.0543},
+ 0.3293,
+ 0.3707,
+ 0.2293},
+ kTol),
+ ResultNear({{0.1794, 0.9338},
+ {1.8514, 5.9577},
+ Time{3.0686},
+ 0.3128,
+ 0.3872,
+ 0.2128},
+ kTol),
+ ResultNear({{0.1910, 0.9712},
+ {0.8156, 2.6159},
+ Time{3.0829},
+ 0.3056,
+ 0.3944,
+ 0.2056},
+ kTol),
+ ResultNear({{0.1961, 0.9874},
+ {0.3549, 1.1389},
+ Time{3.0971},
+ 0.3024,
+ 0.3976,
+ 0.2024},
+ kTol),
+ ResultNear({{0.1983, 0.9945},
+ {0.1547, 0.4965},
+ Time{3.1114},
+ 0.3011,
+ 0.3989,
+ 0.2011},
+ kTol),
+ ResultNear({{0.1993, 0.9976},
+ {0.0674, 0.2164},
+ Time{3.1257},
+ 0.3005,
+ 0.3995,
+ 0.2005},
+ kTol),
+ ResultNear({{0.1997, 0.9990},
+ {0.0294, 0.0943},
+ Time{3.1400},
+ 0.3002,
+ 0.3998,
+ 0.2002},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {.4, 1.4},
+ .time = time,
+ .pressure = .2,
+ .tilt = .7,
+ .orientation = 0});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({{0.1668, 0.8712},
+ {7.8837, 22.7349},
+ Time{3.0500},
+ 0.3245,
+ 0.3755,
+ 0.2245},
+ kTol),
+ ResultNear({{0.2575, 1.0906},
+ {9.0771, 21.9411},
+ Time{3.0600},
+ 0.2761,
+ 0.4716,
+ 0.1522},
+ kTol)));
+
+ results = modeler.Predict();
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({{0.3395, 1.2676},
+ {5.7349, 12.3913},
+ Time{3.0743},
+ 0.2325,
+ 0.6024,
+ 0.0651},
+ kTol),
+ ResultNear({{0.3735, 1.3421},
+ {2.3831, 5.2156},
+ Time{3.0886},
+ 0.2142,
+ 0.6573,
+ 0.0284},
+ kTol),
+ ResultNear({{0.3885, 1.3748},
+ {1.0463, 2.2854},
+ Time{3.1029},
+ 0.2062,
+ 0.6814,
+ 0.0124},
+ kTol),
+ ResultNear({{0.3950, 1.3890},
+ {0.4556, 0.9954},
+ Time{3.1171},
+ 0.2027,
+ 0.6919,
+ 0.0054},
+ kTol),
+ ResultNear({{0.3978, 1.3952},
+ {0.1986, 0.4339},
+ Time{3.1314},
+ 0.2012,
+ 0.6965,
+ 0.0024},
+ kTol),
+ ResultNear({{0.3990, 1.3979},
+ {0.0866, 0.1891},
+ Time{3.1457},
+ 0.2005,
+ 0.6985,
+ 0.0010},
+ kTol),
+ ResultNear({{0.3996, 1.3991},
+ {0.0377, 0.0824},
+ Time{3.1600},
+ 0.2002,
+ 0.6993,
+ 0.0004},
+ kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kUp,
+ .position = {.7, 1.7},
+ .time = time,
+ .pressure = .1,
+ .tilt = 1,
+ .orientation = 0});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, ElementsAre(ResultNear({{0.3691, 1.2874},
+ {11.1558, 19.6744},
+ Time{3.0700},
+ 0.2256,
+ 0.6231,
+ 0.0512},
+ kTol),
+ ResultNear({{0.4978, 1.4640},
+ {12.8701, 17.6629},
+ Time{3.0800},
+ 0.1730,
+ 0.7809,
+ 0},
+ kTol),
+ ResultNear({{0.6141, 1.5986},
+ {8.1404, 9.4261},
+ Time{3.0943},
+ 0.1312,
+ 0.9064,
+ 0},
+ kTol),
+ ResultNear({{0.6624, 1.6557},
+ {3.3822, 3.9953},
+ Time{3.1086},
+ 0.1136,
+ 0.9591,
+ 0},
+ kTol),
+ ResultNear({{0.6836, 1.6807},
+ {1.4851, 1.7488},
+ Time{3.1229},
+ 0.1059,
+ 0.9822,
+ 0},
+ kTol),
+ ResultNear({{0.6929, 1.6916},
+ {0.6466, 0.7618},
+ Time{3.1371},
+ 0.1026,
+ 0.9922,
+ 0},
+ kTol),
+ ResultNear({{0.6969, 1.6963},
+ {0.2819, 0.3321},
+ Time{3.1514},
+ 0.1011,
+ 0.9966,
+ 0},
+ kTol),
+ ResultNear({{0.6986, 1.6984},
+ {0.1229, 0.1447},
+ Time{3.1657},
+ 0.1005,
+ 0.9985,
+ 0},
+ kTol),
+ ResultNear({{0.6994, 1.6993},
+ {0.0535, 0.0631},
+ Time{3.1800},
+ 0.1002,
+ 0.9994,
+ 0},
+ kTol)));
+
+ EXPECT_EQ(modeler.Predict().status().code(),
+ absl::StatusCode::kFailedPrecondition);
+}
+
+TEST(StrokeModelerTest, GenerateOutputOnTUpEvenIfNoTimeDelta) {
+ const Duration kDeltaTime{1. / 500};
+
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+
+ Time time{0};
+ absl::StatusOr<std::vector<Result>> results =
+ modeler.Update({.event_type = Input::EventType::kDown,
+ .position = {5, 5},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(
+ *results,
+ ElementsAre(ResultNear(
+ {.position = {5, 5}, .velocity = {0, 0}, .time = Time(0)}, kTol)));
+
+ time += kDeltaTime;
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {5, 5},
+ .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results,
+ ElementsAre(ResultNear(
+ {.position = {5, 5}, .velocity = {0, 0}, .time = Time(0.002)},
+ kTol)));
+
+ results = modeler.Update(
+ {.event_type = Input::EventType::kUp, .position = {5, 5}, .time = time});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(
+ *results,
+ ElementsAre(ResultNear(
+ {.position = {5, 5}, .velocity = {0, 0}, .time = Time(0.0076)},
+ kTol)));
+}
+
+TEST(StrokeModelerTest, RejectInputIfNegativeTimeDelta) {
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+
+ absl::StatusOr<std::vector<Result>> results =
+ modeler.Update({.event_type = Input::EventType::kDown,
+ .position = {0, 0},
+ .time = Time(0)});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ EXPECT_EQ(modeler
+ .Update({.event_type = Input::EventType::kMove,
+ .position = {1, 1},
+ .time = Time(-.1)})
+ .status()
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {1, 1},
+ .time = Time(1)});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ EXPECT_EQ(modeler
+ .Update({.event_type = Input::EventType::kUp,
+ .position = {1, 1},
+ .time = Time(.9)})
+ .status()
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+}
+
+TEST(StrokeModelerTest, RejectDuplicateInput) {
+ StrokeModeler modeler;
+ ASSERT_TRUE(modeler.Reset(kDefaultParams).ok());
+
+ absl::StatusOr<std::vector<Result>> results =
+ modeler.Update({.event_type = Input::EventType::kDown,
+ .position = {0, 0},
+ .time = Time(0),
+ .pressure = .2,
+ .tilt = .3,
+ .orientation = .4});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ EXPECT_EQ(modeler
+ .Update({.event_type = Input::EventType::kDown,
+ .position = {0, 0},
+ .time = Time(0),
+ .pressure = .2,
+ .tilt = .3,
+ .orientation = .4})
+ .status()
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+
+ results = modeler.Update({.event_type = Input::EventType::kMove,
+ .position = {1, 2},
+ .time = Time(1),
+ .pressure = .1,
+ .tilt = .2,
+ .orientation = .3});
+ ASSERT_TRUE(results.ok());
+ EXPECT_THAT(*results, Not(IsEmpty()));
+
+ EXPECT_EQ(modeler
+ .Update({.event_type = Input::EventType::kMove,
+ .position = {1, 2},
+ .time = Time(1),
+ .pressure = .1,
+ .tilt = .2,
+ .orientation = .3})
+ .status()
+ .code(),
+ absl::StatusCode::kInvalidArgument);
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/types.cc b/ink_stroke_modeler/types.cc
new file mode 100644
index 0000000..9864ea3
--- /dev/null
+++ b/ink_stroke_modeler/types.cc
@@ -0,0 +1,36 @@
+#include "ink_stroke_modeler/types.h"
+
+#include "absl/status/status.h"
+#include "ink_stroke_modeler/internal/validation.h"
+
+// This convenience macro evaluates the given expression, and if it does not
+// return an OK status, returns and propagates the status.
+#define RETURN_IF_ERROR(expr) \
+ do { \
+ if (auto status = (expr); !status.ok()) return status; \
+ } while (false)
+
+namespace ink {
+namespace stroke_model {
+
+absl::Status ValidateInput(const Input& input) {
+ switch (input.event_type) {
+ case Input::EventType::kUp:
+ case Input::EventType::kMove:
+ case Input::EventType::kDown:
+ break;
+ default:
+ return absl::InvalidArgumentError("Unknown Input.event_type.");
+ }
+ RETURN_IF_ERROR(ValidateIsFiniteNumber(input.position.x, "Input.position.x"));
+ RETURN_IF_ERROR(ValidateIsFiniteNumber(input.position.y, "Input.position.y"));
+ RETURN_IF_ERROR(ValidateIsFiniteNumber(input.time.Value(), "Input.time"));
+ RETURN_IF_ERROR(ValidateIsFiniteNumber(input.pressure, "Input.pressure"));
+ RETURN_IF_ERROR(ValidateIsFiniteNumber(input.tilt, "Input.tilt"));
+ RETURN_IF_ERROR(
+ ValidateIsFiniteNumber(input.orientation, "Input.orientation"));
+ return absl::OkStatus();
+}
+
+} // namespace stroke_model
+} // namespace ink
diff --git a/ink_stroke_modeler/types.h b/ink_stroke_modeler/types.h
new file mode 100644
index 0000000..cc77565
--- /dev/null
+++ b/ink_stroke_modeler/types.h
@@ -0,0 +1,356 @@
+/*
+ * Copyright 2022 Google LLC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef INK_STROKE_MODELER_TYPES_H_
+#define INK_STROKE_MODELER_TYPES_H_
+
+#include <cmath>
+#include <ostream>
+
+#include "absl/status/status.h"
+
+namespace ink {
+namespace stroke_model {
+
+// A vector (or point) in 2D space.
+struct Vec2 {
+ float x = 0;
+ float y = 0;
+
+ // The length of the vector, i.e. its distance from the origin.
+ float Magnitude() const { return std::sqrt(x * x + y * y); }
+};
+
+bool operator==(Vec2 lhs, Vec2 rhs);
+bool operator!=(Vec2 lhs, Vec2 rhs);
+
+Vec2 operator+(Vec2 lhs, Vec2 rhs);
+Vec2 operator-(Vec2 lhs, Vec2 rhs);
+Vec2 operator*(float scalar, Vec2 v);
+Vec2 operator*(Vec2 v, float scalar);
+Vec2 operator/(Vec2 v, float scalar);
+
+Vec2 &operator+=(Vec2 &lhs, Vec2 rhs);
+Vec2 &operator-=(Vec2 &lhs, Vec2 rhs);
+Vec2 &operator*=(Vec2 &lhs, float scalar);
+Vec2 &operator/=(Vec2 &lhs, float scalar);
+
+std::ostream &operator<<(std::ostream &stream, Vec2 v);
+
+// This represents a duration of time, i.e. the difference between two points in
+// time (as represented by class Time, below). This class is unit-agnostic; it
+// could represent e.g. hours, seconds, or years.
+class Duration {
+ public:
+ Duration() : Duration(0) {}
+ explicit Duration(double value) : value_(value) {}
+ double Value() const { return value_; }
+
+ private:
+ double value_ = 0;
+};
+
+Duration operator+(Duration lhs, Duration rhs);
+Duration operator-(Duration lhs, Duration rhs);
+Duration operator*(Duration duration, double scalar);
+Duration operator*(double scalar, Duration duration);
+Duration operator/(Duration duration, double scalar);
+
+Duration &operator+=(Duration &lhs, Duration rhs);
+Duration &operator-=(Duration &lhs, Duration rhs);
+Duration &operator*=(Duration &duration, double scalar);
+Duration &operator/=(Duration &duration, double scalar);
+
+bool operator==(Duration lhs, Duration rhs);
+bool operator!=(Duration lhs, Duration rhs);
+bool operator<(Duration lhs, Duration rhs);
+bool operator>(Duration lhs, Duration rhs);
+bool operator<=(Duration lhs, Duration rhs);
+bool operator>=(Duration lhs, Duration rhs);
+
+std::ostream &operator<<(std::ostream &s, Duration duration);
+
+// This represents a point in time. This class is unit- and offset-agnostic; it
+// could be measured in e.g. hours, seconds, or years, and Time(0) has no
+// specific meaning outside of the context it is used in.
+class Time {
+ public:
+ Time() : Time(0) {}
+ explicit Time(double value) : value_(value) {}
+ double Value() const { return value_; }
+
+ private:
+ double value_ = 0;
+};
+
+Time operator+(Time time, Duration duration);
+Time operator+(Duration duration, Time time);
+Time operator-(Time time, Duration duration);
+Duration operator-(Time lhs, Time rhs);
+
+Time &operator+=(Time &time, Duration duration);
+Time &operator-=(Time &time, Duration duration);
+
+bool operator==(Time lhs, Time rhs);
+bool operator!=(Time lhs, Time rhs);
+bool operator<(Time lhs, Time rhs);
+bool operator>(Time lhs, Time rhs);
+bool operator<=(Time lhs, Time rhs);
+bool operator>=(Time lhs, Time rhs);
+
+std::ostream &operator<<(std::ostream &s, Time time);
+
+// The input passed to the stroke modeler.
+struct Input {
+ // The type of event represented by the input. A "kDown" event represents
+ // the beginning of the stroke, a "kUp" event represents the end of the
+ // stroke, and all events in between are "kMove" events.
+ enum class EventType { kDown, kMove, kUp };
+ EventType event_type;
+
+ // The position of the input.
+ Vec2 position{0};
+
+ // The time at which the input occurs.
+ Time time{0};
+
+ // The amount of pressure applied to the stylus. This is expected to lie in
+ // the range [0, 1]. A negative value indicates unknown pressure.
+ float pressure = -1;
+
+ // The angle between the stylus and the plane of the device's screen. This
+ // is expected to lie in the range [0, π/2]. A value of 0 indicates that the
+ // stylus is perpendicular to the screen, while a value of π/2 indicates
+ // that it is flush with the screen. A negative value indicates unknown
+ // tilt.
+ float tilt = -1;
+
+ // The angle between the projection of the stylus onto the screen and the
+ // positive x-axis, measured counter-clockwise. This is expected to lie in
+ // the range [0, 2π). A negative value indicates unknown orientation.
+ float orientation = -1;
+};
+
+bool operator==(const Input &lhs, const Input &rhs);
+bool operator!=(const Input &lhs, const Input &rhs);
+
+std::ostream &operator<<(std::ostream &s, Input::EventType event_type);
+std::ostream &operator<<(std::ostream &s, const Input &input);
+
+absl::Status ValidateInput(const Input &input);
+
+// A modeled input produced by the stroke modeler.
+struct Result {
+ // The position and velocity of the stroke tip.
+ Vec2 position{0};
+ Vec2 velocity{0};
+
+ // The time at which this input occurs.
+ Time time{0};
+
+ // These pressure, tilt, and orientation of the stylus. See the
+ // corresponding fields on the Input struct for more info.
+ float pressure = -1;
+ float tilt = -1;
+ float orientation = -1;
+};
+
+std::ostream &operator<<(std::ostream &s, const Result &result);
+
+////////////////////////////////////////////////////////////////////////////////
+// Inline function definitions
+////////////////////////////////////////////////////////////////////////////////
+
+inline bool operator==(Vec2 lhs, Vec2 rhs) {
+ return lhs.x == rhs.x && lhs.y == rhs.y;
+}
+inline bool operator!=(Vec2 lhs, Vec2 rhs) { return !(lhs == rhs); }
+
+inline Vec2 operator+(Vec2 lhs, Vec2 rhs) {
+ return {.x = lhs.x + rhs.x, .y = lhs.y + rhs.y};
+}
+inline Vec2 operator-(Vec2 lhs, Vec2 rhs) {
+ return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
+}
+inline Vec2 operator*(float scalar, Vec2 v) {
+ return {.x = scalar * v.x, .y = scalar * v.y};
+}
+inline Vec2 operator*(Vec2 v, float scalar) { return scalar * v; }
+inline Vec2 operator/(Vec2 v, float scalar) {
+ return {.x = v.x / scalar, .y = v.y / scalar};
+}
+
+inline Vec2 &operator+=(Vec2 &lhs, Vec2 rhs) {
+ lhs.x += rhs.x;
+ lhs.y += rhs.y;
+ return lhs;
+}
+inline Vec2 &operator-=(Vec2 &lhs, Vec2 rhs) {
+ lhs.x -= rhs.x;
+ lhs.y -= rhs.y;
+ return lhs;
+}
+inline Vec2 &operator*=(Vec2 &lhs, float scalar) {
+ lhs.x *= scalar;
+ lhs.y *= scalar;
+ return lhs;
+}
+inline Vec2 &operator/=(Vec2 &lhs, float scalar) {
+ lhs.x /= scalar;
+ lhs.y /= scalar;
+ return lhs;
+}
+
+inline std::ostream &operator<<(std::ostream &stream, Vec2 v) {
+ return stream << "(" << v.x << ", " << v.y << ")";
+}
+
+inline Duration operator+(Duration lhs, Duration rhs) {
+ return Duration(lhs.Value() + rhs.Value());
+}
+inline Duration operator-(Duration lhs, Duration rhs) {
+ return Duration(lhs.Value() - rhs.Value());
+}
+inline Duration operator*(Duration duration, double scalar) {
+ return Duration(duration.Value() * scalar);
+}
+inline Duration operator*(double scalar, Duration duration) {
+ return Duration(scalar * duration.Value());
+}
+inline Duration operator/(Duration duration, double scalar) {
+ return Duration(duration.Value() / scalar);
+}
+
+inline Duration &operator+=(Duration &lhs, Duration rhs) {
+ lhs = lhs + rhs;
+ return lhs;
+}
+inline Duration &operator-=(Duration &lhs, Duration rhs) {
+ lhs = lhs - rhs;
+ return lhs;
+}
+inline Duration &operator*=(Duration &duration, double scalar) {
+ duration = duration * scalar;
+ return duration;
+}
+inline Duration &operator/=(Duration &duration, double scalar) {
+ duration = duration / scalar;
+ return duration;
+}
+
+inline bool operator==(Duration lhs, Duration rhs) {
+ return lhs.Value() == rhs.Value();
+}
+inline bool operator!=(Duration lhs, Duration rhs) {
+ return lhs.Value() != rhs.Value();
+}
+inline bool operator<(Duration lhs, Duration rhs) {
+ return lhs.Value() < rhs.Value();
+}
+inline bool operator>(Duration lhs, Duration rhs) {
+ return lhs.Value() > rhs.Value();
+}
+inline bool operator<=(Duration lhs, Duration rhs) {
+ return lhs.Value() <= rhs.Value();
+}
+inline bool operator>=(Duration lhs, Duration rhs) {
+ return lhs.Value() >= rhs.Value();
+}
+
+inline Time operator+(Time time, Duration duration) {
+ return Time(time.Value() + duration.Value());
+}
+inline Time operator+(Duration duration, Time time) {
+ return Time(time.Value() + duration.Value());
+}
+inline Time operator-(Time time, Duration duration) {
+ return Time(time.Value() - duration.Value());
+}
+inline Duration operator-(Time lhs, Time rhs) {
+ return Duration(lhs.Value() - rhs.Value());
+}
+
+inline Time &operator+=(Time &time, Duration duration) {
+ time = time + duration;
+ return time;
+}
+inline Time &operator-=(Time &time, Duration duration) {
+ time = time - duration;
+ return time;
+}
+
+inline bool operator==(Time lhs, Time rhs) {
+ return lhs.Value() == rhs.Value();
+}
+inline bool operator!=(Time lhs, Time rhs) {
+ return lhs.Value() != rhs.Value();
+}
+inline bool operator<(Time lhs, Time rhs) { return lhs.Value() < rhs.Value(); }
+inline bool operator>(Time lhs, Time rhs) { return lhs.Value() > rhs.Value(); }
+inline bool operator<=(Time lhs, Time rhs) {
+ return lhs.Value() <= rhs.Value();
+}
+inline bool operator>=(Time lhs, Time rhs) {
+ return lhs.Value() >= rhs.Value();
+}
+
+inline bool operator==(const Input &lhs, const Input &rhs) {
+ return lhs.event_type == rhs.event_type && lhs.position == rhs.position &&
+ lhs.time == rhs.time && lhs.pressure == rhs.pressure &&
+ lhs.tilt == rhs.tilt && lhs.orientation == rhs.orientation;
+}
+inline bool operator!=(const Input &lhs, const Input &rhs) {
+ return !(lhs == rhs);
+}
+
+inline std::ostream &operator<<(std::ostream &s, Duration duration) {
+ return s << duration.Value();
+}
+
+inline std::ostream &operator<<(std::ostream &s, Time time) {
+ return s << time.Value();
+}
+
+inline std::ostream &operator<<(std::ostream &s, Input::EventType event_type) {
+ switch (event_type) {
+ case Input::EventType::kDown:
+ return s << "Down";
+ case Input::EventType::kMove:
+ return s << "Move";
+ case Input::EventType::kUp:
+ return s << "Up";
+ }
+ return s << "UnknownEventType<" << event_type << ">";
+}
+
+inline std::ostream &operator<<(std::ostream &s, const Input &input) {
+ return s << "<Input: " << input.event_type << ", pos: " << input.position
+ << ", time: " << input.time << ", pressure: " << input.pressure
+ << ", tilt: " << input.tilt << ", orientation: " << input.orientation
+ << ">";
+}
+
+inline std::ostream &operator<<(std::ostream &s, const Result &result) {
+ return s << "<Result: pos: " << result.position
+ << ", velocity: " << result.velocity << ", time: " << result.time
+ << ", pressure: " << result.pressure << ", tilt: " << result.tilt
+ << ", orientation: " << result.orientation << ">";
+}
+
+} // namespace stroke_model
+} // namespace ink
+
+#endif // INK_STROKE_MODELER_TYPES_H_
diff --git a/ink_stroke_modeler/types_test.cc b/ink_stroke_modeler/types_test.cc
new file mode 100644
index 0000000..e812d76
--- /dev/null
+++ b/ink_stroke_modeler/types_test.cc
@@ -0,0 +1,212 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/types.h"
+
+#include <cmath>
+#include <limits>
+#include <sstream>
+#include <string>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "ink_stroke_modeler/internal/type_matchers.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+using ::testing::Not;
+
+TEST(TypesTest, Vec2Equality) {
+ EXPECT_EQ((Vec2{1, 2}), (Vec2{1, 2}));
+ EXPECT_EQ((Vec2{-.4, 17}), (Vec2{-.4, 17}));
+
+ EXPECT_NE((Vec2{1, 2}), (Vec2{1, 5}));
+ EXPECT_NE((Vec2{3, 2}), (Vec2{.7, 2}));
+ EXPECT_NE((Vec2{-4, .3}), (Vec2{.5, 12}));
+}
+
+TEST(TypesTest, Vec2EqMatcher) {
+ EXPECT_THAT((Vec2{1, 2}), Vec2Eq({1, 2}));
+ EXPECT_THAT((Vec2{3, 4}), Not(Vec2Eq({3, 5})));
+ EXPECT_THAT((Vec2{5, 6}), Not(Vec2Eq({4, 6})));
+
+ // Vec2Eq delegates to FloatEq, which uses a tolerance of 4 ULP.
+ constexpr float kEps = std::numeric_limits<float>::epsilon();
+ EXPECT_THAT((Vec2{1, 1}), Vec2Eq({1 + kEps, 1 - kEps}));
+}
+
+TEST(TypesTest, Vec2NearMatcher) {
+ EXPECT_THAT((Vec2{1, 2}), Vec2Near({1.05, 1.95}, .1));
+ EXPECT_THAT((Vec2{3, 4}), Not(Vec2Near({3, 5}, .5)));
+ EXPECT_THAT((Vec2{5, 6}), Not(Vec2Near({4, 6}, .5)));
+}
+
+TEST(TypesTest, Vec2Magnitude) {
+ EXPECT_FLOAT_EQ((Vec2{1, 1}).Magnitude(), std::sqrt(2));
+ EXPECT_FLOAT_EQ((Vec2{-3, 4}).Magnitude(), 5);
+ EXPECT_FLOAT_EQ((Vec2{0, 0}).Magnitude(), 0);
+ EXPECT_FLOAT_EQ((Vec2{0, 17}).Magnitude(), 17);
+}
+
+TEST(TypesTest, Vec2Addition) {
+ Vec2 a{3, 0};
+ Vec2 b{-1, .3};
+ Vec2 c{2.7, 4};
+
+ EXPECT_THAT(a + b, Vec2Eq({2, .3}));
+ EXPECT_THAT(a + c, Vec2Eq({5.7, 4}));
+ EXPECT_THAT(b + c, Vec2Eq({1.7, 4.3}));
+}
+
+TEST(TypesTest, Vec2Subtraction) {
+ Vec2 a{0, -2};
+ Vec2 b{.5, 19};
+ Vec2 c{1.1, -3.4};
+
+ EXPECT_THAT(a - b, Vec2Eq({-.5, -21}));
+ EXPECT_THAT(a - c, Vec2Eq({-1.1, 1.4}));
+ EXPECT_THAT(b - c, Vec2Eq({-.6, 22.4}));
+}
+
+TEST(TypesTest, Vec2ScalarMultiplication) {
+ Vec2 a{.7, -3};
+ Vec2 b{3, 5};
+
+ EXPECT_THAT(a * 2, Vec2Eq({1.4, -6}));
+ EXPECT_THAT(.1 * a, Vec2Eq({.07, -.3}));
+ EXPECT_THAT(b * -.3, Vec2Eq({-.9, -1.5}));
+ EXPECT_THAT(4 * b, Vec2({12, 20}));
+}
+
+TEST(TypesTest, Vec2ScalarDivision) {
+ Vec2 a{7, .9};
+ Vec2 b{-4.5, -2};
+
+ EXPECT_THAT(a / 2, Vec2Eq({3.5, .45}));
+ EXPECT_THAT(a / -.1, Vec2Eq({-70, -9}));
+ EXPECT_THAT(b / 5, Vec2Eq({-.9, -.4}));
+ EXPECT_THAT(b / .2, Vec2Eq({-22.5, -10}));
+}
+
+TEST(TypesTest, Vec2AddAssign) {
+ Vec2 a{1, 2};
+ a += {.x = 3, .y = -1};
+ EXPECT_THAT(a, Vec2Eq({4, 1}));
+ a += {.x = -.5, .y = 2};
+ EXPECT_THAT(a, Vec2Eq({3.5, 3}));
+}
+
+TEST(TypesTest, Vec2SubractAssign) {
+ Vec2 a{1, 2};
+ a -= {.x = 3, .y = -1};
+ EXPECT_THAT(a, Vec2Eq({-2, 3}));
+ a -= {.x = -.5, .y = 2};
+ EXPECT_THAT(a, Vec2Eq({-1.5, 1}));
+}
+
+TEST(TypesTest, Vec2ScalarMultiplyAssign) {
+ Vec2 a{1, 2};
+ a *= 2;
+ EXPECT_THAT(a, Vec2Eq({2, 4}));
+ a *= -.4;
+ EXPECT_THAT(a, Vec2Eq({-.8, -1.6}));
+}
+
+TEST(TypesTest, Vec2ScalarDivideAssign) {
+ Vec2 a{1, 2};
+ a /= 2;
+ EXPECT_THAT(a, Vec2Eq({.5, 1}));
+ a /= -.4;
+ EXPECT_THAT(a, Vec2Eq({-1.25, -2.5}));
+}
+
+TEST(TypesTest, Vec2Stream) {
+ std::stringstream s;
+ s << Vec2{3.5, -2.7};
+ EXPECT_EQ(s.str(), "(3.5, -2.7)");
+}
+
+TEST(TypesTest, DurationArithmetic) {
+ EXPECT_EQ(Duration(1) + Duration(2), Duration(3));
+ EXPECT_EQ(Duration(6) - Duration(1), Duration(5));
+ EXPECT_EQ(Duration(3) * 2, Duration(6));
+ EXPECT_EQ(.5 * Duration(7), Duration(3.5));
+ EXPECT_EQ(Duration(12) / 4, Duration(3));
+}
+
+TEST(TypesTest, DurationArithmeticAssignment) {
+ Duration d{5};
+ d += Duration(2);
+ EXPECT_EQ(d, Duration(7));
+ d -= Duration(3);
+ EXPECT_EQ(d, Duration(4));
+ d *= 5;
+ EXPECT_EQ(d, Duration(20));
+ d /= 2;
+ EXPECT_EQ(d, Duration(10));
+}
+
+TEST(TypesTest, DurationComparison) {
+ EXPECT_EQ(Duration(0), Duration(0));
+ EXPECT_NE(Duration(0), Duration(1));
+ EXPECT_LT(Duration(1), Duration(2));
+ EXPECT_LE(Duration(4), Duration(5));
+ EXPECT_LE(Duration(2), Duration(2));
+ EXPECT_GT(Duration(10), Duration(9));
+ EXPECT_GE(Duration(7), Duration(5));
+ EXPECT_GE(Duration(5), Duration(5));
+}
+
+TEST(TypesTest, TimeArithmetic) {
+ EXPECT_EQ(Time(5) + Duration(1), Time(6));
+ EXPECT_EQ(Duration(7) + Time(12), Time(19));
+ EXPECT_EQ(Time(23) - Duration(5), Time(18));
+ EXPECT_EQ(Time(35) - Time(7), Duration(28));
+}
+
+TEST(TypesTest, TimeArithmeticAssignment) {
+ Time t{20};
+ t += Duration(10);
+ EXPECT_EQ(t, Time(30));
+ t -= Duration(24);
+ EXPECT_EQ(t, Time(6));
+}
+
+TEST(TypesTest, TimeComparison) {
+ EXPECT_EQ(Time(0), Time(0));
+ EXPECT_NE(Time(0), Time(1));
+ EXPECT_LT(Time(1), Time(2));
+ EXPECT_LE(Time(4), Time(5));
+ EXPECT_LE(Time(2), Time(2));
+ EXPECT_GT(Time(10), Time(9));
+ EXPECT_GE(Time(7), Time(5));
+ EXPECT_GE(Time(5), Time(5));
+}
+
+TEST(TypesTest, InputEquality) {
+ const Input kBaseline{};
+ EXPECT_EQ(kBaseline, Input());
+ EXPECT_NE(kBaseline, Input{.event_type = Input::EventType::kMove});
+ EXPECT_NE(kBaseline, (Input{.position = {1, -1}}));
+ EXPECT_NE(kBaseline, Input{.time = Time(1)});
+ EXPECT_NE(kBaseline, Input{.pressure = .5});
+ EXPECT_NE(kBaseline, Input{.tilt = .2});
+ EXPECT_NE(kBaseline, Input{.orientation = .7});
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink