diff options
Diffstat (limited to 'ink_stroke_modeler')
46 files changed, 7215 insertions, 0 deletions
diff --git a/ink_stroke_modeler/BUILD.bazel b/ink_stroke_modeler/BUILD.bazel new file mode 100644 index 0000000..d48f1ea --- /dev/null +++ b/ink_stroke_modeler/BUILD.bazel @@ -0,0 +1,97 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +cc_library( + name = "params", + srcs = ["params.cc"], + hdrs = ["params.h"], + visibility = ["//visibility:public"], + deps = [ + ":types", + "//ink_stroke_modeler/internal:validation", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "params_test", + srcs = ["params_test.cc"], + deps = [ + ":params", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "types", + srcs = ["types.cc"], + hdrs = ["types.h"], + visibility = ["//visibility:public"], + deps = [ + "//ink_stroke_modeler/internal:validation", + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "types_test", + srcs = ["types_test.cc"], + deps = [ + ":types", + "//ink_stroke_modeler/internal:type_matchers", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "stroke_modeler", + srcs = ["stroke_modeler.cc"], + hdrs = ["stroke_modeler.h"], + deps = [ + ":params", + ":types", + "//ink_stroke_modeler/internal:internal_types", + "//ink_stroke_modeler/internal:position_modeler", + "//ink_stroke_modeler/internal:stylus_state_modeler", + "//ink_stroke_modeler/internal:wobble_smoother", + "//ink_stroke_modeler/internal/prediction:input_predictor", + "//ink_stroke_modeler/internal/prediction:kalman_predictor", + "//ink_stroke_modeler/internal/prediction:stroke_end_predictor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:variant", + ], +) + +cc_test( + name = "stroke_modeler_test", + srcs = ["stroke_modeler_test.cc"], + deps = [ + ":params", + ":stroke_modeler", + "//ink_stroke_modeler/internal:type_matchers", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/ink_stroke_modeler/CMakeLists.txt b/ink_stroke_modeler/CMakeLists.txt new file mode 100644 index 0000000..736663a --- /dev/null +++ b/ink_stroke_modeler/CMakeLists.txt @@ -0,0 +1,102 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(internal) + +ink_cc_library( + NAME + params + SRCS + params.cc + HDRS + params.h + DEPS + InkStrokeModeler::types + absl::status + absl::strings + absl::variant + InkStrokeModeler::validation +) + +ink_cc_test( + NAME + params_test + SRCS + params_test.cc + DEPS + InkStrokeModeler::params + absl::status + GTest::gtest_main +) + +ink_cc_library( + NAME + types + SRCS + types.cc + HDRS + types.h + DEPS + absl::status + InkStrokeModeler::validation +) + +ink_cc_test( + NAME + types_test + SRCS + types_test.cc + DEPS + InkStrokeModeler::types + InkStrokeModeler::type_matchers + GTest::gmock_main +) + +ink_cc_library( + NAME + stroke_modeler + SRCS + stroke_modeler.cc + HDRS + stroke_modeler.h + DEPS + InkStrokeModeler::params + InkStrokeModeler::types + InkStrokeModeler::internal_types + InkStrokeModeler::position_modeler + InkStrokeModeler::stylus_state_modeler + InkStrokeModeler::wobble_smoother + InkStrokeModeler::input_predictor + InkStrokeModeler::kalman_predictor + InkStrokeModeler::stroke_end_predictor + absl::core_headers + absl::memory + absl::status + absl::statusor + absl::optional + absl::variant +) + +ink_cc_test( + NAME + stroke_modeler_test + SRCS + stroke_modeler_test.cc + DEPS + InkStrokeModeler::params + InkStrokeModeler::stroke_modeler + InkStrokeModeler::type_matchers + absl::status + GTest::gmock_main +) diff --git a/ink_stroke_modeler/internal/BUILD.bazel b/ink_stroke_modeler/internal/BUILD.bazel new file mode 100644 index 0000000..f4e1d98 --- /dev/null +++ b/ink_stroke_modeler/internal/BUILD.bazel @@ -0,0 +1,143 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//ink_stroke_modeler:__subpackages__"], +) + +licenses(["notice"]) + +cc_library( + name = "internal_types", + hdrs = ["internal_types.h"], + deps = ["//ink_stroke_modeler:types"], +) + +cc_library( + name = "type_matchers", + testonly = 1, + srcs = ["type_matchers.cc"], + hdrs = ["type_matchers.h"], + deps = [ + ":internal_types", + "//:gtest_for_library_testonly", + "//ink_stroke_modeler:types", + ], +) + +cc_library( + name = "utils", + hdrs = ["utils.h"], + deps = ["//ink_stroke_modeler:types"], +) + +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + deps = [ + ":type_matchers", + ":utils", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "wobble_smoother", + srcs = ["wobble_smoother.cc"], + hdrs = ["wobble_smoother.h"], + deps = [ + ":utils", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + ], +) + +cc_test( + name = "wobble_smoother_test", + srcs = ["wobble_smoother_test.cc"], + deps = [ + ":type_matchers", + ":wobble_smoother", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "position_modeler", + hdrs = ["position_modeler.h"], + deps = [ + ":internal_types", + ":utils", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + ], +) + +cc_test( + name = "position_modeler_test", + srcs = ["position_modeler_test.cc"], + deps = [ + ":internal_types", + ":position_modeler", + ":type_matchers", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "stylus_state_modeler", + srcs = ["stylus_state_modeler.cc"], + hdrs = ["stylus_state_modeler.h"], + deps = [ + ":internal_types", + ":utils", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + ], +) + +cc_test( + name = "stylus_state_modeler_test", + srcs = ["stylus_state_modeler_test.cc"], + deps = [ + ":internal_types", + ":stylus_state_modeler", + ":type_matchers", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "validation", + hdrs = ["validation.h"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + ], +) + +cc_test( + name = "validaiton_test", + srcs = ["validation_test.cc"], + deps = [ + ":validation", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/ink_stroke_modeler/internal/CMakeLists.txt b/ink_stroke_modeler/internal/CMakeLists.txt new file mode 100644 index 0000000..b1886df --- /dev/null +++ b/ink_stroke_modeler/internal/CMakeLists.txt @@ -0,0 +1,156 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(prediction) + +ink_cc_library( + NAME + internal_types + HDRS + internal_types.h + DEPS + InkStrokeModeler::types +) + +ink_cc_library( + NAME + type_matchers + SRCS + type_matchers.cc + HDRS + type_matchers.h + DEPS + InkStrokeModeler::internal_types + InkStrokeModeler::types + GTest::gtest_main +) + +ink_cc_library( + NAME + utils + HDRS + utils.h + DEPS + InkStrokeModeler::types +) + +ink_cc_test( + NAME + utils_test + SRCS + utils_test.cc + DEPS + InkStrokeModeler::type_matchers + InkStrokeModeler::utils + GTest::gmock_main +) + +ink_cc_library( + NAME + wobble_smoother + SRCS + wobble_smoother.cc + HDRS + wobble_smoother.h + DEPS + InkStrokeModeler::utils + InkStrokeModeler::params + InkStrokeModeler::types +) + +ink_cc_test( + NAME + wobble_smoother_test + SRCS + wobble_smoother_test.cc + DEPS + InkStrokeModeler::type_matchers + InkStrokeModeler::wobble_smoother + InkStrokeModeler::params + InkStrokeModeler::types + GTest::gmock_main +) + +ink_cc_library( + NAME + position_modeler + HDRS + position_modeler.h + DEPS + InkStrokeModeler::internal_types + InkStrokeModeler::utils + InkStrokeModeler::params + InkStrokeModeler::types +) + +ink_cc_test( + NAME + position_modeler_test + SRCS + position_modeler_test.cc + DEPS + InkStrokeModeler::internal_types + InkStrokeModeler::position_modeler + InkStrokeModeler::type_matchers + InkStrokeModeler::params + InkStrokeModeler::types + GTest::gmock_main +) + +ink_cc_library( + NAME + stylus_state_modeler + SRCS + stylus_state_modeler.cc + HDRS + stylus_state_modeler.h + DEPS + InkStrokeModeler::internal_types + InkStrokeModeler::utils + InkStrokeModeler::params +) + +ink_cc_test( + NAME + stylus_state_modeler_test + SRCS + stylus_state_modeler_test.cc + DEPS + InkStrokeModeler::internal_types + InkStrokeModeler::stylus_state_modeler + InkStrokeModeler::type_matchers + InkStrokeModeler::params + InkStrokeModeler::types + GTest::gmock_main +) + +ink_cc_library( + NAME + validation + HDRS + validation.h + DEPS + absl::status + absl::strings +) + +ink_cc_test( + NAME + validation_test + SRCS + validation_test.cc + DEPS + InkStrokeModeler::validation + GTest::gtest_main +) diff --git a/ink_stroke_modeler/internal/internal_types.h b/ink_stroke_modeler/internal/internal_types.h new file mode 100644 index 0000000..1c6b671 --- /dev/null +++ b/ink_stroke_modeler/internal/internal_types.h @@ -0,0 +1,73 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_INTERNAL_TYPES_H_ +#define INK_STROKE_MODELER_INTERNAL_INTERNAL_TYPES_H_ + +#include <ostream> + +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// This struct contains the position and velocity of the modeled pen tip at +// the indicated time. +struct TipState { + Vec2 position{0}; + Vec2 velocity{0}; + Time time{0}; +}; + +std::ostream &operator<<(std::ostream &s, const TipState &tip_state); + +// This struct contains information about the state of the stylus. See the +// corresponding fields on the Input struct for more info. +struct StylusState { + float pressure = -1; + float tilt = -1; + float orientation = -1; +}; + +bool operator==(const StylusState &lhs, const StylusState &rhs); +std::ostream &operator<<(std::ostream &s, const StylusState &stylus_state); + +//////////////////////////////////////////////////////////////////////////////// +// Inline function definitions +//////////////////////////////////////////////////////////////////////////////// + +inline bool operator==(const StylusState &lhs, const StylusState &rhs) { + return lhs.pressure == rhs.pressure && lhs.tilt == rhs.tilt && + lhs.orientation == rhs.orientation; +} + +inline std::ostream &operator<<(std::ostream &s, const TipState &tip_state) { + return s << "TipState: pos: " << tip_state.position + << ", velocity: " << tip_state.velocity + << ", time: " << tip_state.time << ">"; +} + +inline std::ostream &operator<<(std::ostream &s, + const StylusState &stylus_state) { + return s << "<Result: pressure: " << stylus_state.pressure + << ", tilt: " << stylus_state.tilt + << ", orientation: " << stylus_state.orientation << ">"; +} + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_INTERNAL_TYPES_H_ diff --git a/ink_stroke_modeler/internal/position_modeler.h b/ink_stroke_modeler/internal/position_modeler.h new file mode 100644 index 0000000..c3276e8 --- /dev/null +++ b/ink_stroke_modeler/internal/position_modeler.h @@ -0,0 +1,138 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_POSITION_MODELER_H_ +#define INK_STROKE_MODELER_INTERNAL_POSITION_MODELER_H_ + +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/internal/utils.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// This class models the movement of the pen tip based on the laws of motion. +// The pen tip is represented as a mass, connected by a spring to a moving +// anchor; as the anchor moves, it drags the pen tip along behind it. +class PositionModeler { + public: + void Reset(const TipState& state, PositionModelerParams params) { + state_ = state; + params_ = params; + } + + // Given the position of the anchor and the time, updates the model and + // returns the state of the pen tip. + TipState Update(Vec2 anchor_position, Time time) { + Duration delta_time = time - state_.time; + float float_delta = delta_time.Value(); + auto acceleration = + (anchor_position - state_.position) / params_.spring_mass_constant - + params_.drag_constant * state_.velocity; + state_.velocity += float_delta * acceleration; + state_.position += float_delta * state_.velocity; + state_.time = time; + + return state_; + } + + const TipState& CurrentState() const { return state_; } + const PositionModelerParams& Params() const { return params_; } + + // This helper function linearly interpolates between the between the start + // and end anchor position and time, updating the model at each step and + // storing the result in the given output iterator. + // + // NOTE: Because the expected use case is to repeatedly call this function on + // a sequence of anchor positions/times, the start position/time is not sent + // to the model. This prevents us from duplicating those inputs, but it does + // mean that the first input must be provided on its own, via either Reset() + // or Update(). This also means that the interpolation values are + // (1 ... n) / n, as opposed to (0 ... (n - 1)) / (n - 1). + // + // Template parameter OutputIt is expected to be an output iterator over + // TipState. + template <typename OutputIt> + void UpdateAlongLinearPath(Vec2 start_anchor_position, Time start_time, + Vec2 end_anchor_position, Time end_time, + int n_samples, OutputIt output) { + for (int i = 1; i <= n_samples; ++i) { + auto interp_value = static_cast<float>(i) / n_samples; + auto position = + Interp(start_anchor_position, end_anchor_position, interp_value); + auto time = Interp(start_time, end_time, interp_value); + *output++ = Update(position, time); + } + } + + // This helper function models the end of the stroke, by repeatedly updating + // with the final anchor position. It attempts to stop at the closest point to + // the anchor, by checking if it has overshot, and retrying with successively + // smaller time steps. + // + // It halts when any of these three conditions is met: + // - It has taken more than max_iterations steps (including discarded steps) + // - The distance between the current state and the anchor is less than + // stop_distance + // - The distance between the previous state and the current state is less + // than stop_distance + // + // Template parameter OutputIt is expected to be an output iterator over + // TipState. + template <typename OutputIt> + void ModelEndOfStroke(Vec2 anchor_position, Duration delta_time, + int max_iterations, float stop_distance, + OutputIt output) { + for (int i = 0; i < max_iterations; ++i) { + // The call to Update modifies the state, so we store a copy of the + // previous state so we can retry with a smaller step if necessary. + const TipState previous_state = state_; + TipState candidate = + Update(anchor_position, previous_state.time + delta_time); + if (Distance(previous_state.position, candidate.position) < + stop_distance) { + // We're no longer making any significant progress, which means that + // we're about as close as we can get without looping around. + return; + } + + float closest_t = NearestPointOnSegment( + previous_state.position, candidate.position, anchor_position); + if (closest_t < 1) { + // We're overshot the anchor, retry with a smaller step. + delta_time *= .5; + state_ = previous_state; + continue; + } + *output++ = candidate; + + if (Distance(candidate.position, anchor_position) < stop_distance) { + // We're within tolerance of the anchor. + return; + } + } + } + + private: + PositionModelerParams params_; + TipState state_; +}; + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_POSITION_MODELER_H_ diff --git a/ink_stroke_modeler/internal/position_modeler_test.cc b/ink_stroke_modeler/internal/position_modeler_test.cc new file mode 100644 index 0000000..7898ea1 --- /dev/null +++ b/ink_stroke_modeler/internal/position_modeler_test.cc @@ -0,0 +1,317 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/position_modeler.h" + +#include <cmath> +#include <iterator> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/internal/type_matchers.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { +namespace { + +using ::testing::ElementsAre; + +const Duration kDefaultTimeStep(1. / 180); +constexpr float kTol = .00005; + +// The expected position values are taken directly from results the old +// TipDynamics class. The expected velocity values are from the same source, but +// a multiplier of 300 (i.e. dt / (1 - drag)) had to be applied to account for +// the fact that PositionModeler uses the time step correctly. + +TEST(PositionModelerTest, StraightLine) { + PositionModeler modeler; + Time current_time(0); + modeler.Reset({{0, 0}, {0, 0}, current_time}, PositionModelerParams()); + + current_time += kDefaultTimeStep; + EXPECT_THAT(modeler.Update({1, 0}, current_time), + TipStateNear({{.0909, 0}, {16.3636, 0}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT(modeler.Update({2, 0}, current_time), + TipStateNear({{.319, 0}, {41.0579, 0}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT(modeler.Update({3, 0}, current_time), + TipStateNear({{.6996, 0}, {68.5055, 0}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT(modeler.Update({4, 0}, current_time), + TipStateNear({{1.228, 0}, {95.1099, 0}, current_time}, kTol)); +} + +TEST(PositionModelerTest, ZigZag) { + PositionModeler modeler; + Time current_time(3); + modeler.Reset({{-1, -1}, {0, 0}, current_time}, PositionModelerParams()); + + current_time += kDefaultTimeStep; + EXPECT_THAT(modeler.Update({-.5, -1}, current_time), + TipStateNear({{-.9545, -1}, {8.1818, 0}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({-.5, -.5}, current_time), + TipStateNear({{-.886, -.9545}, {12.3471, 8.1818}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({0, -.5}, current_time), + TipStateNear({{-.7643, -.886}, {21.9056, 12.3471}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({0, 0}, current_time), + TipStateNear({{-.6218, -.7643}, {25.6493, 21.9056}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({.5, 0}, current_time), + TipStateNear({{-.4343, -.6218}, {33.7456, 25.6493}, current_time}, kTol)); +} + +TEST(PositionModelerTest, SharpTurn) { + PositionModeler modeler; + Time current_time(1.6); + modeler.Reset({{0, 0}, {0, 0}, current_time}, PositionModelerParams()); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({.25, .25}, current_time), + TipStateNear({{.0227, .0227}, {4.0909, 4.0909}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({.5, .5}, current_time), + TipStateNear({{.0798, .0798}, {10.2645, 10.2645}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({.75, .75}, current_time), + TipStateNear({{.1749, .1749}, {17.1264, 17.1264}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({1, 1}, current_time), + TipStateNear({{.307, .307}, {23.7775, 23.7775}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({1.25, .75}, current_time), + TipStateNear({{.472, .4265}, {29.6975, 21.5157}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({1.5, .5}, current_time), + TipStateNear({{.6644, .5049}, {34.6406, 14.1117}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({1.75, .25}, current_time), + TipStateNear({{.8786, .5288}, {38.5482, 4.2955}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update({2, 0}, current_time), + TipStateNear({{1.109, .495}, {41.4794, -6.0756}, current_time}, kTol)); +} + +TEST(PositionModelerTest, SmoothTurn) { + auto point_on_circle = [](float theta) { + return Vec2{std::cos(theta), std::sin(theta)}; + }; + + PositionModeler modeler; + Time current_time(10.1); + modeler.Reset({point_on_circle(0), {0, 0}, current_time}, + PositionModelerParams()); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update(point_on_circle(M_PI * .125), current_time), + TipStateNear({{.9931, .0348}, {-1.2456, 6.2621}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update(point_on_circle(M_PI * .25), current_time), + TipStateNear({{0.9629, 0.1168}, {-5.4269, 14.7588}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT(modeler.Update(point_on_circle(M_PI * .375), current_time), + TipStateNear( + {{0.8921, 0.2394}, {-12.7511, 22.0623}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT(modeler.Update(point_on_circle(M_PI * .5), current_time), + TipStateNear( + {{0.7685, 0.3820}, {-22.2485, 25.6844}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT(modeler.Update(point_on_circle(M_PI * .625), current_time), + TipStateNear( + {{0.5897, 0.5169}, {-32.1865, 24.2771}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT(modeler.Update(point_on_circle(M_PI * .75), current_time), + TipStateNear( + {{0.3645, 0.6151}, {-40.5319, 17.6785}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update(point_on_circle(M_PI * .875), current_time), + TipStateNear({{0.1123, 0.6529}, {-45.4017, 6.8034}, current_time}, kTol)); + + current_time += kDefaultTimeStep; + EXPECT_THAT( + modeler.Update(point_on_circle(M_PI), current_time), + TipStateNear({{-0.1402, 0.6162}, {-45.4417, -6.6022}, current_time}, + kTol)); +} + +TEST(PositionModelerTest, UpdateAlongLinearPath) { + PositionModeler modeler; + modeler.Reset({{5, 10}, {0, 0}, Time{3}}, PositionModelerParams()); + + std::vector<TipState> result; + modeler.UpdateAlongLinearPath({5, 10}, Time{3}, {15, 10}, Time{3.05}, 5, + std::back_inserter(result)); + EXPECT_THAT( + result, + ElementsAre( + TipStateNear({{5.5891, 10}, {58.9091, 0}, Time{3.01}}, kTol), + TipStateNear({{6.7587, 10}, {116.9613, 0}, Time{3.02}}, kTol), + TipStateNear({{8.3355, 10}, {157.6746, 0}, Time{3.03}}, kTol), + TipStateNear({{10.1509, 10}, {181.5411, 0}, Time{3.04}}, kTol), + TipStateNear({{12.0875, 10}, {193.6607, 0}, Time{3.05}}, kTol))); + + result.clear(); + modeler.UpdateAlongLinearPath({15, 10}, Time{3.05}, {15, 16}, Time{3.08}, 3, + std::back_inserter(result)); + EXPECT_THAT( + result, + ElementsAre( + TipStateNear({{13.4876, 10.5891}, {140.0123, 58.9091}, Time{3.06}}, + kTol), + TipStateNear({{14.3251, 11.7587}, {83.7508, 116.9613}, Time{3.07}}, + kTol), + TipStateNear({{14.7584, 13.3355}, {43.3291, 157.6746}, Time{3.08}}, + kTol))); +} + +TEST(PositionModelerTest, ModelEndOfStrokeStationary) { + PositionModeler modeler; + modeler.Reset({{4, -2}, {0, 0}, Time{0}}, PositionModelerParams()); + + std::vector<TipState> result; + modeler.ModelEndOfStroke({3, -1}, Duration(1. / 180), 20, .01, + std::back_inserter(result)); + EXPECT_THAT( + result, + ElementsAre( + TipStateNear({{3.9091, -1.9091}, {-16.3636, 16.3636}, Time{0.0056}}, + kTol), + TipStateNear({{3.7719, -1.7719}, {-24.6942, 24.6942}, Time{0.0111}}, + kTol), + TipStateNear({{3.6194, -1.6194}, {-27.4476, 27.4476}, Time{0.0167}}, + kTol), + TipStateNear({{3.4716, -1.4716}, {-26.6045, 26.6044}, Time{0.0222}}, + kTol), + TipStateNear({{3.3401, -1.3401}, {-23.6799, 23.6799}, Time{0.0278}}, + kTol), + TipStateNear({{3.2302, -1.2302}, {-19.7725, 19.7725}, Time{0.0333}}, + kTol), + TipStateNear({{3.1434, -1.1434}, {-15.6306, 15.6306}, Time{0.0389}}, + kTol), + TipStateNear({{3.0782, -1.0782}, {-11.7244, 11.7244}, Time{0.0444}}, + kTol), + TipStateNear({{3.0320, -1.0320}, {-8.3149, 8.3149}, Time{0.0500}}, + kTol), + TipStateNear({{3.0014, -1.0014}, {-5.5133, 5.5133}, Time{0.0556}}, + kTol))); +} + +TEST(PositionModelerTest, ModelEndOfStrokeInMotion) { + PositionModeler modeler; + modeler.Reset({{-1, 2}, {40, 10}, Time{1}}, PositionModelerParams()); + + std::vector<TipState> result; + modeler.ModelEndOfStroke({7, 2}, Duration(1. / 120), 20, .01, + std::back_inserter(result)); + EXPECT_THAT( + result, + ElementsAre( + TipStateNear({{0.7697, 2.0333}, {212.3636, 4.0000}, Time{1.0083}}, + kTol), + TipStateNear({{2.7520, 2.0398}, {237.8711, 0.7818}, Time{1.0167}}, + kTol), + TipStateNear({{4.4138, 2.0343}, {199.4186, -0.6654}, Time{1.0250}}, + kTol), + TipStateNear({{5.6075, 2.0251}, {143.2474, -1.1081}, Time{1.0333}}, + kTol), + TipStateNear({{6.3698, 2.0162}, {91.4784, -1.0586}, Time{1.0417}}, + kTol), + TipStateNear({{6.8037, 2.0094}, {52.0592, -0.8222}, Time{1.0500}}, + kTol), + TipStateNear({{6.9655, 2.0065}, {38.8512, -0.6909}, Time{1.0542}}, + kTol), + TipStateNear({{6.9850, 2.0062}, {37.4471, -0.6750}, Time{1.0547}}, + kTol))); +} + +TEST(PositionModelerTest, ModelEndOfStrokeMaxIterationsReached) { + PositionModeler modeler; + modeler.Reset({{8, -3}, {-100, -150}, Time{1}}, PositionModelerParams()); + + std::vector<TipState> result; + modeler.ModelEndOfStroke({-9, -10}, Duration(.0001), 10, .001, + std::back_inserter(result)); + EXPECT_THAT( + result, + ElementsAre( + TipStateNear( + {{7.9896, -3.0151}, {-104.2873, -150.9818}, Time{1.0001}}, kTol), + TipStateNear( + {{7.9787, -3.0303}, {-108.5406, -151.9521}, Time{1.0002}}, kTol), + TipStateNear( + {{7.9674, -3.0456}, {-112.7601, -152.9110}, Time{1.0003}}, kTol), + TipStateNear( + {{7.9557, -3.0610}, {-116.9459, -153.8584}, Time{1.0004}}, kTol), + TipStateNear( + {{7.9436, -3.0764}, {-121.0982, -154.7945}, Time{1.0005}}, kTol), + TipStateNear( + {{7.9311, -3.0920}, {-125.2169, -155.7193}, Time{1.0006}}, kTol), + TipStateNear( + {{7.9182, -3.1077}, {-129.3023, -156.6328}, Time{1.0007}}, kTol), + TipStateNear( + {{7.9048, -3.1234}, {-133.3545, -157.5351}, Time{1.0008}}, kTol), + TipStateNear( + {{7.8911, -3.1393}, {-137.3736, -158.4263}, Time{1.0009}}, kTol), + TipStateNear( + {{7.8770, -3.1552}, {-141.3597, -159.3065}, Time{1.0010}}, + kTol))); +} + +} // namespace +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/prediction/BUILD.bazel b/ink_stroke_modeler/internal/prediction/BUILD.bazel new file mode 100644 index 0000000..2d94289 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/BUILD.bazel @@ -0,0 +1,83 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//ink_stroke_modeler:__subpackages__"], +) + +licenses(["notice"]) + +cc_library( + name = "input_predictor", + hdrs = ["input_predictor.h"], + deps = [ + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + "//ink_stroke_modeler/internal:internal_types", + ], +) + +cc_library( + name = "kalman_predictor", + srcs = ["kalman_predictor.cc"], + hdrs = ["kalman_predictor.h"], + deps = [ + ":input_predictor", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + "//ink_stroke_modeler/internal:internal_types", + "//ink_stroke_modeler/internal:utils", + "//ink_stroke_modeler/internal/prediction/kalman_filter", + ], +) + +cc_test( + name = "kalman_predictor_test", + srcs = ["kalman_predictor_test.cc"], + deps = [ + ":input_predictor", + ":kalman_predictor", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + "//ink_stroke_modeler/internal:type_matchers", + "@com_google_absl//absl/types:optional", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "stroke_end_predictor", + srcs = ["stroke_end_predictor.cc"], + hdrs = ["stroke_end_predictor.h"], + deps = [ + ":input_predictor", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler:types", + "//ink_stroke_modeler/internal:internal_types", + "//ink_stroke_modeler/internal:position_modeler", + "@com_google_absl//absl/types:optional", + ], +) + +cc_test( + name = "stroke_end_predictor_test", + srcs = ["stroke_end_predictor_test.cc"], + deps = [ + ":input_predictor", + ":stroke_end_predictor", + "//ink_stroke_modeler:params", + "//ink_stroke_modeler/internal:type_matchers", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/ink_stroke_modeler/internal/prediction/CMakeLists.txt b/ink_stroke_modeler/internal/prediction/CMakeLists.txt new file mode 100644 index 0000000..e54dbbe --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/CMakeLists.txt @@ -0,0 +1,86 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(kalman_filter) + +ink_cc_library( + NAME + input_predictor + HDRS + input_predictor.h + DEPS + InkStrokeModeler::params + InkStrokeModeler::types + InkStrokeModeler::internal_types +) + +ink_cc_library( + NAME + kalman_predictor + SRCS + kalman_predictor.cc + HDRS + kalman_predictor.h + DEPS + InkStrokeModeler::input_predictor + InkStrokeModeler::params + InkStrokeModeler::types + InkStrokeModeler::internal_types + InkStrokeModeler::utils + InkStrokeModeler::kalman_filter +) + +ink_cc_test( + NAME + kalman_predictor_test + SRCS + kalman_predictor_test.cc + DEPS + InkStrokeModeler::input_predictor + InkStrokeModeler::kalman_predictor + InkStrokeModeler::params + InkStrokeModeler::types + InkStrokeModeler::type_matchers + absl::optional + GTest::gmock_main +) + +ink_cc_library( + NAME + stroke_end_predictor + SRCS + stroke_end_predictor.cc + HDRS + stroke_end_predictor.h + DEPS + InkStrokeModeler::input_predictor + InkStrokeModeler::params + InkStrokeModeler::types + InkStrokeModeler::internal_types + InkStrokeModeler::position_modeler + absl::optional +) + +ink_cc_test( + NAME + stroke_end_predictor_test + SRCS + stroke_end_predictor_test.cc + DEPS + InkStrokeModeler::input_predictor + InkStrokeModeler::stroke_end_predictor + InkStrokeModeler::params + InkStrokeModeler::type_matchers + GTest::gmock_main +) diff --git a/ink_stroke_modeler/internal/prediction/input_predictor.h b/ink_stroke_modeler/internal/prediction/input_predictor.h new file mode 100644 index 0000000..4147953 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/input_predictor.h @@ -0,0 +1,57 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_INPUT_PREDICTOR_H_ +#define INK_STROKE_MODELER_INTERNAL_PREDICTION_INPUT_PREDICTOR_H_ + +#include <vector> + +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// Interface for input predictors that generate points based on past input. +class InputPredictor { + public: + virtual ~InputPredictor() {} + + // Resets the predictor's internal model. + virtual void Reset() = 0; + + // Updates the predictor's internal model with the given input. + virtual void Update(Vec2 position, Time time) = 0; + + // Constructs a prediction based from the given state, based on the + // predictor's internal model. The result may be empty if the predictor has + // not yet accumulated enough data, via Update(), to construct a reasonable + // prediction. + // + // Subclasses are expected to maintain the following invariants: + // - The given state must not appear in the prediction. + // - The time delta between each state in the prediction, and between the + // given state and the first predicted state, must conform to + // SamplingParams::min_output_rate. + virtual std::vector<TipState> ConstructPrediction( + const TipState &last_state) const = 0; +}; + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_INPUT_PREDICTOR_H_ diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/BUILD.bazel b/ink_stroke_modeler/internal/prediction/kalman_filter/BUILD.bazel new file mode 100644 index 0000000..e487c16 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_filter/BUILD.bazel @@ -0,0 +1,58 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + default_visibility = ["//ink_stroke_modeler:__subpackages__"], +) + +licenses(["notice"]) + +cc_library( + name = "matrix", + hdrs = ["matrix.h"], +) + +cc_test( + name = "matrix_test", + srcs = ["matrix_test.cc"], + deps = [ + ":matrix", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "kalman_filter", + srcs = [ + "axis_predictor.cc", + "kalman_filter.cc", + ], + hdrs = [ + "axis_predictor.h", + "kalman_filter.h", + ], + deps = [ + ":matrix", + "@com_google_absl//absl/memory", + ], +) + +cc_test( + name = "axis_predictor_test", + srcs = ["axis_predictor_test.cc"], + deps = [ + ":kalman_filter", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/CMakeLists.txt b/ink_stroke_modeler/internal/prediction/kalman_filter/CMakeLists.txt new file mode 100644 index 0000000..b3c8d88 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_filter/CMakeLists.txt @@ -0,0 +1,54 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ink_cc_library( + NAME + matrix + HDRS + matrix.h +) + +ink_cc_test( + NAME + matrix_test + SRCS + matrix_test.cc + DEPS + InkStrokeModeler::matrix + GTest::gtest_main +) + +ink_cc_library( + NAME + kalman_filter + SRCS + axis_predictor.cc + kalman_filter.cc + HDRS + axis_predictor.h + kalman_filter.h + DEPS + InkStrokeModeler::matrix + absl::memory +) + +ink_cc_test( + NAME + axis_predictor_test + SRCS + axis_predictor_test.cc + DEPS + InkStrokeModeler::kalman_filter + GTest::gtest_main +) diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.cc b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.cc new file mode 100644 index 0000000..fab9363 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.cc @@ -0,0 +1,98 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h" + +#include "absl/memory/memory.h" +#include "ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h" + +namespace ink { +namespace stroke_model { +namespace { + +constexpr int kPositionIndex = 0; +constexpr int kVelocityIndex = 1; +constexpr int kAccelerationIndex = 2; +constexpr int kJerkIndex = 3; + +constexpr double kDt = 1.0; +constexpr double kDtSquared = kDt * kDt; +constexpr double kDtCubed = kDt * kDt * kDt; +} // namespace + +AxisPredictor::AxisPredictor(double process_noise, double measurement_noise, + int min_stable_iteration) { + // State translation matrix is basic physics. + // new_pos = pre_pos + v * dt + 1/2 * a * dt^2 + 1/6 * J * dt^3. + // new_v = v + a * dt + 1/2 * J * dt^2. + // new_a = a + J * dt. + // new_j = J. + Matrix4 state_transition(1, kDt, .5 * kDtSquared, 1.0 / 6 * kDtCubed, // + 0, 1, kDt, .5 * kDtSquared, // + 0, 0, 1, kDt, // + 0, 0, 0, 1); + // We model the system noise as noisy force on the pen. + // The following matrix describes the impact of that noise on each state. + Vec4 process_noise_vector(1.0 / 6 * kDtCubed, 0.5 * kDtSquared, kDt, 1.0); + Matrix4 process_noise_covariance = + OuterProduct(process_noise_vector, process_noise_vector) * process_noise; + + // Sensor only detects location. Thus measurement only impact the position. + Vec4 measurement_vector(1.0, 0.0, 0.0, 0.0); + + kalman_filter_ = absl::make_unique<KalmanFilter>( + state_transition, process_noise_covariance, measurement_vector, + measurement_noise, min_stable_iteration); +} + +bool AxisPredictor::Stable() const { + return kalman_filter_ && kalman_filter_->Stable(); +} + +void AxisPredictor::Reset() { + if (kalman_filter_) kalman_filter_->Reset(); +} + +void AxisPredictor::Update(double observation) { + if (kalman_filter_) kalman_filter_->Update(observation); +} + +int AxisPredictor::NumIterations() const { + return kalman_filter_ ? kalman_filter_->NumIterations() : 0; +} + +double AxisPredictor::GetPosition() const { + if (kalman_filter_) + return kalman_filter_->GetStateEstimation()[kPositionIndex]; + else + return 0.0; +} + +double AxisPredictor::GetVelocity() const { + if (kalman_filter_) + return kalman_filter_->GetStateEstimation()[kVelocityIndex]; + else + return 0.0; +} + +double AxisPredictor::GetAcceleration() const { + return kalman_filter_->GetStateEstimation()[kAccelerationIndex]; +} + +double AxisPredictor::GetJerk() const { + return kalman_filter_->GetStateEstimation()[kJerkIndex]; +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h new file mode 100644 index 0000000..3ab02ee --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h @@ -0,0 +1,62 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_AXIS_PREDICTOR_H_ +#define INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_AXIS_PREDICTOR_H_ + +#include <memory> + +#include "ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h" + +namespace ink { +namespace stroke_model { + +// Class to predict on axis. +// +// This predictor use one instance of Kalman filter to predict one dimension of +// stylus movement. +class AxisPredictor { + public: + AxisPredictor(double process_noise, double measurement_noise, + int min_stable_iteration); + + // Return true if the underlying Kalman filter is stable. + bool Stable() const; + + // Reset the underlying Kalman filter. + void Reset(); + + // Update the predictor with a new observation. + void Update(double observation); + + // Returns the number of times Update() has been called since the last time + // the AxisPredictor was reset. + int NumIterations() const; + + // Get the predicted values from the underlying Kalman filter. + double GetPosition() const; + double GetVelocity() const; + double GetAcceleration() const; + double GetJerk() const; + + private: + std::unique_ptr<KalmanFilter> kalman_filter_; +}; + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_AXIS_PREDICTOR_H_ diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc new file mode 100644 index 0000000..40ba9f4 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor_test.cc @@ -0,0 +1,100 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h" + +#include <vector> + +#include "gtest/gtest.h" + +namespace ink { +namespace stroke_model { +namespace { + +constexpr int kStableIterNum = 4; + +constexpr double kProcessNoise = 0.01; +constexpr double kMeasurementNoise = 1.0; + +} // namespace + +struct DataSet { + double initial_observation; + std::vector<double> observation; + std::vector<double> position; + std::vector<double> velocity; + std::vector<double> acceleration; + std::vector<double> jerk; +}; + +void ValidateAxisPredictor(AxisPredictor* predictor, const DataSet& data) { + predictor->Reset(); + predictor->Update(data.initial_observation); + for (decltype(data.observation.size()) i = 0; i < data.observation.size(); + i++) { + predictor->Update(data.observation[i]); + EXPECT_NEAR(data.position[i], predictor->GetPosition(), 0.0001); + EXPECT_NEAR(data.velocity[i], predictor->GetVelocity(), 0.0001); + EXPECT_NEAR(data.acceleration[i], predictor->GetAcceleration(), 0.0001); + EXPECT_NEAR(data.jerk[i], predictor->GetJerk(), 0.0001); + } +} + +// Test that the predictor will stable. +TEST(AxisPredictorTest, ShouldStable) { + AxisPredictor predictor(kProcessNoise, kMeasurementNoise, kStableIterNum); + for (int i = 0; i < kStableIterNum; i++) { + EXPECT_FALSE(predictor.Stable()); + predictor.Update(1); + } + EXPECT_TRUE(predictor.Stable()); +} + +// Test the kalman filter behavior. The data set is generated by a "known to +// work" kalman filter. +TEST(AxisPredictorTest, PredictedValue) { + AxisPredictor predictor(kProcessNoise, kMeasurementNoise, kStableIterNum); + DataSet data; + data.initial_observation = 0; + data.observation = {1, 2, 3, 4, 5, 6}; + data.position = {0.6949411066858742, 1.8880162111305765, 3.0596776689233476, + 4.080666568886563, 5.039574058758894, 5.990101744132957}; + data.velocity = {0.48326413015846115, 1.349212968908908, 1.5150757723942188, + 1.2449353797925855, 0.9823147273054352, 0.831418084705206}; + data.acceleration = {0.20388102703160751, 0.6602537865634062, + 0.46392675203046707, 0.0691864035645362, + -0.1571001901104591, -0.2303438651979314}; + data.jerk = {0.051351580374544535, 0.17805019978769315, + 0.06592110190532013, -0.06063794909774803, + -0.10198612906906362, -0.09541445938944032}; + + ValidateAxisPredictor(&predictor, data); + + data.initial_observation = 0; + data.observation = {1, 2, 4, 8, 16, 32}; + data.position = {0.6949411066858742, 1.8880162111305765, 3.9597202826804603, + 7.9052737853848285, 15.720340533540115, 31.24662046486774}; + data.velocity = {0.48326413015846115, 1.349212968908908, 2.492271225870179, + 4.610844489557212, 8.828231877380588, 16.987494416071463}; + data.acceleration = {0.20388102703160751, 0.6602537865634062, + 1.090991623810185, 1.885675547541351, + 3.4586206593783526, 6.34082285106952}; + data.jerk = {0.051351580374544535, 0.17805019978769315, 0.25373225050247916, + 0.4023497012294069, 0.6945464157568688, 1.1947316519015612}; + + ValidateAxisPredictor(&predictor, data); +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.cc b/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.cc new file mode 100644 index 0000000..45238f7 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.cc @@ -0,0 +1,79 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h" + +#include "ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h" + +namespace ink { +namespace stroke_model { + +KalmanFilter::KalmanFilter(const Matrix4& state_transition, + const Matrix4& process_noise_covariance, + const Vec4& measurement_vector, + double measurement_noise_variance, + int min_stable_iteration) + : state_transition_matrix_(state_transition), + process_noise_covariance_matrix_(process_noise_covariance), + measurement_vector_(measurement_vector), + measurement_noise_variance_(measurement_noise_variance), + min_stable_iteration_(min_stable_iteration), + iter_num_(0) {} + +void KalmanFilter::Predict() { + // X = F * X + state_estimation_ = state_transition_matrix_ * state_estimation_; + // P = F * P * F' + Q + error_covariance_matrix_ = state_transition_matrix_ * + error_covariance_matrix_ * + state_transition_matrix_.Transpose() + + process_noise_covariance_matrix_; +} + +void KalmanFilter::Update(double observation) { + if (iter_num_++ == 0) { + // We only update the state estimation in the first iteration. + state_estimation_[0] = observation; + return; + } + Predict(); + // Y = z - H * X + double y = observation - DotProduct(measurement_vector_, state_estimation_); + // S = H * P * H' + R + double S = DotProduct(measurement_vector_ * error_covariance_matrix_, + measurement_vector_) + + measurement_noise_variance_; + // K = P * H' * inv(S) + Vec4 kalman_gain = measurement_vector_ * error_covariance_matrix_ / S; + + // X = X + K * Y + state_estimation_ = state_estimation_ + kalman_gain * y; + + // I_HK = eye(P) - K * H + Matrix4 I_KH = Matrix4() - OuterProduct(kalman_gain, measurement_vector_); + + // P = I_KH * P * I_KH' + K * R * K' + error_covariance_matrix_ = + I_KH * error_covariance_matrix_ * I_KH.Transpose() + + OuterProduct(kalman_gain, kalman_gain) * measurement_noise_variance_; +} + +void KalmanFilter::Reset() { + state_estimation_ = {0, 0, 0, 0}; + error_covariance_matrix_ = Matrix4(); // identity + iter_num_ = 0; +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h b/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h new file mode 100644 index 0000000..65fb55f --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h @@ -0,0 +1,96 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_KALMAN_FILTER_H_ +#define INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_KALMAN_FILTER_H_ + +#include "ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h" + +namespace ink { +namespace stroke_model { + +// Generates a state estimation based upon observations which can then be used +// to compute predicted values. +class KalmanFilter { + public: + KalmanFilter(const Matrix4& state_transition, + const Matrix4& process_noise_covariance, + const Vec4& measurement_vector, + double measurement_noise_variance, int min_stable_iteration); + + // Get the estimation of current state. + const Vec4& GetStateEstimation() const { return state_estimation_; } + + // Will return true only if the Kalman filter has seen enough data and is + // considered as stable. + bool Stable() const { return iter_num_ >= min_stable_iteration_; } + + // Update the observation of the system. + void Update(double observation); + + void Reset(); + + // Returns the number of times Update() has been called since the last time + // the KalmanFilter was reset. + int NumIterations() const { return iter_num_; } + + private: + void Predict(); + + // Estimate of the latent state + // Symbol: X + // Dimension: state_vector_dim_ + Vec4 state_estimation_; + + // The covariance of the difference between prior predicted latent + // state and posterior estimated latent state (the so-called "innovation". + // Symbol: P + Matrix4 error_covariance_matrix_; + + // For position, state transition matrix is derived from basic physics: + // new_x = x + v * dt + 1/2 * a * dt^2 + 1/6 * jerk * dt^3 + // new_v = v + a * dt + 1/2 * jerk * dt^2 + // ... + // Matrix that transmit current state to next state + // Symbol: F + Matrix4 state_transition_matrix_; + + // Process_noise_covariance_matrix_ is a time-varying parameter that will be + // estimated as part of the Kalman filter process. + // Symbol: Q + Matrix4 process_noise_covariance_matrix_; + + // Vector to transform estimate to measurement. + // Symbol: H + const Vec4 measurement_vector_{0, 0, 0, 0}; + + // measurement_noise_ is a time-varying parameter that will be estimated as + // part of the Kalman filter process. + // Symbol: R + double measurement_noise_variance_; + + // The first iteration at which the Kalman filter is considered stable enough + // to make a good estimate of the state. + int min_stable_iteration_; + + // Tracks the number of update iterations that have occurred. + int iter_num_; +}; + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_KALMAN_FILTER_H_ diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h b/ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h new file mode 100644 index 0000000..2c12bae --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h @@ -0,0 +1,270 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_MATRIX_H_ +#define INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_MATRIX_H_ + +#include <array> +#include <cstddef> +#include <ostream> + +namespace ink { +namespace stroke_model { + +// These classes provide the matrix arithmetic needed for the Kalman filter. +// +// This is intentionally limited to just the required functions, so some +// common matrix arithmetic operations aren't present (e.g. inversion), and +// some operators' symmetric counterparts are missing (e.g. Vec4 * double is +// defined, but double * Vec4 is not). + +// A double-precision vector in 4-dimensional space. +class Vec4 { + public: + constexpr Vec4() : Vec4(0, 0, 0, 0) {} + constexpr Vec4(double x, double y, double z, double w) + : array_({x, y, z, w}) {} + + double& operator[](size_t i) { return array_[i]; } + double operator[](size_t i) const { return array_[i]; } + + private: + std::array<double, 4> array_; +}; + +// A double-precision 4x4 matrix. +class Matrix4 { + public: + // Constructs an identity matrix. + constexpr Matrix4() + : Matrix4(1, 0, 0, 0, // + 0, 1, 0, 0, // + 0, 0, 1, 0, // + 0, 0, 0, 1) {} + + // Constructs a matrix with the given values, in row-major order. + constexpr Matrix4(double m00, double m01, double m02, double m03, // + double m10, double m11, double m12, double m13, // + double m20, double m21, double m22, double m23, // + double m30, double m31, double m32, double m33) + : array_{{{m00, m01, m02, m03}, + {m10, m11, m12, m13}, + {m20, m21, m22, m23}, + {m30, m31, m32, m33}}} {} + + // Constructs a matrix s.t. all values are zero. + static constexpr Matrix4 Zero() { + return {0, 0, 0, 0, // + 0, 0, 0, 0, // + 0, 0, 0, 0, // + 0, 0, 0, 0}; + } + + // Returns a copy of the matrix with its rows and columns swapped, i.e. + // original.At(i, j) == transposed.At(j, i). + Matrix4 Transpose() const { + Matrix4 result; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + result.At(i, j) = At(j, i); + } + } + return result; + } + + double& At(size_t row, size_t column) { return array_[row][column]; } + double At(size_t row, size_t column) const { return array_[row][column]; } + + private: + std::array<Vec4, 4> array_; +}; + +// Computes the dot product of two vectors. Given vectors a and b, this is +// equivalent to the matrix product: +// [a₀ a₁ a₂ a₃]⎡b₀⎤ +// ⎢b₁⎥ +// ⎢b₂⎥ +// ⎣b₃⎦ +double DotProduct(const Vec4& lhs, const Vec4& rhs); + +// Computes the outer product of two vectors. Given vectors a and b, this is +// equivalent to the matrix product: +// ⎡a₀⎤[b₀ b₁ b₂ b₃] +// ⎢a₁⎥ +// ⎢a₂⎥ +// ⎣a₃⎦ +Matrix4 OuterProduct(const Vec4& lhs, const Vec4& rhs); + +bool operator==(const Vec4& lhs, const Vec4& rhs); +bool operator!=(const Vec4& lhs, const Vec4& rhs); +Vec4 operator+(const Vec4& lhs, const Vec4& rhs); +Vec4 operator*(const Vec4& v, double k); +Vec4 operator/(const Vec4& v, double k); + +bool operator==(const Matrix4& lhs, const Matrix4& rhs); +bool operator!=(const Matrix4& lhs, const Matrix4& rhs); +Matrix4 operator*(const Matrix4& lhs, const Matrix4& rhs); +Matrix4 operator+(const Matrix4& lhs, const Matrix4& rhs); +Matrix4 operator-(const Matrix4& lhs, const Matrix4& rhs); + +Matrix4 operator*(const Matrix4& m, double k); +Vec4 operator*(const Matrix4& m, const Vec4& v); +Vec4 operator*(const Vec4& v, const Matrix4& m); + +std::ostream& operator<<(std::ostream& stream, const Vec4& v); +std::ostream& operator<<(std::ostream& stream, const Matrix4& m); + +// ============================================================================ +// Inline function implementations +// ============================================================================ + +inline double DotProduct(const Vec4& lhs, const Vec4& rhs) { + double result = 0; + for (int i = 0; i < 4; ++i) result += lhs[i] * rhs[i]; + return result; +} + +inline Matrix4 OuterProduct(const Vec4& lhs, const Vec4& rhs) { + Matrix4 result = Matrix4::Zero(); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + result.At(i, j) = lhs[i] * rhs[j]; + } + } + return result; +} + +inline bool operator==(const Vec4& lhs, const Vec4& rhs) { + for (int i = 0; i < 4; ++i) { + if (lhs[i] != rhs[i]) return false; + } + return true; +} +inline bool operator!=(const Vec4& lhs, const Vec4& rhs) { + return !(lhs == rhs); +} + +inline Vec4 operator+(const Vec4& lhs, const Vec4& rhs) { + Vec4 result; + for (int i = 0; i < 4; ++i) result[i] = lhs[i] + rhs[i]; + return result; +} + +inline Vec4 operator*(const Vec4& v, double k) { + Vec4 result; + for (int i = 0; i < 4; ++i) result[i] = v[i] * k; + return result; +} + +inline Vec4 operator/(const Vec4& v, double k) { + Vec4 result; + for (int i = 0; i < 4; ++i) result[i] = v[i] / k; + return result; +} + +inline bool operator==(const Matrix4& lhs, const Matrix4& rhs) { + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + if (lhs.At(i, j) != rhs.At(i, j)) return false; + } + } + return true; +} + +inline bool operator!=(const Matrix4& lhs, const Matrix4& rhs) { + return !(lhs == rhs); +} + +inline Matrix4 operator*(const Matrix4& lhs, const Matrix4& rhs) { + Matrix4 result = Matrix4::Zero(); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 4; ++k) { + result.At(i, j) += lhs.At(i, k) * rhs.At(k, j); + } + } + } + return result; +} + +inline Matrix4 operator+(const Matrix4& lhs, const Matrix4& rhs) { + Matrix4 result; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + result.At(i, j) = lhs.At(i, j) + rhs.At(i, j); + } + } + return result; +} + +inline Matrix4 operator-(const Matrix4& lhs, const Matrix4& rhs) { + Matrix4 result; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + result.At(i, j) = lhs.At(i, j) - rhs.At(i, j); + } + } + return result; +} + +inline Matrix4 operator*(const Matrix4& m, double k) { + Matrix4 result; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + result.At(i, j) = m.At(i, j) * k; + } + } + return result; +} + +inline Vec4 operator*(const Matrix4& m, const Vec4& v) { + Vec4 result; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + result[i] += v[j] * m.At(i, j); + } + } + return result; +} + +inline Vec4 operator*(const Vec4& v, const Matrix4& m) { + Vec4 result; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + result[i] += v[j] * m.At(j, i); + } + } + return result; +} + +inline std::ostream& operator<<(std::ostream& stream, const Vec4& v) { + stream << "(" << v[0]; + for (int i = 1; i < 4; ++i) stream << ", " << v[i]; + return stream << ")"; +} + +inline std::ostream& operator<<(std::ostream& stream, const Matrix4& m) { + for (int i = 0; i < 4; ++i) { + stream << '\n' << m.At(i, 0); + for (int j = 1; j < 4; ++j) stream << '\t' << m.At(i, j); + } + return stream; +} + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_FILTER_MATRIX_H_ diff --git a/ink_stroke_modeler/internal/prediction/kalman_filter/matrix_test.cc b/ink_stroke_modeler/internal/prediction/kalman_filter/matrix_test.cc new file mode 100644 index 0000000..b11970d --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_filter/matrix_test.cc @@ -0,0 +1,215 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h" + +#include <sstream> +#include <string> + +#include "gtest/gtest.h" + +namespace ink { +namespace stroke_model { +namespace { + +TEST(MatrixTest, Vec4Equality) { + EXPECT_EQ(Vec4(1, 2, 3, 4), Vec4(1, 2, 3, 4)); + EXPECT_NE(Vec4(1, 2, 3, 4), Vec4(1, 2, 7, 4)); + EXPECT_NE(Vec4(.1, .7, -4, 6), Vec4(-1, 64, .3, 200)); +} + +TEST(MatrixTest, Vec4Addition) { + EXPECT_EQ(Vec4(0, 1, 2, 4) + Vec4(5, -1, 6, 7), Vec4(5, 0, 8, 11)); + EXPECT_EQ(Vec4(0.25, -3, 17, 0) + Vec4(.5, -.5, -1, 2), + Vec4(.75, -3.5, 16, 2)); +} + +TEST(MatrixTest, Vec4ScalarMultiplication) { + EXPECT_EQ(Vec4(6, -12, 7, .25) * 2, Vec4(12, -24, 14, .5)); + EXPECT_EQ(Vec4(17, -3, 5.5, 0) * -.25, Vec4(-4.25, .75, -1.375, 0)); +} + +TEST(MatrixTest, Vec4ScalarDivision) { + EXPECT_EQ(Vec4(13, -8, 0, 100) / -2, Vec4(-6.5, 4, 0, -50)); + EXPECT_EQ(Vec4(0, -3, 20, 1) / .2, Vec4(0, -15, 100, 5)); +} + +TEST(MatrixTest, Matrix4Equality) { + EXPECT_EQ(Matrix4(0, 1, 2, 3, // + 4, 5, 6, 7, // + 8, 9, 10, 11, // + 12, 13, 14, 15), + Matrix4(0, 1, 2, 3, // + 4, 5, 6, 7, // + 8, 9, 10, 11, // + 12, 13, 14, 15)); + EXPECT_NE(Matrix4(0, 1, 2, 3, // + 4, 5, 6, 7, // + 8, 9, 10, 11, // + 12, 13, 14, 15), + Matrix4(1, 2, 0, 4, // + 4, 9, 6, 12, // + 9, 9, 9, 9, // + -1, -2, 14, 99)); +} + +TEST(MatrixTest, Matrix4IdentityCtor) { + EXPECT_EQ(Matrix4(), Matrix4(1, 0, 0, 0, // + 0, 1, 0, 0, // + 0, 0, 1, 0, // + 0, 0, 0, 1)); +} + +TEST(MatrixTest, Matrix4Zero) { + EXPECT_EQ(Matrix4::Zero(), Matrix4(0, 0, 0, 0, // + 0, 0, 0, 0, // + 0, 0, 0, 0, // + 0, 0, 0, 0)); +} + +TEST(MatrixTest, Matrix4Transpose) { + Matrix4 m(0, 1, 2, 3, // + 4, 5, 6, 7, // + 8, 9, 10, 11, // + 12, 13, 14, 15); + EXPECT_EQ(m.Transpose(), Matrix4(0, 4, 8, 12, // + 1, 5, 9, 13, // + 2, 6, 10, 14, // + 3, 7, 11, 15)); +} + +TEST(MatrixTest, Matrix4Multiplication) { + Matrix4 a(-4, 4, 2, 9, // + -2, -5, 6, 1, // + -2, 7, 10, 1, // + -4, -5, 2, 6); + Matrix4 b(-1, 7, 9, 3, // + 0, 7, -3, 8, // + -9, 7, 7, -10, // + 1, -1, -3, -1); + EXPECT_EQ(a * b, Matrix4(-5, 5, -61, -9, // + -51, -8, 36, -107, // + -87, 104, 28, -51, // + -8, -55, -25, -78)); + EXPECT_EQ(b * a, Matrix4(-40, 9, 136, 25, // + -40, -96, 28, 52, // + 48, 28, 74, -127, // + 8, -7, -36, -1)); +} + +TEST(MatrixTest, Matrix4Addition) { + Matrix4 a(2, 0, -10, -1, // + -4, -4, -7, 3, // + 7, -1, 7, 3, // + -7, -4, -4, -4); + Matrix4 b(9, -6, -10, 0, // + 6, 1, -5, 9, // + -7, -4, -3, -6, // + 7, 7, -10, -9); + EXPECT_EQ(a + b, Matrix4(11, -6, -20, -1, // + 2, -3, -12, 12, // + 0, -5, 4, -3, // + 0, 3, -14, -13)); + EXPECT_EQ(b + a, Matrix4(11, -6, -20, -1, // + 2, -3, -12, 12, // + 0, -5, 4, -3, // + 0, 3, -14, -13)); +} + +TEST(MatrixTest, Matrix4Subtraction) { + Matrix4 a(-7, -9, 9, 9, // + -4, 10, -3, -1, // + 8, 9, 6, 4, // + 9, -7, 7, 4); + Matrix4 b(2, -1, 2, 6, // + -1, -8, -1, 10, // + 3, 0, -6, -1, // + -6, -3, 6, 7); + EXPECT_EQ(a - b, Matrix4(-9, -8, 7, 3, // + -3, 18, -2, -11, // + 5, 9, 12, 5, // + 15, -4, 1, -3)); + EXPECT_EQ(b - a, Matrix4(9, 8, -7, -3, // + 3, -18, 2, 11, // + -5, -9, -12, -5, // + -15, 4, -1, 3)); +} + +TEST(MatrixTest, Matrix4ScalarMultiplication) { + Matrix4 m(6, 8, 6, 7, // + 2, -4, 10, 8, // + -1, 4, 9, -7, // + -2, -9, 10, 10); + EXPECT_EQ(m * 3, Matrix4(18, 24, 18, 21, // + 6, -12, 30, 24, // + -3, 12, 27, -21, // + -6, -27, 30, 30)); + EXPECT_EQ(m * .5, Matrix4(3, 4, 3, 3.5, // + 1, -2, 5, 4, // + -.5, 2, 4.5, -3.5, // + -1, -4.5, 5, 5)); +} + +TEST(MatrixTest, Matrix4VectorMultiplication) { + Matrix4 m(3, 0, 4, 3, // + 7, -6, 7, -10, // + 6, -2, -10, -5, // + -3, 9, 1, -5); + Vec4 v(-6, 9, -7, -10); + EXPECT_EQ(m * v, Vec4(-76, -45, 66, 142)); + EXPECT_EQ(v * m, Vec4(33, -130, 99, -23)); +} + +TEST(MatrixTest, DotProduct) { + Vec4 a(0, -3, 0, -5); + Vec4 b(6, 4, -4, 6); + EXPECT_EQ(DotProduct(a, b), -42); + EXPECT_EQ(DotProduct(b, a), -42); +} + +TEST(MatrixTest, OuterProduct) { + Vec4 a(-8, -3, 6, -8); + Vec4 b(4, -9, -1, 10); + EXPECT_EQ(OuterProduct(a, b), Matrix4(-32, 72, 8, -80, // + -12, 27, 3, -30, // + 24, -54, -6, 60, // + -32, 72, 8, -80)); + EXPECT_EQ(OuterProduct(b, a), Matrix4(-32, -12, 24, -32, // + 72, 27, -54, 72, // + 8, 3, -6, 8, // + -80, -30, 60, -80)); +} + +TEST(MatrixTest, Vec4Stream) { + std::stringstream s; + s << Vec4(1.28, -9, .9, 2.7); + EXPECT_EQ(s.str(), "(1.28, -9, 0.9, 2.7)"); +} + +TEST(MatrixTest, Matrix4Stream) { + std::stringstream s; + s << Matrix4(7.5, -7.7, -4.6, 8, // + 6.4, -8.52, 0, 8.8, // + -3.5, -5.2, -.5, 9, // + -2.6, -3.4, 5.5, 8.3); + EXPECT_EQ(s.str(), + "\n7.5\t-7.7\t-4.6\t8" + "\n6.4\t-8.52\t0\t8.8" + "\n-3.5\t-5.2\t-0.5\t9" + "\n-2.6\t-3.4\t5.5\t8.3"); +} + +} // namespace +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/prediction/kalman_predictor.cc b/ink_stroke_modeler/internal/prediction/kalman_predictor.cc new file mode 100644 index 0000000..cc8949c --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_predictor.cc @@ -0,0 +1,265 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/prediction/kalman_predictor.h" + +#include <algorithm> +#include <cmath> +#include <limits> +#include <vector> + +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/internal/utils.h" +#include "ink_stroke_modeler/params.h" + +namespace ink { +namespace stroke_model { +namespace { + +KalmanPredictor::State EvaluateCubic(const KalmanPredictor::State &start_state, + Duration delta_time) { + float dt = delta_time.Value(); + auto dt_squared = dt * dt; + auto dt_cubed = dt_squared * dt; + + KalmanPredictor::State end_state; + end_state.position = start_state.position + start_state.velocity * dt + + start_state.acceleration * dt_squared / 2.f + + start_state.jerk * dt_cubed / 6.f; + end_state.velocity = start_state.velocity + start_state.acceleration * dt + + start_state.jerk * dt_squared / 2.f; + end_state.acceleration = start_state.acceleration + start_state.jerk * dt; + end_state.jerk = start_state.jerk; + + return end_state; +} + +} // namespace + +void KalmanPredictor::Reset() { + x_predictor_.Reset(); + y_predictor_.Reset(); + sample_times_.clear(); + last_position_received_ = absl::nullopt; +} + +void KalmanPredictor::Update(Vec2 position, Time time) { + last_position_received_ = position; + sample_times_.push_back(time); + if (predictor_params_.max_time_samples < 0 || + sample_times_.size() > (uint)predictor_params_.max_time_samples) + sample_times_.pop_front(); + + x_predictor_.Update(position.x); + y_predictor_.Update(position.y); +} + +absl::optional<KalmanPredictor::State> KalmanPredictor::GetEstimatedState() + const { + if (!IsStable() || sample_times_.empty()) return absl::nullopt; + + State estimated_state; + estimated_state.position = {static_cast<float>(x_predictor_.GetPosition()), + static_cast<float>(y_predictor_.GetPosition())}; + estimated_state.velocity = {static_cast<float>(x_predictor_.GetVelocity()), + static_cast<float>(y_predictor_.GetVelocity())}; + estimated_state.acceleration = { + static_cast<float>(x_predictor_.GetAcceleration()), + static_cast<float>(y_predictor_.GetAcceleration())}; + estimated_state.jerk = {static_cast<float>(x_predictor_.GetJerk()), + static_cast<float>(y_predictor_.GetJerk())}; + + // The axis predictors are not time-aware, assuming that the time delta + // between measurements is always 1. To correct for this, we divide the + // velocity, acceleration, and jerk by the average observed time delta, raised + // to the appropriate power. + auto dt = static_cast<float>( + (sample_times_.back() - sample_times_.front()).Value()) / + sample_times_.size(); + auto dt_squared = dt * dt; + auto dt_cubed = dt_squared * dt; + estimated_state.velocity /= dt; + estimated_state.acceleration /= dt_squared; + estimated_state.jerk /= dt_cubed; + + // We want our predictions to tend more towards linearity -- to achieve this, + // we reduce the acceleration and jerk. + estimated_state.acceleration *= predictor_params_.acceleration_weight; + estimated_state.jerk *= predictor_params_.jerk_weight; + + return estimated_state; +} + +std::vector<TipState> KalmanPredictor::ConstructPrediction( + const TipState &last_state) const { + auto estimated_state = GetEstimatedState(); + if (!estimated_state || !last_position_received_) { + // We don't yet have enough data to construct a prediction. + return {}; + } + + Duration sample_dt{1. / sampling_params_.min_output_rate}; + std::vector<TipState> prediction; + ConstructCubicConnector(last_state, *estimated_state, predictor_params_, + sample_dt, &prediction); + auto start_time = + prediction.empty() ? last_state.time : prediction.back().time; + ConstructCubicPrediction(*estimated_state, predictor_params_, start_time, + sample_dt, NumberOfPointsToPredict(*estimated_state), + &prediction); + return prediction; +} + +void KalmanPredictor::ConstructCubicPrediction( + const State &estimated_state, const KalmanPredictorParams ¶ms, + Time start_time, Duration sample_dt, int n_samples, + std::vector<TipState> *output) { + auto current_state = estimated_state; + auto current_time = start_time; + for (int i = 0; i < n_samples; ++i) { + auto next_state = EvaluateCubic(current_state, sample_dt); + current_time += sample_dt; + output->push_back({next_state.position, next_state.velocity, current_time}); + current_state = next_state; + } +} + +void KalmanPredictor::ConstructCubicConnector( + const TipState &last_tip_state, const State &estimated_state, + const KalmanPredictorParams ¶ms, Duration sample_dt, + std::vector<TipState> *output) { + // Estimate how long it will take for the tip to travel from its last position + // to the estimated position, based on the start and end velocities. We define + // a minimum "reasonable" velocity to avoid division by zero. + auto distance_traveled = + Distance(last_tip_state.position, estimated_state.position); + auto max_velocity_at_ends = std::max(last_tip_state.velocity.Magnitude(), + estimated_state.velocity.Magnitude()); + Duration target_duration{ + distance_traveled / + std::max(max_velocity_at_ends, params.min_catchup_velocity)}; + + // Determine how many samples this will give us, ensuring that there's always + // at least one. Then, pick a duration that's a multiple of the sample dt. + int n_points = std::max(std::ceil(static_cast<float>(target_duration.Value() / + sample_dt.Value())), + 1.f); + auto duration = n_points * sample_dt; + + // We want to construct a cubic curve connecting the last tip state and the + // estimated state. Given positions p₀ and p₁, velocities v₀ and v₁, and times + // t₀ and t₁ at the start and end of the curve, we define a pair of functions, + // f and g, such that the curve is described by the composite function + // f(g(t)): + // f(x) = ax³ + bx² + cx + d + // g(t) = (t - t₀) / (t₁ - t₀) + // We then find the derivatives: + // f'(x) = 3ax² + 2bx + c + // g'(t) = 1 / (t₁ - t₀) + // (f∘g)'(t) = f'(g(t)) ⋅ g'(t) = (3ax² + 2bx + c) / (t₁ - t₀) + // We then plug in the given values: + // f(g(t₀)) = f(0) = p₀ + // ax³ + bx² + cx + d + // f(g(t₁)) = f(1) = p₁ + // (f∘g)'(t₀) = f'(0) ⋅ g'(t₀) = v₀ + // (f∘g)'(t₁) = f'(1) ⋅ g'(t₁) = v₁ + // This gives us four linear equations: + // a⋅0³ + b⋅0² + c⋅0 + d = p₀ + // a⋅1³ + b⋅1² + c⋅1 + d = p₁ + // (3a⋅0² + 2b⋅0 + c) / (t₁ - t₀) = v₀ + // (3a⋅1² + 2b⋅1 + c) / (t₁ - t₀) = v₁ + // Finally, we can solve for a, b, c, and d: + // a = 2p₀ - 2p₁ + (v₀ + v₁)(t₁ - t₀) + // b = -3p₀ + 3p₁ - (2v₀ + v₁)(t₁ - t₀) + // c = v₀(t₁ - t₀) + // d = p₀ + float float_duration = duration.Value(); + auto a = + 2.f * last_tip_state.position - 2.f * estimated_state.position + + (last_tip_state.velocity + estimated_state.velocity) * float_duration; + auto b = -3.f * last_tip_state.position + 3.f * estimated_state.position - + (2.f * last_tip_state.velocity + estimated_state.velocity) * + float_duration; + auto c = last_tip_state.velocity * float_duration; + auto d = last_tip_state.position; + + output->reserve(output->size() + n_points); + for (int i = 1; i <= n_points; ++i) { + float t = static_cast<float>(i) / n_points; + float t_squared = t * t; + float t_cubed = t_squared * t; + auto position = a * t_cubed + b * t_squared + c * t + d; + auto velocity = 3.f * a * t_squared + 2.f * b * t + c; + auto time = last_tip_state.time + duration * t; + output->push_back({position, velocity / float_duration, time}); + } +} + +int KalmanPredictor::NumberOfPointsToPredict( + const State &estimated_state) const { + const KalmanPredictorParams::ConfidenceParams &confidence_params = + predictor_params_.confidence_params; + + auto target_number = + static_cast<float>(predictor_params_.prediction_interval.Value() * + sampling_params_.min_output_rate); + + // The more samples we've received, the less effect the noise from each + // individual input affects the result. + float sample_ratio = + std::min(1.f, static_cast<float>(x_predictor_.NumIterations()) / + confidence_params.desired_number_of_samples); + + // The further the last given position is from the estimated position, the + // less confidence we have in the result. + float estimated_error = + Distance(*last_position_received_, estimated_state.position); + float normalized_error = + 1.f - Normalize01(0.f, confidence_params.max_estimation_distance, + estimated_error); + + // This is the state that the prediction would end at if we predicted the full + // interval (i.e. if confidence == 1). + auto end_state = + EvaluateCubic(estimated_state, predictor_params_.prediction_interval); + + // If the prediction is not traveling quickly, then changes in direction + // become more apparent, making the prediction appear wobbly. + float travel_speed = + Distance(estimated_state.position, end_state.position) / + static_cast<float>(predictor_params_.prediction_interval.Value()); + float normalized_distance = + Normalize01(confidence_params.min_travel_speed, + confidence_params.max_travel_speed, travel_speed); + + // If the actual prediction differs too much from the linear prediction, it + // suggests that the acceleration and jerk components overtake the velocity, + // resulting in a prediction that flies far off from the stroke. + float deviation_from_linear_prediction = Distance( + end_state.position, + estimated_state.position + + static_cast<float>(predictor_params_.prediction_interval.Value()) * + estimated_state.velocity); + float linearity = + Interp(confidence_params.baseline_linearity_confidence, 1.f, + 1.f - Normalize01(0.f, confidence_params.max_linear_deviation, + deviation_from_linear_prediction)); + + auto confidence = + sample_ratio * normalized_error * normalized_distance * linearity; + return std::ceil(target_number * confidence); +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/prediction/kalman_predictor.h b/ink_stroke_modeler/internal/prediction/kalman_predictor.h new file mode 100644 index 0000000..b848770 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_predictor.h @@ -0,0 +1,103 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_PREDICTOR_H_ +#define INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_PREDICTOR_H_ + +#include <deque> +#include <vector> + +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/internal/prediction/input_predictor.h" +#include "ink_stroke_modeler/internal/prediction/kalman_filter/axis_predictor.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// This class constructs a prediction by using a pair of Kalman filters (one +// each for the x- and y-dimension) to model the true state of the tip, assuming +// that the data we receive contains some noise. +// To construct a prediction, we first fetch the estimation of the position, +// velocity, acceleration, and jerk from the Kalman filters. The prediction is +// then constructed in two parts: one cubic spline that connects the last tip +// state to the estimation, constructed from the positions and velocities at the +// endpoints; and one cubic spline that extends into the future, constructed +// from the estimated position, velocity, acceleration, and jerk. +class KalmanPredictor : public InputPredictor { + public: + explicit KalmanPredictor(const KalmanPredictorParams &predictor_params, + const SamplingParams &sampling_params) + : predictor_params_(predictor_params), + sampling_params_(sampling_params), + x_predictor_(predictor_params_.process_noise, + predictor_params_.measurement_noise, + predictor_params_.min_stable_iteration), + y_predictor_(predictor_params_.process_noise, + predictor_params_.measurement_noise, + predictor_params_.min_stable_iteration) {} + + void Reset() override; + void Update(Vec2 position, Time time) override; + std::vector<TipState> ConstructPrediction( + const TipState &last_state) const override; + + struct State { + Vec2 position{0}; + Vec2 velocity{0}; + Vec2 acceleration{0}; + Vec2 jerk{0}; + }; + + // Returns the current estimate of the tip's true state, as modeled by the + // Kalman filters, or absl::nullopt if the predictor does not yet have enough + // data to make a reasonable estimate. + absl::optional<State> GetEstimatedState() const; + + private: + bool IsStable() const { + return x_predictor_.Stable() && y_predictor_.Stable(); + } + + static void ConstructCubicConnector(const TipState &last_tip_state, + const State &estimated_state, + const KalmanPredictorParams ¶ms, + Duration sample_dt, + std::vector<TipState> *output); + + static void ConstructCubicPrediction(const State &estimated_state, + const KalmanPredictorParams ¶ms, + Time start_time, Duration sample_dt, + int n_samples, + std::vector<TipState> *output); + + int NumberOfPointsToPredict(const State &estimated_state) const; + + KalmanPredictorParams predictor_params_; + SamplingParams sampling_params_; + + absl::optional<Vec2> last_position_received_; + + std::deque<Time> sample_times_; + + AxisPredictor x_predictor_; + AxisPredictor y_predictor_; +}; + +} // namespace stroke_model +} // namespace ink +#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_KALMAN_PREDICTOR_H_ diff --git a/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc b/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc new file mode 100644 index 0000000..66bac66 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc @@ -0,0 +1,215 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/prediction/kalman_predictor.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/types/optional.h" +#include "ink_stroke_modeler/internal/prediction/input_predictor.h" +#include "ink_stroke_modeler/internal/type_matchers.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { +namespace { + +using ::testing::ElementsAre; +using ::testing::Matcher; +using ::testing::Optional; + +constexpr float kTol = 1e-4; + +const KalmanPredictorParams kDefaultKalmanParams{ + .process_noise = .00026458, + .measurement_noise = .026458, + .min_catchup_velocity = .01, + .prediction_interval = Duration(1. / 60), + .confidence_params{.max_estimation_distance = .04, + .min_travel_speed = 3, + .max_travel_speed = 15, + .max_linear_deviation = .2}}; +constexpr SamplingParams kDefaultSamplingParams{ + .min_output_rate = 180, .end_of_stroke_stopping_distance = .001}; + +// Matcher for the State. Note that, because each of position, +// velocity, acceleration, and jerk are divided by increasing powers of the time +// delta, the values grow exponentially. As such, this uses a relative tolerance +// unless one of the arguments is exactly zero. +MATCHER_P5(StateNearMatcher, position, velocity, acceleration, jerk, + relative_tol, "") { + auto within_relative_tol = [](float lhs, float rhs, float tol) { + if (lhs == 0 || rhs == 0) { + return std::abs(lhs - rhs) < tol; + } + return std::abs(lhs / rhs - 1.f) < tol; + }; + + if (within_relative_tol(position.x, arg.position.x, relative_tol) && + within_relative_tol(position.y, arg.position.y, relative_tol) && + within_relative_tol(velocity.x, arg.velocity.x, relative_tol) && + within_relative_tol(velocity.y, arg.velocity.y, relative_tol) && + within_relative_tol(acceleration.x, arg.acceleration.x, relative_tol) && + within_relative_tol(acceleration.y, arg.acceleration.y, relative_tol) && + within_relative_tol(jerk.x, arg.jerk.x, relative_tol) && + within_relative_tol(jerk.y, arg.jerk.y, relative_tol)) { + return true; + } + + *result_listener << "\n expected:" // + << "\n p = " << position // + << "\n v = " << velocity // + << "\n a = " << acceleration // + << "\n j = " << jerk // + << "\n actual:" // + << "\n p = " << arg.position // + << "\n v = " << arg.velocity // + << "\n a = " << arg.acceleration // + << "\n j = " << arg.jerk; + return false; +} + +// Wrapping the matcher in a function allows the compiler to perform template +// deduction, so we can specify arguments as initializer lists. +Matcher<KalmanPredictor::State> StateNear(Vec2 position, Vec2 velocity, + Vec2 acceleration, Vec2 jerk, + float tolerance) { + return StateNearMatcher(position, velocity, acceleration, jerk, tolerance); +} + +TEST(KalmanPredictorTest, EmptyPrediction) { + KalmanPredictor predictor{kDefaultKalmanParams, kDefaultSamplingParams}; + EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt); + EXPECT_TRUE( + predictor.ConstructPrediction({{4, 3}, {2, -4}, Time{3}}).empty()); + + predictor.Update({1, 3}, Time{4}); + EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt); + EXPECT_TRUE( + predictor.ConstructPrediction({{1, 3}, {0, 0}, Time{3.1}}).empty()); +} + +TEST(KalmanPredictorTest, TypicalCase) { + KalmanPredictor predictor{kDefaultKalmanParams, kDefaultSamplingParams}; + + predictor.Update({0, 0}, Time{0}); + predictor.Update({.1, 0}, Time{.01}); + predictor.Update({.2, 0}, Time{.02}); + EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt); + EXPECT_TRUE( + predictor.ConstructPrediction({{4, 3}, {2, -4}, Time{3}}).empty()); + + predictor.Update({.3, 0}, Time{.03}); + EXPECT_THAT(predictor.GetEstimatedState(), + Optional(StateNear({.30078, 0}, {13.584, 0}, {-66.806, 0}, + {-3382.8, 0}, kTol))); + EXPECT_THAT( + predictor.ConstructPrediction({{.2, 0}, {10, 0}, Time{.03}}), + ElementsAre(TipStateNear({{.2454, 0}, {7.7094, 0}, Time{.0356}}, kTol), + TipStateNear({{.3008, 0}, {13.5837, 0}, Time{.0411}}, kTol), + TipStateNear({{.3751, 0}, {13.1604, 0}, Time{.0467}}, kTol))); + + predictor.Update({.5, .1}, Time{.04}); + EXPECT_THAT(predictor.GetEstimatedState(), + Optional(StateNear({.49705, .097146}, {28.217, 16.732}, + {671.91, 813.82}, {4454.3, 6998.2}, kTol))); + EXPECT_THAT( + predictor.ConstructPrediction({{.3, 0}, {10, 0}, Time{.04}}), + ElementsAre( + TipStateNear({{.3732, .0253}, {17.047, 8.9317}, Time{.0456}}, kTol), + TipStateNear({{.497, .0971}, {28.2172, 16.7319}, Time{.0511}}, kTol), + TipStateNear({{.6643, .2029}, {32.0188, 21.3611}, Time{.0567}}, + kTol))); +} + +TEST(KalmanPredictorTest, AlternateParams) { + auto kalman_params = kDefaultKalmanParams; + auto sampling_params = kDefaultSamplingParams; + kalman_params.prediction_interval = Duration(.025); + sampling_params.min_output_rate = 200; + KalmanPredictor predictor{kalman_params, sampling_params}; + + predictor.Update({2, 5}, Time{1}); + predictor.Update({2.2, 4.9}, Time{1.02}); + predictor.Update({2.3, 4.7}, Time{1.04}); + predictor.Update({2.3, 4.4}, Time{1.06}); + EXPECT_THAT( + predictor.GetEstimatedState(), + Optional(StateNear({2.3016, 4.3992}, {-3.9981, -24.374}, + {-338.22, -288.12}, {-1852.9, -584.31}, kTol))); + EXPECT_THAT( + predictor.ConstructPrediction({{2.25, 4.75}, {1, -20}, Time{1.06}}), + ElementsAre( + TipStateNear({{2.27, 4.6417}, {5.917, -23.0547}, Time{1.065}}, kTol), + TipStateNear({{2.2982, 4.5221}, {4.251, -24.5126}, Time{1.07}}, kTol), + TipStateNear({{2.3016, 4.3992}, {-3.9981, -24.3736}, Time{1.075}}, + kTol), + TipStateNear({{2.2773, 4.2738}, {-5.7123, -25.8215}, Time{1.08}}, + kTol))); + + predictor.Update({2.2, 4.2}, Time{1.08}); + EXPECT_THAT(predictor.GetEstimatedState(), + Optional(StateNear({2.1987, 4.1933}, {-11.457, -11.953}, + {-328.01, 185.32}, {-1133.8, 1569.8}, kTol))); + EXPECT_THAT( + predictor.ConstructPrediction({{2.25, 4.5}, {-1, -20}, Time{1.08}}), + ElementsAre( + TipStateNear({{2.2499, 4.407}, {.5082, -17.2661}, Time{1.085}}, kTol), + TipStateNear({{2.2505, 4.3265}, {-.7319, -15.0137}, Time{1.09}}, + kTol), + TipStateNear({{2.238, 4.2561}, {-4.7203, -13.2427}, Time{1.095}}, + kTol), + TipStateNear({{2.1987, 4.1933}, {-11.4569, -11.9531}, Time{1.1}}, + kTol), + TipStateNear({{2.1373, 4.1359}, {-13.1112, -11.0068}, Time{1.105}}, + kTol))); +} + +TEST(KalmanPredictorTest, Reset) { + KalmanPredictor predictor{kDefaultKalmanParams, kDefaultSamplingParams}; + + predictor.Update({4, -4}, Time{6}); + predictor.Update({-6, 9}, Time{6.03}); + predictor.Update({10, 5}, Time{6.06}); + EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt); + EXPECT_TRUE( + predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{6.06}}).empty()); + + predictor.Update({2, 4}, Time{6.09}); + EXPECT_NE(predictor.GetEstimatedState(), absl::nullopt); + EXPECT_FALSE( + predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{6.06}}).empty()); + + predictor.Reset(); + EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt); + EXPECT_TRUE( + predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{6.09}}).empty()); + + predictor.Update({-9, 3}, Time{2}); + predictor.Update({-6, -1}, Time{2.1}); + predictor.Update({6, -6}, Time{2.2}); + EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt); + EXPECT_TRUE( + predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{2.2}}).empty()); + + predictor.Update({3, 6}, Time{2.3}); + EXPECT_NE(predictor.GetEstimatedState(), absl::nullopt); + EXPECT_FALSE( + predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{2.3}}).empty()); +} + +} // namespace +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/prediction/stroke_end_predictor.cc b/ink_stroke_modeler/internal/prediction/stroke_end_predictor.cc new file mode 100644 index 0000000..3e83720 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/stroke_end_predictor.cc @@ -0,0 +1,52 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/prediction/stroke_end_predictor.h" + +#include <iterator> +#include <vector> + +#include "absl/types/optional.h" +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/internal/position_modeler.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +void StrokeEndPredictor::Update(Vec2 position, Time time) { + last_position_ = position; +} + +std::vector<TipState> StrokeEndPredictor::ConstructPrediction( + const TipState &last_state) const { + if (!last_position_) { + // We don't yet have enough data to construct a prediction. + return {}; + } + + std::vector<TipState> prediction; + prediction.reserve(sampling_params_.end_of_stroke_max_iterations); + PositionModeler modeler; + modeler.Reset(last_state, position_modeler_params_); + modeler.ModelEndOfStroke(*last_position_, + Duration(1. / sampling_params_.min_output_rate), + sampling_params_.end_of_stroke_max_iterations, + sampling_params_.end_of_stroke_stopping_distance, + std::back_inserter(prediction)); + return prediction; +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/prediction/stroke_end_predictor.h b/ink_stroke_modeler/internal/prediction/stroke_end_predictor.h new file mode 100644 index 0000000..420d596 --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/stroke_end_predictor.h @@ -0,0 +1,58 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_PREDICTION_STROKE_END_PREDICTOR_H_ +#define INK_STROKE_MODELER_INTERNAL_PREDICTION_STROKE_END_PREDICTOR_H_ + +#include <vector> + +#include "absl/types/optional.h" +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/internal/prediction/input_predictor.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// This class constructs a prediction using the same PositionModeler class as +// the SpringBasedModeler, fixing the anchor position and allowing the stroke to +// "catch up". The way the prediction is constructed is very similar to how the +// SpringBasedModeler models the end of a stroke. +class StrokeEndPredictor : public InputPredictor { + public: + explicit StrokeEndPredictor( + const PositionModelerParams &position_modeler_params, + const SamplingParams &sampling_params) + : position_modeler_params_(position_modeler_params), + sampling_params_(sampling_params) {} + + void Reset() override { last_position_ = absl::nullopt; } + void Update(Vec2 position, Time time) override; + std::vector<TipState> ConstructPrediction( + const TipState &last_state) const override; + + private: + PositionModelerParams position_modeler_params_; + SamplingParams sampling_params_; + + absl::optional<Vec2> last_position_; +}; + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_PREDICTION_STROKE_END_PREDICTOR_H_ diff --git a/ink_stroke_modeler/internal/prediction/stroke_end_predictor_test.cc b/ink_stroke_modeler/internal/prediction/stroke_end_predictor_test.cc new file mode 100644 index 0000000..5a7750d --- /dev/null +++ b/ink_stroke_modeler/internal/prediction/stroke_end_predictor_test.cc @@ -0,0 +1,137 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/prediction/stroke_end_predictor.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "ink_stroke_modeler/internal/prediction/input_predictor.h" +#include "ink_stroke_modeler/internal/type_matchers.h" +#include "ink_stroke_modeler/params.h" + +namespace ink { +namespace stroke_model { +namespace { + +using ::testing::ElementsAre; +using ::testing::IsEmpty; +using ::testing::Not; + +constexpr float kTol = 1e-4; + +constexpr SamplingParams kDefaultSamplingParams{ + .min_output_rate = 180, + .end_of_stroke_stopping_distance = .001, + .end_of_stroke_max_iterations = 20}; + +TEST(StrokeEndPredictorTest, EmptyPrediction) { + StrokeEndPredictor predictor{PositionModelerParams{}, kDefaultSamplingParams}; + EXPECT_THAT(predictor.ConstructPrediction({{4, 6}, {-1, 1}, Time{5}}), + IsEmpty()); + + predictor.Reset(); + EXPECT_THAT(predictor.ConstructPrediction({{-2, 11}, {0, 0}, Time{1}}), + IsEmpty()); +} + +TEST(StrokeEndPredictorTest, SingleInput) { + StrokeEndPredictor predictor{PositionModelerParams{}, kDefaultSamplingParams}; + predictor.Update({4, 5}, Time{2}); + + EXPECT_THAT(predictor.ConstructPrediction({{4, 5}, {0, 0}, Time{2}}), + IsEmpty()); +} + +TEST(StrokeEndPredictorTest, MultipleInputs) { + StrokeEndPredictor predictor{PositionModelerParams{}, kDefaultSamplingParams}; + + predictor.Update({-1, 1}, Time{1}); + EXPECT_THAT(predictor.ConstructPrediction({{-1, 1}, {0, 0}, Time{1}}), + IsEmpty()); + + predictor.Update({-1, 1.2}, Time{1.02}); + EXPECT_THAT( + predictor.ConstructPrediction({{-1, 1.1}, {0, 5}, Time{1.02}}), + ElementsAre( + TipStateNear({{-1, 1.1258}, {0, 4.6364}, Time{1.0256}}, kTol), + TipStateNear({{-1, 1.1480}, {0, 3.9967}, Time{1.0311}}, kTol), + TipStateNear({{-1, 1.1660}, {0, 3.2496}, Time{1.0367}}, kTol), + TipStateNear({{-1, 1.1799}, {0, 2.5059}, Time{1.0422}}, kTol), + TipStateNear({{-1, 1.1901}, {0, 1.8318}, Time{1.0478}}, kTol), + TipStateNear({{-1, 1.1971}, {0, 1.2609}, Time{1.0533}}, kTol), + TipStateNear({{-1, 1.2000}, {0, 1.0323}, Time{1.0561}}, kTol))); + + predictor.Update({-1, 1.4}, Time{1.04}); + EXPECT_THAT( + predictor.ConstructPrediction({{-1, 1.2}, {0, 5}, Time{1.04}}), + ElementsAre( + TipStateNear({{-1, 1.2348}, {0, 6.2727}, Time{1.0455}}, kTol), + TipStateNear({{-1, 1.2708}, {0, 6.4661}, Time{1.0511}}, kTol), + TipStateNear({{-1, 1.3041}, {0, 5.9943}, Time{1.0566}}, kTol), + TipStateNear({{-1, 1.3328}, {0, 5.1663}, Time{1.0622}}, kTol), + TipStateNear({{-1, 1.3561}, {0, 4.1998}, Time{1.0677}}, kTol), + TipStateNear({{-1, 1.3741}, {0, 3.2381}, Time{1.0733}}, kTol), + TipStateNear({{-1, 1.3872}, {0, 2.3668}, Time{1.0788}}, kTol), + TipStateNear({{-1, 1.3963}, {0, 1.6288}, Time{1.0844}}, kTol), + TipStateNear({{-1, 1.4000}, {0, 1.3333}, Time{1.0872}}, kTol))); +} + +TEST(StrokeEndPredictorTest, Reset) { + StrokeEndPredictor predictor{PositionModelerParams{}, kDefaultSamplingParams}; + + predictor.Update({-9, 6}, Time{5}); + EXPECT_THAT(predictor.ConstructPrediction({{-9, 6}, {0, 0}, Time{5}}), + IsEmpty()); + predictor.Update({1, 4}, Time{7}); + EXPECT_THAT(predictor.ConstructPrediction({{-4, 5}, {5, -1}, Time{7}}), + Not(IsEmpty())); + + predictor.Reset(); + EXPECT_THAT(predictor.ConstructPrediction({{0, 1}, {0, 0}, Time{1}}), + IsEmpty()); +} + +TEST(StrokeEndPredictorTest, AlternateSamplingParams) { + StrokeEndPredictor predictor{ + PositionModelerParams{}, + SamplingParams{.min_output_rate = 200, + .end_of_stroke_stopping_distance = .005}}; + + predictor.Update({4, -7}, Time{3}); + EXPECT_THAT(predictor.ConstructPrediction({{4, -7}, {0, 0}, Time{3}}), + IsEmpty()); + + predictor.Update({4.2, -6.8}, Time{3.01}); + EXPECT_THAT( + predictor.ConstructPrediction({{4.1, -6.9}, {2, 2}, Time{3.01}}), + ElementsAre( + TipStateNear({{4.1138, -6.8862}, {2.7527, 2.7527}, Time{3.015}}, + kTol), + TipStateNear({{4.1289, -6.8711}, {3.0318, 3.0318}, Time{3.02}}, kTol), + TipStateNear({{4.1439, -6.8561}, {2.9871, 2.9871}, Time{3.025}}, + kTol), + TipStateNear({{4.1576, -6.8424}, {2.7386, 2.7386}, Time{3.03}}, kTol), + TipStateNear({{4.1694, -6.8306}, {2.3779, 2.3779}, Time{3.035}}, + kTol), + TipStateNear({{4.1793, -6.8207}, {1.9719, 1.9719}, Time{3.04}}, kTol), + TipStateNear({{4.1871, -6.8129}, {1.5669, 1.5669}, Time{3.045}}, + kTol), + TipStateNear({{4.1931, -6.8069}, {1.1923, 1.1923}, Time{3.05}}, kTol), + TipStateNear({{4.1974, -6.8026}, {0.8647, 0.8647}, Time{3.055}}, + kTol))); +} + +} // namespace +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/stylus_state_modeler.cc b/ink_stroke_modeler/internal/stylus_state_modeler.cc new file mode 100644 index 0000000..e7c1bbc --- /dev/null +++ b/ink_stroke_modeler/internal/stylus_state_modeler.cc @@ -0,0 +1,101 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/stylus_state_modeler.h" + +#include <limits> + +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/internal/utils.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +void StylusStateModeler::Update(Vec2 position, const StylusState &state) { + if (state.pressure < 0) received_unknown_pressure_ = true; + if (state.tilt < 0) received_unknown_tilt_ = true; + if (state.orientation < 0) received_unknown_orientation_ = true; + + if (received_unknown_pressure_ && received_unknown_tilt_ && + received_unknown_orientation_) { + // We've stopped tracking all fields, so there's no need to keep updating. + positions_and_states_.clear(); + return; + } + + positions_and_states_.push_back({position, state}); + + if (params_.max_input_samples < 0 || + positions_and_states_.size() > (uint)params_.max_input_samples) { + positions_and_states_.pop_front(); + } +} + +void StylusStateModeler::Reset(const StylusStateModelerParams ¶ms) { + params_ = params; + positions_and_states_.clear(); + received_unknown_pressure_ = false; + received_unknown_tilt_ = false; + received_unknown_orientation_ = false; +} + +StylusState StylusStateModeler::Query(Vec2 position) const { + if (positions_and_states_.empty()) + return {.pressure = -1, .tilt = -1, .orientation = -1}; + + if (positions_and_states_.size() == 1) { + const auto &state = positions_and_states_.front().state; + return { + .pressure = received_unknown_pressure_ ? -1 : state.pressure, + .tilt = received_unknown_tilt_ ? -1 : state.tilt, + .orientation = received_unknown_orientation_ ? -1 : state.orientation}; + } + + int closest_segment = -1; + float min_distance = std::numeric_limits<float>::infinity(); + float interp_value = 0; + for (decltype(positions_and_states_.size()) i = 0; + i < positions_and_states_.size() - 1; ++i) { + const Vec2 segment_start = positions_and_states_[i].position; + const Vec2 segment_end = positions_and_states_[i + 1].position; + float param = NearestPointOnSegment(segment_start, segment_end, position); + float distance = + Distance(position, Interp(segment_start, segment_end, param)); + if (distance <= min_distance) { + closest_segment = i; + min_distance = distance; + interp_value = param; + } + } + + auto from_state = positions_and_states_[closest_segment].state; + auto to_state = positions_and_states_[closest_segment + 1].state; + return StylusState{ + .pressure = + received_unknown_pressure_ + ? -1 + : Interp(from_state.pressure, to_state.pressure, interp_value), + .tilt = received_unknown_tilt_ + ? -1 + : Interp(from_state.tilt, to_state.tilt, interp_value), + .orientation = received_unknown_orientation_ + ? -1 + : InterpAngle(from_state.orientation, + to_state.orientation, interp_value)}; +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/stylus_state_modeler.h b/ink_stroke_modeler/internal/stylus_state_modeler.h new file mode 100644 index 0000000..a5b925f --- /dev/null +++ b/ink_stroke_modeler/internal/stylus_state_modeler.h @@ -0,0 +1,82 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_STYLUS_STATE_MODELER_H_ +#define INK_STROKE_MODELER_INTERNAL_STYLUS_STATE_MODELER_H_ + +#include <deque> +#include <ostream> + +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// This class is used to model the state of the stylus for a given position, +// based on the state of the stylus at the original input points. +// +// The stylus is modeled by storing the last max_input_samples positions and +// states received via Update(); when queried, it treats the stored positions as +// a polyline, and finds the closest segment. The returned stylus state is a +// linear interpolation between the states associated with the endpoints of the +// segment, correcting angles to account for the "wraparound" that occurs at 0 +// and 2π. The value used for interpolation is based on how far along the +// segment the closest point lies. +// +// If Update() is called with a state in which a field (i.e. pressure, tilt, or +// orientation) has a negative value (indicating no information), then the +// results of Query() will be -1 for that field until Reset() is called. This is +// tracked independently for each field; e.g., if you pass in tilt = -1, then +// pressure and orientation will continue to be interpolated normally. +class StylusStateModeler { + public: + // Adds a position and state pair to the model. During stroke modeling, these + // values will be taken from the raw input. + void Update(Vec2 position, const StylusState &state); + + // Clear the model and reset. + void Reset(const StylusStateModelerParams ¶ms); + + // Query the model for the state at the given position. During stroke + // modeling, the position will be taken from the modeled input. + // + // If no Update() calls have been received since the last Reset(), this will + // return {.pressure = -1, .tilt = -1, .orientation = -1}. + StylusState Query(Vec2 position) const; + + private: + struct PositionAndState { + Vec2 position{0}; + StylusState state; + + PositionAndState(Vec2 position_in, const StylusState &state_in) + : position(position_in), state(state_in) {} + }; + + bool received_unknown_pressure_ = false; + bool received_unknown_tilt_ = false; + bool received_unknown_orientation_ = false; + + std::deque<PositionAndState> positions_and_states_; + StylusStateModelerParams params_; +}; + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_STYLUS_STATE_MODELER_H_ diff --git a/ink_stroke_modeler/internal/stylus_state_modeler_test.cc b/ink_stroke_modeler/internal/stylus_state_modeler_test.cc new file mode 100644 index 0000000..c905ec4 --- /dev/null +++ b/ink_stroke_modeler/internal/stylus_state_modeler_test.cc @@ -0,0 +1,269 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/stylus_state_modeler.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/internal/type_matchers.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { +namespace { + +constexpr float kTol = 1e-5; +constexpr StylusState kUnknown{.pressure = -1, .tilt = -1, .orientation = -1}; + +TEST(StylusStateModelerTest, QueryEmpty) { + StylusStateModeler modeler; + EXPECT_EQ(modeler.Query({0, 0}), kUnknown); + EXPECT_EQ(modeler.Query({-5, 3}), kUnknown); +} + +TEST(StylusStateModelerTest, QuerySingleInput) { + StylusStateModeler modeler; + modeler.Update({0, 0}, {.pressure = 0.75, .tilt = 0.75, .orientation = 0.75}); + EXPECT_THAT(modeler.Query({0, 0}), + StylusStateNear( + {.pressure = .75, .tilt = .75, .orientation = .75}, kTol)); + EXPECT_THAT(modeler.Query({1, 1}), + StylusStateNear( + {.pressure = .75, .tilt = .75, .orientation = .75}, kTol)); +} + +TEST(StylusStateModelerTest, QueryMultipleInputs) { + StylusStateModeler modeler; + modeler.Update({.5, 1.5}, {.pressure = .3, .tilt = .8, .orientation = .1}); + modeler.Update({2, 1.5}, {.pressure = .6, .tilt = .5, .orientation = .7}); + modeler.Update({3, 3.5}, {.pressure = .8, .tilt = .1, .orientation = .3}); + modeler.Update({3.5, 4}, {.pressure = .2, .tilt = .2, .orientation = .2}); + + EXPECT_THAT( + modeler.Query({0, 2}), + StylusStateNear({.pressure = .3, .tilt = .8, .orientation = .1}, kTol)); + EXPECT_THAT( + modeler.Query({1, 2}), + StylusStateNear({.pressure = .4, .tilt = .7, .orientation = .3}, kTol)); + EXPECT_THAT( + modeler.Query({2, 1.5}), + StylusStateNear({.pressure = .6, .tilt = .5, .orientation = .7}, kTol)); + EXPECT_THAT( + modeler.Query({2.5, 1.875}), + StylusStateNear({.pressure = .65, .tilt = .4, .orientation = .6}, kTol)); + EXPECT_THAT( + modeler.Query({2.5, 3.125}), + StylusStateNear({.pressure = .75, .tilt = .2, .orientation = .4}, kTol)); + EXPECT_THAT( + modeler.Query({2.5, 4}), + StylusStateNear({.pressure = .8, .tilt = .1, .orientation = .3}, kTol)); + EXPECT_THAT( + modeler.Query({3, 4}), + StylusStateNear({.pressure = .5, .tilt = .15, .orientation = .25}, kTol)); + EXPECT_THAT( + modeler.Query({4, 4}), + StylusStateNear({.pressure = .2, .tilt = .2, .orientation = .2}, kTol)); +} + +TEST(StylusStateModelerTest, QueryStaleInputsAreDiscarded) { + StylusStateModeler modeler; + modeler.Update({1, 1}, {.pressure = .6, .tilt = .5, .orientation = .4}); + modeler.Update({-1, 2}, {.pressure = .3, .tilt = .7, .orientation = .6}); + modeler.Update({-4, 0}, {.pressure = .9, .tilt = .7, .orientation = .3}); + modeler.Update({-6, -3}, {.pressure = .4, .tilt = .3, .orientation = .5}); + modeler.Update({-5, -5}, {.pressure = .3, .tilt = .3, .orientation = .1}); + modeler.Update({-3, -4}, {.pressure = .6, .tilt = .8, .orientation = .3}); + modeler.Update({-6, -7}, {.pressure = .9, .tilt = .8, .orientation = .1}); + modeler.Update({-9, -8}, {.pressure = .8, .tilt = .2, .orientation = .2}); + modeler.Update({-11, -5}, {.pressure = .2, .tilt = .4, .orientation = .7}); + modeler.Update({-10, -2}, {.pressure = .7, .tilt = .3, .orientation = .2}); + + EXPECT_THAT( + modeler.Query({2, 0}), + StylusStateNear({.pressure = .6, .tilt = .5, .orientation = .4}, kTol)); + EXPECT_THAT( + modeler.Query({1, 3.5}), + StylusStateNear({.pressure = .45, .tilt = .6, .orientation = .5}, kTol)); + EXPECT_THAT( + modeler.Query({-3, 17. / 6}), + StylusStateNear({.pressure = .5, .tilt = .7, .orientation = .5}, kTol)); + + // This causes the point at {1, 1} to be discarded. + modeler.Update({-8, 0}, {.pressure = .6, .tilt = .8, .orientation = .9}); + EXPECT_THAT( + modeler.Query({2, 0}), + StylusStateNear({.pressure = .3, .tilt = .7, .orientation = .6}, kTol)); + EXPECT_THAT( + modeler.Query({1, 3.5}), + StylusStateNear({.pressure = .3, .tilt = .7, .orientation = .6}, kTol)); + EXPECT_THAT( + modeler.Query({-3, 17. / 6}), + StylusStateNear({.pressure = .5, .tilt = .7, .orientation = .5}, kTol)); + + // This causes the point at {-1, 2} to be discarded. + modeler.Update({-8, 0}, {.pressure = .6, .tilt = .8, .orientation = .9}); + EXPECT_THAT( + modeler.Query({2, 0}), + StylusStateNear({.pressure = .9, .tilt = .7, .orientation = .3}, kTol)); + EXPECT_THAT( + modeler.Query({1, 3.5}), + StylusStateNear({.pressure = .9, .tilt = .7, .orientation = .3}, kTol)); + EXPECT_THAT( + modeler.Query({-3, 17. / 6}), + StylusStateNear({.pressure = .9, .tilt = .7, .orientation = .3}, kTol)); +} + +TEST(StylusStateModelerTest, QueryCyclicOrientationInterpolation) { + StylusStateModeler modeler; + modeler.Update({0, 0}, {.pressure = 0, .tilt = 0, .orientation = 1.8 * M_PI}); + modeler.Update({0, 1}, {.pressure = 0, .tilt = 0, .orientation = .2 * M_PI}); + modeler.Update({0, 2}, {.pressure = 0, .tilt = 0, .orientation = 1.6 * M_PI}); + + EXPECT_NEAR(modeler.Query({0, .25}).orientation, 1.9 * M_PI, 1e-5); + EXPECT_NEAR(modeler.Query({0, .75}).orientation, .1 * M_PI, 1e-5); + EXPECT_NEAR(modeler.Query({0, 1.25}).orientation, .05 * M_PI, 1e-5); + EXPECT_NEAR(modeler.Query({0, 1.75}).orientation, 1.75 * M_PI, 1e-5); +} + +TEST(StylusStateModelerTest, QueryAndReset) { + StylusStateModeler modeler; + + modeler.Update({4, 5}, {.pressure = .4, .tilt = .9, .orientation = .1}); + modeler.Update({7, 8}, {.pressure = .1, .tilt = .2, .orientation = .5}); + EXPECT_THAT( + modeler.Query({10, 12}), + StylusStateNear({.pressure = .1, .tilt = .2, .orientation = .5}, kTol)); + + modeler.Reset(StylusStateModelerParams{}); + EXPECT_EQ(modeler.Query({10, 12}), kUnknown); + + modeler.Update({-1, 4}, {.pressure = .4, .tilt = .6, .orientation = .8}); + EXPECT_THAT( + modeler.Query({6, 7}), + StylusStateNear({.pressure = .4, .tilt = .6, .orientation = .8}, kTol)); + + modeler.Update({-3, 0}, {.pressure = .7, .tilt = .2, .orientation = .5}); + EXPECT_THAT( + modeler.Query({-2, 2}), + StylusStateNear({.pressure = .55, .tilt = .4, .orientation = .65}, kTol)); + EXPECT_THAT( + modeler.Query({0, 5}), + StylusStateNear({.pressure = .4, .tilt = .6, .orientation = .8}, kTol)); +} + +TEST(StylusStateModelerTest, UpdateWithUnknownState) { + StylusStateModeler modeler; + + modeler.Update({1, 2}, {.pressure = .1, .tilt = .2, .orientation = .3}); + modeler.Update({2, 3}, {.pressure = .3, .tilt = .4, .orientation = .5}); + EXPECT_THAT( + modeler.Query({2, 2}), + StylusStateNear({.pressure = .2, .tilt = .3, .orientation = .4}, kTol)); + + modeler.Update({5, 5}, kUnknown); + EXPECT_EQ(modeler.Query({5, 5}), kUnknown); + + modeler.Update({2, 3}, {.pressure = .3, .tilt = .4, .orientation = .5}); + EXPECT_EQ(modeler.Query({1, 2}), kUnknown); + + modeler.Update({-1, 3}, kUnknown); + EXPECT_EQ(modeler.Query({7, 9}), kUnknown); + + modeler.Reset(StylusStateModelerParams{}); + modeler.Update({3, 3}, {.pressure = .7, .tilt = .6, .orientation = .5}); + EXPECT_THAT( + modeler.Query({3, 3}), + StylusStateNear({.pressure = .7, .tilt = .6, .orientation = .5}, kTol)); +} + +TEST(StylusStateModelerTest, ModelPressureOnly) { + StylusStateModeler modeler; + + modeler.Update({0, 0}, {.pressure = .5, .tilt = -2, .orientation = -.1}); + EXPECT_THAT( + modeler.Query({1, 1}), + StylusStateNear({.pressure = .5, .tilt = -1, .orientation = -1}, kTol)); + + modeler.Update({2, 0}, {.pressure = .7, .tilt = -2, .orientation = -.1}); + EXPECT_THAT( + modeler.Query({1, 1}), + StylusStateNear({.pressure = .6, .tilt = -1, .orientation = -1}, kTol)); +} + +TEST(StylusStateModelerTest, ModelTiltOnly) { + StylusStateModeler modeler; + + modeler.Update({0, 0}, {.pressure = -2, .tilt = .5, .orientation = -.1}); + EXPECT_THAT( + modeler.Query({1, 1}), + StylusStateNear({.pressure = -1, .tilt = .5, .orientation = -1}, kTol)); + + modeler.Update({2, 0}, {.pressure = -2, .tilt = .3, .orientation = -.1}); + EXPECT_THAT( + modeler.Query({1, 1}), + StylusStateNear({.pressure = -1, .tilt = .4, .orientation = -1}, kTol)); +} + +TEST(StylusStateModelerTest, ModelOrientationOnly) { + StylusStateModeler modeler; + + modeler.Update({0, 0}, {.pressure = -2, .tilt = -.1, .orientation = 1}); + EXPECT_THAT( + modeler.Query({1, 1}), + StylusStateNear({.pressure = -1, .tilt = -1, .orientation = 1}, kTol)); + + modeler.Update({2, 0}, {.pressure = -2, .tilt = -.3, .orientation = 2}); + EXPECT_THAT( + modeler.Query({1, 1}), + StylusStateNear({.pressure = -1, .tilt = -1, .orientation = 1.5}, kTol)); +} + +TEST(StylusStateModelerTest, DropFieldsOneByOne) { + StylusStateModeler modeler; + + modeler.Update({0, 0}, {.pressure = .5, .tilt = .5, .orientation = .5}); + EXPECT_THAT( + modeler.Query({1, 0}), + StylusStateNear({.pressure = .5, .tilt = .5, .orientation = .5}, kTol)); + + modeler.Update({2, 0}, {.pressure = .3, .tilt = .7, .orientation = -1}); + EXPECT_THAT( + modeler.Query({1, 0}), + StylusStateNear({.pressure = .4, .tilt = .6, .orientation = -1}, kTol)); + + modeler.Update({4, 0}, {.pressure = .1, .tilt = -1, .orientation = 1}); + EXPECT_THAT( + modeler.Query({3, 0}), + StylusStateNear({.pressure = .2, .tilt = -1, .orientation = -1}, kTol)); + + modeler.Update({6, 0}, {.pressure = -1, .tilt = .2, .orientation = 0}); + EXPECT_THAT(modeler.Query({5, 0}), StylusStateNear(kUnknown, kTol)); + + modeler.Update({8, 0}, {.pressure = .3, .tilt = .4, .orientation = .5}); + EXPECT_THAT(modeler.Query({7, 0}), StylusStateNear(kUnknown, kTol)); + + modeler.Reset(StylusStateModelerParams{}); + EXPECT_THAT(modeler.Query({1, 0}), StylusStateNear(kUnknown, kTol)); + + modeler.Update({0, 0}, {.pressure = .1, .tilt = .8, .orientation = .3}); + EXPECT_THAT( + modeler.Query({1, 0}), + StylusStateNear({.pressure = .1, .tilt = .8, .orientation = .3}, kTol)); +} + +} // namespace +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/type_matchers.cc b/ink_stroke_modeler/internal/type_matchers.cc new file mode 100644 index 0000000..f974b61 --- /dev/null +++ b/ink_stroke_modeler/internal/type_matchers.cc @@ -0,0 +1,69 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/type_matchers.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { +namespace { + +using ::testing::DoubleNear; +using ::testing::FloatEq; +using ::testing::FloatNear; +using ::testing::Matcher; +using ::testing::Matches; + +MATCHER_P(Vec2EqMatcher, expected, "") { + return Matches(FloatEq(expected.x))(arg.x) && + Matches(FloatEq(expected.y))(arg.y); +} + +MATCHER_P2(Vec2NearMatcher, expected, tolerance, "") { + return Matches(FloatNear(expected.x, tolerance))(arg.x) && + Matches(FloatNear(expected.y, tolerance))(arg.y); +} +MATCHER_P2(TipStateNearMatcher, expected, tolerance, "") { + return Matches(Vec2Near(expected.position, tolerance))(arg.position) && + Matches(Vec2Near(expected.velocity, tolerance))(arg.velocity) && + Matches(DoubleNear(expected.time.Value(), tolerance))( + arg.time.Value()); +} + +MATCHER_P2(StylusStateNearMatcher, expected, tolerance, "") { + return Matches(FloatNear(expected.pressure, tolerance))(arg.pressure) && + Matches(FloatNear(expected.tilt, tolerance))(arg.tilt) && + Matches(FloatNear(expected.orientation, tolerance))(arg.orientation); +} + +} // namespace + +Matcher<Vec2> Vec2Eq(const Vec2 v) { return Vec2EqMatcher(v); } +Matcher<Vec2> Vec2Near(const Vec2 v, float tolerance) { + return Vec2NearMatcher(v, tolerance); +} +Matcher<TipState> TipStateNear(const TipState &expected, float tolerance) { + return TipStateNearMatcher(expected, tolerance); +} +Matcher<StylusState> StylusStateNear(const StylusState &expected, + float tolerance) { + return StylusStateNearMatcher(expected, tolerance); +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/type_matchers.h b/ink_stroke_modeler/internal/type_matchers.h new file mode 100644 index 0000000..97f8876 --- /dev/null +++ b/ink_stroke_modeler/internal/type_matchers.h @@ -0,0 +1,42 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_TYPE_MATCHERS_H_ +#define INK_STROKE_MODELER_INTERNAL_TYPE_MATCHERS_H_ + +#include "gtest/gtest.h" +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// These matchers compare Vec2s component-wise, delegating to +// ::testing::FloatEq() and ::testing::FloatNear(), respectively. +::testing::Matcher<Vec2> Vec2Eq(const Vec2 v); +::testing::Matcher<Vec2> Vec2Near(const Vec2 v, float tolerance); + +// These convenience matchers perform comparisons using ::testing::FloatNear(), +// ::testing::DoubleNear(), and Vec2Near(). +::testing::Matcher<TipState> TipStateNear(const TipState &expected, + float tolerance); +::testing::Matcher<StylusState> StylusStateNear(const StylusState &expected, + float tolerance); + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_TYPE_MATCHERS_H_ diff --git a/ink_stroke_modeler/internal/utils.h b/ink_stroke_modeler/internal/utils.h new file mode 100644 index 0000000..6bf2a27 --- /dev/null +++ b/ink_stroke_modeler/internal/utils.h @@ -0,0 +1,99 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_UTILS_H_ +#define INK_STROKE_MODELER_INTERNAL_UTILS_H_ + +#include <cmath> + +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// General utility functions for use within the stroke model. + +// Clamps the given value to the range [0, 1]. +inline float Clamp01(float value) { + if (value < 0.f) return 0.f; + if (value > 1.f) return 1.f; + return value; +} + +// Returns the ratio of the difference from `start` to `value` and the +// difference from `start` to `end`, clamped to the range [0, 1]. If +// `start` == `end`, returns 1 if `value` > `start`, 0 otherwise. +inline float Normalize01(float start, float end, float value) { + if (start == end) { + return value > start ? 1 : 0; + } + return Clamp01((value - start) / (end - start)); +} + +// Linearly interpolates between `start` and `end`, clamping the interpolation +// value to the range [0, 1]. +template <typename ValueType> +inline ValueType Interp(ValueType start, ValueType end, float interp_amount) { + return start + (end - start) * Clamp01(interp_amount); +} + +// Linearly interpolates from `start` to `end`, traveling around the shorter +// path (e.g. interpolating from π/4 to 7π/4 is equivalent to interpolating from +// π/4 to 0, then 2π to 7π/4). The returned angle will be normalized to the +// interval [0, 2π). All angles are measured in radians. +inline float InterpAngle(float start, float end, float interp_amount) { + auto normalize_angle = [](float angle) { + while (angle < 0) angle += 2 * M_PI; + while (angle > 2 * M_PI) angle -= 2 * M_PI; + return angle; + }; + + start = normalize_angle(start); + end = normalize_angle(end); + float delta = end - start; + if (delta < -M_PI) { + end += 2 * M_PI; + } else if (delta > M_PI) { + end -= 2 * M_PI; + } + return normalize_angle(Interp(start, end, interp_amount)); +} + +// Returns the distance between two points. +inline float Distance(Vec2 start, Vec2 end) { + return (end - start).Magnitude(); +} + +// Returns the point on the line segment from `segment_start` to `segment_end` +// that is closest to `point`, represented as the ratio of the length along the +// segment. +inline float NearestPointOnSegment(Vec2 segment_start, Vec2 segment_end, + Vec2 point) { + if (segment_start == segment_end) return 0; + + auto dot_product = [](Vec2 lhs, Vec2 rhs) { + return lhs.x * rhs.x + lhs.y * rhs.y; + }; + Vec2 segment_vector = segment_end - segment_start; + Vec2 projection_vector = point - segment_start; + return Clamp01(dot_product(projection_vector, segment_vector) / + dot_product(segment_vector, segment_vector)); +} + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_UTILS_H_ diff --git a/ink_stroke_modeler/internal/utils_test.cc b/ink_stroke_modeler/internal/utils_test.cc new file mode 100644 index 0000000..b50d21a --- /dev/null +++ b/ink_stroke_modeler/internal/utils_test.cc @@ -0,0 +1,87 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/utils.h" + +#include <cmath> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "ink_stroke_modeler/internal/type_matchers.h" + +namespace ink { +namespace stroke_model { +namespace { + +TEST(UtilsTest, Clamp01) { + EXPECT_FLOAT_EQ(Clamp01(-2), 0); + EXPECT_FLOAT_EQ(Clamp01(0), 0); + EXPECT_FLOAT_EQ(Clamp01(.3), .3); + EXPECT_FLOAT_EQ(Clamp01(.7), .7); + EXPECT_FLOAT_EQ(Clamp01(1), 1); + EXPECT_FLOAT_EQ(Clamp01(1.1), 1); +} + +TEST(UtilsTest, Normalize01) { + EXPECT_FLOAT_EQ(Normalize01(1, 2, 1.5), .5); + EXPECT_FLOAT_EQ(Normalize01(7, 3, 4), .75); + EXPECT_FLOAT_EQ(Normalize01(-1, 1, 2), 1); + EXPECT_FLOAT_EQ(Normalize01(1, 1, 1), 0); + EXPECT_FLOAT_EQ(Normalize01(1, 1, 0), 0); + EXPECT_FLOAT_EQ(Normalize01(1, 1, 2), 1); +} + +TEST(UtilsTest, InterpFloat) { + EXPECT_FLOAT_EQ(Interp(5, 10, .2), 6); + EXPECT_FLOAT_EQ(Interp(10, -2, .75), 1); + EXPECT_FLOAT_EQ(Interp(-1, 2, -3), -1); + EXPECT_FLOAT_EQ(Interp(5, 7, 20), 7); +} + +TEST(UtilsTest, InterpVec2) { + EXPECT_THAT(Interp(Vec2{1, 2}, {3, 5}, .5), Vec2Eq({2, 3.5})); + EXPECT_THAT(Interp(Vec2{-5, 5}, {-15, 0}, .4), Vec2Eq({-9, 3})); + EXPECT_THAT(Interp(Vec2{7, 9}, {25, 30}, -.1), Vec2Eq({7, 9})); + EXPECT_THAT(Interp(Vec2{12, 5}, {13, 14}, 3.2), Vec2Eq({13, 14})); +} + +TEST(UtilsTest, InterpAngle) { + EXPECT_NEAR(InterpAngle(.25 * M_PI, .5 * M_PI, .4), .35 * M_PI, 1e-6); + EXPECT_NEAR(InterpAngle(1.05 * M_PI, .25 * M_PI, .5), .65 * M_PI, 1e-6); + EXPECT_NEAR(InterpAngle(.25 * M_PI, 1.75 * M_PI, .1), .2 * M_PI, 1e-6); + EXPECT_NEAR(InterpAngle(.25 * M_PI, 1.75 * M_PI, .7), 1.9 * M_PI, 1e-6); + EXPECT_NEAR(InterpAngle(1.6 * M_PI, .4 * M_PI, .25), 1.8 * M_PI, 1e-6); + EXPECT_NEAR(InterpAngle(1.6 * M_PI, .4 * M_PI, .625), .1 * M_PI, 1e-6); +} + +TEST(UtilsTest, Distance) { + EXPECT_FLOAT_EQ(Distance({0, 0}, {1, 0}), 1); + EXPECT_FLOAT_EQ(Distance({1, 1}, {-2, 5}), 5); +} + +TEST(UtilsTest, NearestPointOnSegment) { + EXPECT_FLOAT_EQ(NearestPointOnSegment({0, 0}, {1, 0}, {.25, .5}), .25); + EXPECT_FLOAT_EQ(NearestPointOnSegment({3, 4}, {5, 6}, {-1, -1}), 0); + EXPECT_FLOAT_EQ(NearestPointOnSegment({20, 10}, {10, 5}, {2, 2}), 1); + EXPECT_FLOAT_EQ(NearestPointOnSegment({0, 5}, {5, 0}, {3, 3}), .5); +} + +TEST(UtilsTest, NearestPointOnSegmentDegenerateCase) { + EXPECT_FLOAT_EQ(NearestPointOnSegment({0, 0}, {0, 0}, {5, 10}), 0); + EXPECT_FLOAT_EQ(NearestPointOnSegment({3, 7}, {3, 7}, {0, -20}), 0); +} + +} // namespace +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/validation.h b/ink_stroke_modeler/internal/validation.h new file mode 100644 index 0000000..e01aa1a --- /dev/null +++ b/ink_stroke_modeler/internal/validation.h @@ -0,0 +1,48 @@ +#ifndef INK_STROKE_MODELER_INTERNAL_VALIDATION_H_ +#define INK_STROKE_MODELER_INTERNAL_VALIDATION_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" + +template <typename T> +absl::Status ValidateIsFiniteNumber(T value, absl::string_view label) { + if (std::isnan(value)) { + return absl::InvalidArgumentError(absl::Substitute("$0 is NaN", label)); + } + if (std::isinf(value)) { + return absl::InvalidArgumentError( + absl::Substitute("$0 is infinite", label)); + } + return absl::OkStatus(); +} + +template <typename T> +absl::Status ValidateGreaterThanZero(T value, absl::string_view label) { + if (absl::Status status = ValidateIsFiniteNumber(value, label); + !status.ok()) { + return status; + } + if (value <= 0) { + return absl::InvalidArgumentError(absl::Substitute( + "$0 must be greater than zero. Actual value: $1", label, value)); + } + return absl::OkStatus(); +} + +template <typename T> +absl::Status ValidateGreaterThanOrEqualToZero(T value, + absl::string_view label) { + if (absl::Status status = ValidateIsFiniteNumber(value, label); + !status.ok()) { + return status; + } + if (value < 0) { + return absl::InvalidArgumentError(absl::Substitute( + "$0 must be greater than or equal to zero. Actual value: $1", label, + value)); + } + return absl::OkStatus(); +} + +#endif // INK_STROKE_MODELER_INTERNAL_VALIDATION_H_ diff --git a/ink_stroke_modeler/internal/validation_test.cc b/ink_stroke_modeler/internal/validation_test.cc new file mode 100644 index 0000000..51482ff --- /dev/null +++ b/ink_stroke_modeler/internal/validation_test.cc @@ -0,0 +1,82 @@ +#include "ink_stroke_modeler/internal/validation.h" + +#include <cmath> + +#include "gtest/gtest.h" + +namespace { + +TEST(ValidateIsFiniteNumberTest, AcceptFinite) { + ASSERT_TRUE(ValidateIsFiniteNumber(1, "foo").ok()); +} + +TEST(ValidateIsFiniteNumberTest, RejectNan) { + absl::Status status = ValidateIsFiniteNumber(NAN, "foo"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.message(), "foo is NaN"); +} + +TEST(ValidateIsFiniteNumberTest, RejectInf) { + absl::Status status = ValidateIsFiniteNumber(INFINITY, "foo"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.message(), "foo is infinite"); +} + +TEST(ValidateGreaterThanZeroTest, AcceptPositive) { + ASSERT_TRUE(ValidateGreaterThanZero(1, "foo").ok()); +} + +TEST(ValidateGreaterThanZeroTest, RejectZero) { + absl::Status status = ValidateGreaterThanZero(0, "foo"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.message(), "foo must be greater than zero. Actual value: 0"); +} + +TEST(ValidateGreaterThanZeroTest, RejectNegative) { + absl::Status status = ValidateGreaterThanZero(-1, "foo"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.message(), + "foo must be greater than zero. Actual value: -1"); +} + +TEST(ValidateGreaterThanZeroTest, RejectNan) { + absl::Status status = ValidateGreaterThanZero(NAN, "foo"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.message(), "foo is NaN"); +} + +TEST(ValidateGreaterThanZeroTest, RejectInf) { + absl::Status status = ValidateGreaterThanZero(INFINITY, "foo"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.message(), "foo is infinite"); +} + +TEST(ValidateGreaterThanOrEqualToZeroTest, AcceptPositive) { + ASSERT_TRUE(ValidateGreaterThanOrEqualToZero(1, "foo").ok()); +} + +TEST(ValidateGreaterThanOrEqualToZeroTest, AcceptZero) { + absl::Status status = ValidateGreaterThanOrEqualToZero(0, "foo"); + ASSERT_TRUE(ValidateGreaterThanOrEqualToZero(0, "foo").ok()); +} + +TEST(ValidateGreaterThanOrEqualToZeroTest, RejectNegative) { + absl::Status status = ValidateGreaterThanOrEqualToZero(-1, "foo"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.message(), + "foo must be greater than or equal to zero. Actual value: -1"); +} + +TEST(ValidateGreaterThanOrEqualToZeroTest, RejectNan) { + absl::Status status = ValidateGreaterThanOrEqualToZero(NAN, "foo"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.message(), "foo is NaN"); +} + +TEST(ValidateGreaterThanOrEqualToZeroTest, RejectInf) { + absl::Status status = ValidateGreaterThanOrEqualToZero(INFINITY, "foo"); + ASSERT_FALSE(status.ok()); + ASSERT_EQ(status.message(), "foo is infinite"); +} + +} // namespace diff --git a/ink_stroke_modeler/internal/wobble_smoother.cc b/ink_stroke_modeler/internal/wobble_smoother.cc new file mode 100644 index 0000000..a403ecb --- /dev/null +++ b/ink_stroke_modeler/internal/wobble_smoother.cc @@ -0,0 +1,70 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/wobble_smoother.h" + +#include <algorithm> + +#include "ink_stroke_modeler/internal/utils.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +void WobbleSmoother::Reset(const WobbleSmootherParams& params, Vec2 position, + Time time) { + params_ = params; + samples_.clear(); + // Initialize with the "fast" speed -- otherwise, we'll lag behind at the + // start of the stroke. + position_sum_ = position; + speed_sum_ = params_.speed_ceiling; + samples_.push_back( + {.position = position, .speed = params_.speed_ceiling, .time = time}); +} + +Vec2 WobbleSmoother::Update(Vec2 position, Time time) { + // The moving average acts as a low-pass signal filter, removing + // high-frequency fluctuations in the position caused by the discrete nature + // of the touch digitizer. To compensate for the distance between the average + // position and the actual position, we interpolate between them, based on + // speed, to determine the position to use for the input model. + float distance = Distance(position, samples_.back().position); + Duration delta_time = time - samples_.back().time; + float speed = 0; + if (delta_time == Duration(0)) { + // We're going to assume that you're not actually moving infinitely fast. + speed = std::max(params_.speed_ceiling, speed_sum_ / samples_.size()); + } else { + speed = distance / delta_time.Value(); + } + + samples_.push_back({.position = position, .speed = speed, .time = time}); + position_sum_ += position; + speed_sum_ += speed; + while (samples_.front().time < time - params_.timeout) { + position_sum_ -= samples_.front().position; + speed_sum_ -= samples_.front().speed; + samples_.pop_front(); + } + + Vec2 avg_position = position_sum_ / samples_.size(); + float avg_speed = speed_sum_ / samples_.size(); + return Interp( + avg_position, position, + Normalize01(params_.speed_floor, params_.speed_ceiling, avg_speed)); +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/internal/wobble_smoother.h b/ink_stroke_modeler/internal/wobble_smoother.h new file mode 100644 index 0000000..7eb85b8 --- /dev/null +++ b/ink_stroke_modeler/internal/wobble_smoother.h @@ -0,0 +1,57 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_INTERNAL_WOBBLE_SMOOTHER_H_ +#define INK_STROKE_MODELER_INTERNAL_WOBBLE_SMOOTHER_H_ + +#include <deque> + +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// This class smooths "wobble" in input positions from high-frequency noise. It +// does so by maintaining a moving average of the positions, and interpolating +// between the given input and the moving average based on how quickly it's +// moving. When moving at a speed above the ceiling in the WobbleSmootherParams, +// the result will be the unmodified input; when moving at a speed below the +// floor, the result will be the moving average. +class WobbleSmoother { + public: + void Reset(const WobbleSmootherParams ¶ms, Vec2 position, Time time); + + // Updates the average position and speed, and returns the smoothed position. + Vec2 Update(Vec2 position, Time time); + + private: + struct Sample { + Vec2 position{0}; + float speed{0}; + Time time{0}; + }; + std::deque<Sample> samples_; + Vec2 position_sum_{0}; + float speed_sum_{0}; + + WobbleSmootherParams params_; +}; + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_INTERNAL_WOBBLE_SMOOTHER_H_ diff --git a/ink_stroke_modeler/internal/wobble_smoother_test.cc b/ink_stroke_modeler/internal/wobble_smoother_test.cc new file mode 100644 index 0000000..7bffd37 --- /dev/null +++ b/ink_stroke_modeler/internal/wobble_smoother_test.cc @@ -0,0 +1,104 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/internal/wobble_smoother.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "ink_stroke_modeler/internal/type_matchers.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { +namespace { + +const WobbleSmootherParams kDefaultParams{ + .timeout = Duration(.04), .speed_floor = 1.31, .speed_ceiling = 1.44}; + +TEST(WobbleSmootherTest, SlowStraightLine) { + // The line moves at 1 cm/s, which is below the floor of 1.31 cm/s. + WobbleSmoother filter; + filter.Reset(kDefaultParams, {3, 4}, Time{1}); + EXPECT_THAT(filter.Update({3.016, 4}, Time{1.016}), Vec2Eq({3.008, 4})); + EXPECT_THAT(filter.Update({3.032, 4}, Time{1.032}), Vec2Eq({3.016, 4})); + EXPECT_THAT(filter.Update({3.048, 4}, Time{1.048}), Vec2Eq({3.032, 4})); + EXPECT_THAT(filter.Update({3.064, 4}, Time{1.064}), Vec2Eq({3.048, 4})); +} + +TEST(WobbleSmootherTest, SlowStraightLineEqualFloorAndCeiling) { + // The line moves at 1 cm/s, which is below the floor of 1.31 cm/s. + WobbleSmootherParams equal_floor_and_ceiling_params{ + .timeout = Duration(.04), .speed_floor = 1.31, .speed_ceiling = 1.31}; + WobbleSmoother filter; + filter.Reset(equal_floor_and_ceiling_params, {3, 4}, Time{1}); + EXPECT_THAT(filter.Update({3.016, 4}, Time{1.016}), Vec2Eq({3.008, 4})); + EXPECT_THAT(filter.Update({3.032, 4}, Time{1.032}), Vec2Eq({3.016, 4})); + EXPECT_THAT(filter.Update({3.048, 4}, Time{1.048}), Vec2Eq({3.032, 4})); + EXPECT_THAT(filter.Update({3.064, 4}, Time{1.064}), Vec2Eq({3.048, 4})); +} + +TEST(WobbleSmootherTest, FastStraightLine) { + // The line moves at 1.5 cm/s, which is above the ceiling of 1.44 cm/s. + WobbleSmoother filter; + filter.Reset(kDefaultParams, {-1, 0}, Time{0}); + EXPECT_THAT(filter.Update({-1, .024}, Time{.016}), Vec2Eq({-1, .024})); + EXPECT_THAT(filter.Update({-1, .048}, Time{.032}), Vec2Eq({-1, .048})); + EXPECT_THAT(filter.Update({-1, .072}, Time{.048}), Vec2Eq({-1, .072})); +} + +TEST(WobbleSmootherTest, FastStraightLineEqualFloorAndCeiling) { + // The line moves at 1.5 cm/s, which is above the ceiling of 1.44 cm/s. + WobbleSmoother filter; + WobbleSmootherParams equal_floor_and_ceiling_params{ + .timeout = Duration(.04), .speed_floor = 1.41, .speed_ceiling = 1.41}; + filter.Reset(equal_floor_and_ceiling_params, {-1, 0}, Time{0}); + EXPECT_THAT(filter.Update({-1, .024}, Time{.016}), Vec2Eq({-1, .024})); + EXPECT_THAT(filter.Update({-1, .048}, Time{.032}), Vec2Eq({-1, .048})); + EXPECT_THAT(filter.Update({-1, .072}, Time{.048}), Vec2Eq({-1, .072})); +} + +TEST(WobbleSmootherTest, SlowZigZag) { + // The line moves at 1 cm/s, which is below the floor of 1.31 cm/s. + WobbleSmoother filter; + filter.Reset(kDefaultParams, {1, 2}, Time{5}); + EXPECT_THAT(filter.Update({1.016, 2}, Time{5.016}), Vec2Eq({1.008, 2})); + EXPECT_THAT(filter.Update({1.016, 2.016}, Time{5.032}), + Vec2Eq({1.0106667, 2.0053333})); + EXPECT_THAT(filter.Update({1.032, 2.016}, Time{5.048}), + Vec2Eq({1.0213333, 2.0106667})); + EXPECT_THAT(filter.Update({1.032, 2.032}, Time{5.064}), + Vec2Eq({1.0266667, 2.0213333})); + EXPECT_THAT(filter.Update({1.048, 2.032}, Time{5.080}), + Vec2Eq({1.0373333, 2.0266667})); + EXPECT_THAT(filter.Update({1.048, 2.048}, Time{5.096}), + Vec2Eq({1.0426667, 2.0373333})); +} + +TEST(WobbleSmootherTest, FastZigZag) { + // The line moves at 1.5 cm/s, which is above the ceiling of 1.44 cm/s. + WobbleSmoother filter; + filter.Reset(kDefaultParams, {7, 3}, Time{8}); + EXPECT_THAT(filter.Update({7, 3.024}, Time{8.016}), Vec2Eq({7, 3.024})); + EXPECT_THAT(filter.Update({7.024, 3.024}, Time{8.032}), + Vec2Eq({7.024, 3.024})); + EXPECT_THAT(filter.Update({7.024, 3.048}, Time{8.048}), + Vec2Eq({7.024, 3.048})); + EXPECT_THAT(filter.Update({7.048, 3.048}, Time{8.064}), + Vec2Eq({7.048, 3.048})); +} + +} // namespace +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/params.cc b/ink_stroke_modeler/params.cc new file mode 100644 index 0000000..99542e1 --- /dev/null +++ b/ink_stroke_modeler/params.cc @@ -0,0 +1,157 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/params.h" + +#include <cmath> + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" +#include "absl/types/variant.h" +#include "ink_stroke_modeler/internal/validation.h" + +// This convenience macro evaluates the given expression, and if it does not +// return an OK status, returns and propagates the status. +#define RETURN_IF_ERROR(expr) \ + do { \ + if (auto status = (expr); !status.ok()) return status; \ + } while (false) + +namespace ink { +namespace stroke_model { + +absl::Status ValidatePositionModelerParams( + const PositionModelerParams& params) { + RETURN_IF_ERROR(ValidateGreaterThanZero(params.spring_mass_constant, + "PredictionParams::spring_mass")); + return ValidateGreaterThanZero(params.drag_constant, + "PredictionParams::drag_ratio"); +} + +absl::Status ValidateSamplingParams(const SamplingParams& params) { + RETURN_IF_ERROR(ValidateGreaterThanZero(params.min_output_rate, + "PredictionParams::min_output_rate")); + RETURN_IF_ERROR(ValidateGreaterThanZero( + params.end_of_stroke_stopping_distance, + "PredictionParams::end_of_stroke_stopping_distance")); + return ValidateGreaterThanZero( + params.end_of_stroke_max_iterations, + "PredictionParams::end_of_stroke_stopping_distance"); +} + +absl::Status ValidateStylusStateModelerParams( + const StylusStateModelerParams& params) { + return ValidateGreaterThanZero(params.max_input_samples, + "StylusStateModelerParams::max_input_samples"); +} + +absl::Status ValidateWobbleSmootherParams(const WobbleSmootherParams& params) { + RETURN_IF_ERROR(ValidateGreaterThanOrEqualToZero( + params.timeout.Value(), "WobbleSmootherParams::timeout")); + RETURN_IF_ERROR(ValidateGreaterThanOrEqualToZero( + params.speed_floor, "WobbleSmootherParams::speed_floor")); + RETURN_IF_ERROR(ValidateIsFiniteNumber( + params.speed_ceiling, "WobbleSmootherParams::speed_ceiling")); + if (params.speed_ceiling < params.speed_floor) { + return absl::InvalidArgumentError(absl::Substitute( + "WobbleSmootherParams::speed_ceiling must be greater than or " + "equal to WobbleSmootherParams::speed_floor ($0). Actual " + "value: $1", + params.speed_floor, params.speed_ceiling)); + } + return absl::OkStatus(); +} + +absl::Status ValidatePredictionParams(const PredictionParams& params) { + if (absl::holds_alternative<StrokeEndPredictorParams>(params)) { + // Nothing to validate. + return absl::OkStatus(); + } + + const KalmanPredictorParams& kalman_params = + absl::get<KalmanPredictorParams>(params); + RETURN_IF_ERROR(ValidateGreaterThanZero( + kalman_params.process_noise, "KalmanPredictorParams::process_noise")); + RETURN_IF_ERROR( + ValidateGreaterThanZero(kalman_params.measurement_noise, + "KalmanPredictorParams::measurement_noise")); + RETURN_IF_ERROR( + ValidateGreaterThanZero(kalman_params.min_stable_iteration, + "KalmanPredictorParams::min_stable_iteration")); + RETURN_IF_ERROR( + ValidateGreaterThanZero(kalman_params.max_time_samples, + "KalmanPredictorParams::max_time_samples")); + RETURN_IF_ERROR( + ValidateGreaterThanZero(kalman_params.min_catchup_velocity, + "KalmanPredictorParams::min_catchup_velocity")); + RETURN_IF_ERROR( + ValidateIsFiniteNumber(kalman_params.acceleration_weight, + "KalmanPredictorParams::acceleration_weight")); + RETURN_IF_ERROR(ValidateIsFiniteNumber(kalman_params.jerk_weight, + "KalmanPredictorParams::jerk_weight")); + RETURN_IF_ERROR( + ValidateGreaterThanZero(kalman_params.prediction_interval.Value(), + "KalmanPredictorParams::jerk_weight")); + + const KalmanPredictorParams::ConfidenceParams& confidence_params = + kalman_params.confidence_params; + RETURN_IF_ERROR(ValidateGreaterThanZero( + confidence_params.desired_number_of_samples, + "KalmanPredictorParams::ConfidenceParams::desired_number_of_samples")); + RETURN_IF_ERROR(ValidateGreaterThanZero( + confidence_params.max_estimation_distance, + "KalmanPredictorParams::ConfidenceParams::max_estimation_distance")); + RETURN_IF_ERROR(ValidateGreaterThanOrEqualToZero( + confidence_params.min_travel_speed, + "KalmanPredictorParams::ConfidenceParams::min_travel_speed")); + RETURN_IF_ERROR(ValidateIsFiniteNumber( + confidence_params.max_travel_speed, + "KalmanPredictorParams::ConfidenceParams::max_travel_speed")); + if (confidence_params.max_travel_speed < confidence_params.min_travel_speed) { + return absl::InvalidArgumentError( + absl::Substitute("KalmanPredictorParams::ConfidenceParams::max_" + "travel_speed must be greater than or equal to " + "KalmanPredictorParams::ConfidenceParams::min_" + "travel_speed ($0). Actual value: $1", + confidence_params.min_travel_speed, + confidence_params.max_travel_speed)); + } + RETURN_IF_ERROR(ValidateGreaterThanZero( + confidence_params.max_linear_deviation, + "KalmanPredictorParams::ConfidenceParams::max_linear_deviation")); + if (confidence_params.baseline_linearity_confidence < 0 || + confidence_params.baseline_linearity_confidence > 1) { + return absl::InvalidArgumentError(absl::Substitute( + "KalmanPredictorParams::ConfidenceParams::baseline_linearity_" + "confidence must lie in the interval [0, 1]. Actual value: $0", + confidence_params.baseline_linearity_confidence)); + } + return absl::OkStatus(); +} + +absl::Status ValidateStrokeModelParams(const StrokeModelParams& params) { + RETURN_IF_ERROR(ValidateWobbleSmootherParams(params.wobble_smoother_params)); + RETURN_IF_ERROR( + ValidatePositionModelerParams(params.position_modeler_params)); + RETURN_IF_ERROR(ValidateSamplingParams(params.sampling_params)); + RETURN_IF_ERROR( + ValidateStylusStateModelerParams(params.stylus_state_modeler_params)); + return ValidatePredictionParams(params.prediction_params); +} + +} // namespace stroke_model +} // namespace ink + +#undef RETURN_IF_ERROR diff --git a/ink_stroke_modeler/params.h b/ink_stroke_modeler/params.h new file mode 100644 index 0000000..102edba --- /dev/null +++ b/ink_stroke_modeler/params.h @@ -0,0 +1,224 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_PARAMS_H_ +#define INK_STROKE_MODELER_PARAMS_H_ + +#include "absl/status/status.h" +#include "absl/types/variant.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// These structs contain parameters for tuning the behavior of the stroke +// modeler. +// +// The stroke modeler is unit-agnostic, in both time and space. That is, the +// stroke modeler does not know or care whether the inputs and parameters are +// specified in feet and minutes, meters and seconds, or millimeters and years. +// As such, instead of referring to specific units, we refer to "unit distance" +// and "unit time". +// +// These parameters will need to be "tuned" to your use case. Because of this, +// and because of the modeler's unit-agnosticism, it's impossible to define +// "reasonable" default values for many of the parameters -- these parameters +// instead default to -1, which will cause the validation functions to return an +// error. +// +// Where possible, we've indicated what a good starting point for tuning might +// be, but you'll likely need to adjust these for best results. + +// These parameters are used for modeling the position of the pen. +struct PositionModelerParams { + // The mass of the "weight" being pulled along the path, multiplied by the + // spring constant. + float spring_mass_constant = 11.f / 32400; + + // The ratio of the pen's velocity that is subtracted from the pen's + // acceleration, to simulate drag. + float drag_constant = 72.f; +}; + +// These parameters are used for sampling. +struct SamplingParams { + // The minimum number of modeled inputs to output per unit time. If inputs are + // received at a lower rate, they will be upsampled to produce output of at + // least min_output_rate. If inputs are received at a higher rate, the + // output rate will match the input rate. + double min_output_rate = -1; + + // This determines stop condition for end-of-stroke modeling; if the position + // is within this distance of the final raw input, or if the last update + // iteration moved less than this distance, it stop iterating. + // + // This should be a small distance; a good starting point is 2-3 orders of + // magnitude smaller than the expected distance between input points. + float end_of_stroke_stopping_distance = -1; + + // The maximum number of iterations to perform at the end of the stroke, if it + // does not stop due to the constraints of end_of_stroke_stopping_distance. + int end_of_stroke_max_iterations = 20; +}; + +// These parameters are used modeling the state of the stylus once the position +// has been modeled. +struct StylusStateModelerParams { + // The maximum number of raw inputs to look at when finding the nearest states + // for interpolation. + int max_input_samples = 10; +}; + +// These parameters are used for applying smoothing to the input to reduce +// wobble in the prediction. +struct WobbleSmootherParams { + // The length of the window over which the moving average of speed and + // position are calculated. + // + // A good starting point is 2.5 divided by the expected number of inputs per + // unit time. + Duration timeout{-1}; + + // The range of speeds considered for wobble smoothing. At speed_floor, the + // maximum amount of smoothing is applied. At speed_ceiling, no smoothing is + // applied. + // + // Good starting points are 2% and 3% of the expected speed of the inputs. + float speed_floor = -1; + float speed_ceiling = -1; +}; + +// This struct indicates the "stroke end" prediction strategy should be used, +// which models a prediction as though the last seen input was the +// end-of-stroke. There aren't actually any tunable parameters for this; it uses +// the same PositionModelerParams and SamplingParams as the overall model. Note +// that this "prediction" doesn't actually predict substantially into the +// future, it only allows for very quickly "catching up" to the position of the +// raw input. +struct StrokeEndPredictorParams {}; + +// This struct indicates that the Kalman filter-based prediction strategy should +// be used, and provides the parameters for tuning it. +// +// Unlike the "stroke end" predictor, this strategy can predict an extension +// of the stroke beyond the last Input position, in addition to the "catch up" +// step. +struct KalmanPredictorParams { + // The variance of the noise inherent to the stroke itself. + double process_noise = -1; + + // The variance of the noise that rises from errors in measurement of the + // stroke. + double measurement_noise = -1; + + // The minimum number of inputs received before the Kalman predictor is + // considered stable enough to make a prediction. + int min_stable_iteration = 4; + + // The Kalman filter assumes that input is received in uniform time steps, but + // this is not always the case. We hold on to the most recent input timestamps + // for use in calculating the correction for this. This determines the maximum + // number of timestamps to save. + int max_time_samples = 20; + + // The minimum allowed velocity of the "catch up" portion of the prediction, + // which covers the distance between the last Result (the last corrected + // position) and the + // + // A good starting point is 3 orders of magnitude smaller than the expected + // speed of the inputs. + float min_catchup_velocity = -1; + + // These weights are applied to the acceleration (x²) and jerk (x³) terms of + // the cubic prediction polynomial. The closer they are to zero, the more + // linear the prediction will be. + float acceleration_weight = .5; + float jerk_weight = .1; + + // This value is a hint to the predictor, indicating the desired duration of + // of the portion of the prediction extending beyond the position of the last + // input. The actual duration of that portion of the prediction may be less + // than this, based on the predictor's confidence, but it will never be + // greater. + Duration prediction_interval{-1}; + + // The Kalman predictor uses several heuristics to evaluate confidence in the + // prediction. Each heuristic produces a confidence value between 0 and 1, and + // then we take their product as the total confidence. + // These parameters may be used to tune those heuristics. + struct ConfidenceParams { + // The first heuristic simply increases confidence as we receive more sample + // (i.e. input points). It evaluates to 0 at no samples, and 1 at + // desired_number_of_samples. + int desired_number_of_samples = 20; + + // The second heuristic is based on the distance between the last sample + // and the current estimate. If the distance is 0, it evaluates to 1, and if + // the distance is greater than or equal to max_estimation_distance, it + // evaluates to 0. + // + // A good starting point is 1.5 times measurement_noise. + float max_estimation_distance = -1; + + // The third heuristic is based on the speed of the prediction, which is + // approximated by measuring the from the start of the prediction to the + // projected endpoint (if it were extended for the full + // prediction_interval). It evaluates to 0 at min_travel_speed, and 1 + // at max_travel_speed. + // + // Good starting points are 5% and 25% of the expected speed of the inputs. + float min_travel_speed = -1; + float max_travel_speed = -1; + + // The fourth heuristic is based on the linearity of the prediction, which + // is approximated by comparing the endpoint of the prediction with the + // endpoint of a linear prediction (again, extended for the full + // prediction_interval). It evaluates to 1 at zero distance, and + // baseline_linearity_confidence at a distance of max_linear_deviation. + // + // A good starting point is an 10 times the measurement_noise. + float max_linear_deviation = -1; + float baseline_linearity_confidence = .4; + }; + ConfidenceParams confidence_params; +}; +using PredictionParams = + absl::variant<StrokeEndPredictorParams, KalmanPredictorParams>; + +// This convenience struct is a collection of the parameters for the individual +// parameter structs. +struct StrokeModelParams { + WobbleSmootherParams wobble_smoother_params; + PositionModelerParams position_modeler_params; + SamplingParams sampling_params; + StylusStateModelerParams stylus_state_modeler_params; + PredictionParams prediction_params = StrokeEndPredictorParams{}; +}; + +// These validation functions will return an error if the given parameters are +// invalid. +absl::Status ValidatePositionModelerParams(const PositionModelerParams& params); +absl::Status ValidateSamplingParams(const SamplingParams& params); +absl::Status ValidateStylusStateModelerParams( + const StylusStateModelerParams& params); +absl::Status ValidateWobbleSmootherParams(const WobbleSmootherParams& params); +absl::Status ValidatePredictionParams(const PredictionParams& params); +absl::Status ValidateStrokeModelParams(const StrokeModelParams& params); + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_PARAMS_H_ diff --git a/ink_stroke_modeler/params_test.cc b/ink_stroke_modeler/params_test.cc new file mode 100644 index 0000000..15bad8a --- /dev/null +++ b/ink_stroke_modeler/params_test.cc @@ -0,0 +1,259 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/params.h" + +#include <limits> + +#include "gtest/gtest.h" +#include "absl/status/status.h" + +namespace ink { +namespace stroke_model { +namespace { + +const KalmanPredictorParams kGoodKalmanParams{ + .process_noise = .01, + .measurement_noise = .1, + .min_stable_iteration = 2, + .max_time_samples = 10, + .min_catchup_velocity = 1, + .acceleration_weight = -1, + .jerk_weight = 200, + .prediction_interval{Duration(1)}, + .confidence_params{.desired_number_of_samples = 10, + .max_estimation_distance = 1, + .min_travel_speed = 6, + .max_travel_speed = 50, + .max_linear_deviation = 2, + .baseline_linearity_confidence = .5}}; + +const StrokeModelParams kGoodStrokeModelParams{ + .wobble_smoother_params{ + .timeout = Duration(.5), .speed_floor = 1, .speed_ceiling = 20}, + .position_modeler_params{.spring_mass_constant = .2, .drag_constant = 4}, + .sampling_params{.min_output_rate = 3, + .end_of_stroke_stopping_distance = 1e-6, + .end_of_stroke_max_iterations = 1}, + .stylus_state_modeler_params{.max_input_samples = 7}, + .prediction_params = StrokeEndPredictorParams{}}; + +TEST(ParamsTest, ValidatePositionModelerParams) { + EXPECT_TRUE(ValidatePositionModelerParams( + {.spring_mass_constant = 1, .drag_constant = 3}) + .ok()); + + EXPECT_EQ(ValidatePositionModelerParams( + {.spring_mass_constant = 0, .drag_constant = 1}) + .code(), + absl::StatusCode::kInvalidArgument); + EXPECT_EQ(ValidatePositionModelerParams( + {.spring_mass_constant = 1, .drag_constant = 0}) + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ParamsTest, ValidateSamplingParams) { + EXPECT_TRUE(ValidateSamplingParams({.min_output_rate = 10, + .end_of_stroke_stopping_distance = .1, + .end_of_stroke_max_iterations = 3}) + .ok()); + + EXPECT_EQ(ValidateSamplingParams({.min_output_rate = 0, + .end_of_stroke_stopping_distance = .1, + .end_of_stroke_max_iterations = 3}) + .code(), + absl::StatusCode::kInvalidArgument); + EXPECT_EQ(ValidateSamplingParams({.min_output_rate = 1, + .end_of_stroke_stopping_distance = 0, + .end_of_stroke_max_iterations = 3}) + .code(), + absl::StatusCode::kInvalidArgument); + EXPECT_EQ(ValidateSamplingParams({.min_output_rate = 1, + .end_of_stroke_stopping_distance = 5, + .end_of_stroke_max_iterations = 0}) + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ParamsTest, ValidateStylusStateModelerParams) { + EXPECT_TRUE(ValidateStylusStateModelerParams({.max_input_samples = 1}).ok()); + + EXPECT_EQ(ValidateStylusStateModelerParams({.max_input_samples = 0}).code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ParamsTest, ValidateWobbleSmootherParams) { + EXPECT_TRUE( + ValidateWobbleSmootherParams( + {.timeout = Duration(1), .speed_floor = 2, .speed_ceiling = 3}) + .ok()); + EXPECT_TRUE( + ValidateWobbleSmootherParams( + {.timeout = Duration(0), .speed_floor = 0, .speed_ceiling = 0}) + .ok()); + + EXPECT_EQ(ValidateWobbleSmootherParams( + {.timeout = Duration(-1), .speed_floor = 2, .speed_ceiling = 5}) + .code(), + absl::StatusCode::kInvalidArgument); + EXPECT_EQ(ValidateWobbleSmootherParams( + {.timeout = Duration(1), .speed_floor = -2, .speed_ceiling = 1}) + .code(), + absl::StatusCode::kInvalidArgument); + EXPECT_EQ(ValidateWobbleSmootherParams( + {.timeout = Duration(1), .speed_floor = 7, .speed_ceiling = 4}) + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ParamsTest, ValidateStrokeEndPredictorParams) { + EXPECT_TRUE(ValidatePredictionParams(StrokeEndPredictorParams()).ok()); +} + +TEST(ParamsTest, ValidateKalmanPredictorParams) { + EXPECT_TRUE(ValidatePredictionParams(kGoodKalmanParams).ok()); + { + auto bad_params = kGoodKalmanParams; + bad_params.process_noise = 0; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.measurement_noise = 0; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.min_stable_iteration = 0; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.max_time_samples = 0; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.prediction_interval = Duration(0); + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } +} + +TEST(ParamsTest, ValidateKalmanPredictorConfidenceParams) { + EXPECT_TRUE(ValidatePredictionParams(kGoodKalmanParams).ok()); + { + auto bad_params = kGoodKalmanParams; + bad_params.confidence_params.desired_number_of_samples = 0; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.confidence_params.max_estimation_distance = 0; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.confidence_params.min_travel_speed = -1; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.confidence_params.min_travel_speed = 10; + bad_params.confidence_params.max_travel_speed = 1; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.confidence_params.max_linear_deviation = 0; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.confidence_params.baseline_linearity_confidence = -.3; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodKalmanParams; + bad_params.confidence_params.baseline_linearity_confidence = 1.01; + EXPECT_EQ(ValidatePredictionParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } +} + +TEST(ParamsTest, ValidateStrokeModelParams) { + EXPECT_TRUE(ValidateStrokeModelParams(kGoodStrokeModelParams).ok()); + { + auto bad_params = kGoodStrokeModelParams; + bad_params.wobble_smoother_params.timeout = Duration(-10); + EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodStrokeModelParams; + bad_params.position_modeler_params.spring_mass_constant = -1; + EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodStrokeModelParams; + bad_params.stylus_state_modeler_params.max_input_samples = 0; + EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodStrokeModelParams; + bad_params.sampling_params.end_of_stroke_max_iterations = -3; + EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } + { + auto bad_params = kGoodStrokeModelParams; + bad_params.prediction_params = + KalmanPredictorParams{.prediction_interval = Duration(-1)}; + EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); + } +} + +TEST(ParamsTest, NaNIsNotAValidValue) { + auto bad_params = kGoodStrokeModelParams; + bad_params.position_modeler_params.spring_mass_constant = + std::numeric_limits<float>::quiet_NaN(); + EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(ParamsTest, InfinityIsNotAValidValue) { + auto bad_params = kGoodStrokeModelParams; + bad_params.position_modeler_params.spring_mass_constant = + std::numeric_limits<float>::infinity(); + EXPECT_EQ(ValidateStrokeModelParams(bad_params).code(), + absl::StatusCode::kInvalidArgument); +} + +} // namespace +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/stroke_modeler.cc b/ink_stroke_modeler/stroke_modeler.cc new file mode 100644 index 0000000..37d17f5 --- /dev/null +++ b/ink_stroke_modeler/stroke_modeler.cc @@ -0,0 +1,253 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/stroke_modeler.h" + +#include <iterator> +#include <type_traits> +#include <vector> + +#include "absl/base/attributes.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/variant.h" +#include "ink_stroke_modeler/internal/internal_types.h" +#include "ink_stroke_modeler/internal/position_modeler.h" +#include "ink_stroke_modeler/internal/prediction/input_predictor.h" +#include "ink_stroke_modeler/internal/prediction/kalman_predictor.h" +#include "ink_stroke_modeler/internal/prediction/stroke_end_predictor.h" +#include "ink_stroke_modeler/internal/stylus_state_modeler.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { +namespace { + +std::vector<Result> ModelStylus( + const std::vector<TipState> &tip_states, + const StylusStateModeler &stylus_state_modeler) { + std::vector<Result> result; + result.reserve(tip_states.size()); + for (const auto &tip_state : tip_states) { + auto stylus_state = stylus_state_modeler.Query(tip_state.position); + result.push_back({.position = tip_state.position, + .velocity = tip_state.velocity, + .time = tip_state.time, + .pressure = stylus_state.pressure, + .tilt = stylus_state.tilt, + .orientation = stylus_state.orientation}); + } + return result; +} + +int GetNumberOfSteps(Time start_time, Time end_time, double min_rate) { + float float_delta = (end_time - start_time).Value(); + return std::ceil(float_delta * min_rate); +} + +template <typename> +ABSL_ATTRIBUTE_UNUSED inline constexpr bool kAlwaysFalse = false; + +} // namespace + +absl::Status StrokeModeler::Reset( + const StrokeModelParams &stroke_model_params) { + if (auto status = ValidateStrokeModelParams(stroke_model_params); + !status.ok()) { + return status; + } + + // Note that many of the sub-modelers require some knowledge about the stroke + // (e.g. start position, input type) when resetting, and as such are reset in + // ProcessTDown() instead. + stroke_model_params_ = stroke_model_params; + last_input_ = absl::nullopt; + + absl::visit( + [this](auto &¶ms) { + using ParamType = std::decay_t<decltype(params)>; + if constexpr (std::is_same_v<ParamType, KalmanPredictorParams>) { + predictor_ = absl::make_unique<KalmanPredictor>( + params, stroke_model_params_->sampling_params); + } else if constexpr (std::is_same_v<ParamType, + StrokeEndPredictorParams>) { + predictor_ = absl::make_unique<StrokeEndPredictor>( + stroke_model_params_->position_modeler_params, + stroke_model_params_->sampling_params); + } else { + static_assert(kAlwaysFalse<ParamType>, + "Unknown prediction parameter type"); + } + }, + stroke_model_params_->prediction_params); + return absl::OkStatus(); +} + +absl::StatusOr<std::vector<Result>> StrokeModeler::Update(const Input &input) { + if (!stroke_model_params_.has_value()) { + return absl::FailedPreconditionError( + "Stroke model has not yet been initialized"); + } + + if (absl::Status status = ValidateInput(input); !status.ok()) { + return status; + } + + if (last_input_) { + if (last_input_->input == input) { + return absl::InvalidArgumentError("Received duplicate input"); + } + + if (input.time < last_input_->input.time) { + return absl::InvalidArgumentError("Inputs travel backwards in time"); + } + } + + switch (input.event_type) { + case Input::EventType::kDown: + return ProcessDownEvent(input); + case Input::EventType::kMove: + return ProcessMoveEvent(input); + case Input::EventType::kUp: + return ProcessUpEvent(input); + } + return absl::InvalidArgumentError("Invalid EventType."); +} + +absl::StatusOr<std::vector<Result>> StrokeModeler::Predict() const { + if (!stroke_model_params_.has_value()) { + return absl::FailedPreconditionError( + "Stroke model has not yet been initialized"); + } + + if (last_input_ == std::nullopt) { + return absl::FailedPreconditionError( + "Cannot construct prediction when no stroke is in-progress"); + } + + return ModelStylus( + predictor_->ConstructPrediction(position_modeler_.CurrentState()), + stylus_state_modeler_); +} + +absl::StatusOr<std::vector<Result>> StrokeModeler::ProcessDownEvent( + const Input &input) { + if (last_input_) { + return absl::FailedPreconditionError( + "Received down event while stroke is in-progress"); + } + + // Note that many of the sub-modelers require some knowledge about the stroke + // (e.g. start position, input type) when resetting, and as such are reset + // here instead of in Reset(). + wobble_smoother_.Reset(stroke_model_params_->wobble_smoother_params, + input.position, input.time); + position_modeler_.Reset({input.position, {0, 0}, input.time}, + stroke_model_params_->position_modeler_params); + stylus_state_modeler_.Reset( + stroke_model_params_->stylus_state_modeler_params); + stylus_state_modeler_.Update(input.position, + {.pressure = input.pressure, + .tilt = input.tilt, + .orientation = input.orientation}); + + const TipState &tip_state = position_modeler_.CurrentState(); + predictor_->Reset(); + predictor_->Update(input.position, input.time); + + // We don't correct the position on the down event, so we set + // corrected_position to use the input position. + last_input_ = {.input = input, .corrected_position = input.position}; + return {{{.position = tip_state.position, + .velocity = tip_state.velocity, + .time = tip_state.time, + .pressure = input.pressure, + .tilt = input.tilt, + .orientation = input.orientation}}}; +} + +absl::StatusOr<std::vector<Result>> StrokeModeler::ProcessUpEvent( + const Input &input) { + if (!last_input_) { + return absl::FailedPreconditionError( + "Received up event while no stroke is in-progress"); + } + + int n_steps = + GetNumberOfSteps(last_input_->input.time, input.time, + stroke_model_params_->sampling_params.min_output_rate); + std::vector<TipState> tip_states; + tip_states.reserve( + n_steps + + stroke_model_params_->sampling_params.end_of_stroke_max_iterations); + position_modeler_.UpdateAlongLinearPath( + last_input_->corrected_position, last_input_->input.time, input.position, + input.time, n_steps, std::back_inserter(tip_states)); + + position_modeler_.ModelEndOfStroke( + input.position, + Duration(1. / stroke_model_params_->sampling_params.min_output_rate), + stroke_model_params_->sampling_params.end_of_stroke_max_iterations, + stroke_model_params_->sampling_params.end_of_stroke_stopping_distance, + std::back_inserter(tip_states)); + + if (tip_states.empty()) { + // If we haven't generated any new states, add the current state. This can + // happen if the TUp has the same timestamp as the last in-contact input. + tip_states.push_back(position_modeler_.CurrentState()); + } + + stylus_state_modeler_.Update(input.position, + {.pressure = input.pressure, + .tilt = input.tilt, + .orientation = input.orientation}); + + // This indicates that we've finished the stroke. + last_input_ = absl::nullopt; + + return ModelStylus(tip_states, stylus_state_modeler_); +} + +absl::StatusOr<std::vector<Result>> StrokeModeler::ProcessMoveEvent( + const Input &input) { + if (!last_input_) { + return absl::FailedPreconditionError( + "Received move event while no stroke is in-progress"); + } + + Vec2 corrected_position = wobble_smoother_.Update(input.position, input.time); + stylus_state_modeler_.Update(corrected_position, + {.pressure = input.pressure, + .tilt = input.tilt, + .orientation = input.orientation}); + + int n_steps = + GetNumberOfSteps(last_input_->input.time, input.time, + stroke_model_params_->sampling_params.min_output_rate); + std::vector<TipState> tip_states; + tip_states.reserve(n_steps); + position_modeler_.UpdateAlongLinearPath( + last_input_->corrected_position, last_input_->input.time, + corrected_position, input.time, n_steps, std::back_inserter(tip_states)); + + predictor_->Update(corrected_position, input.time); + last_input_ = {.input = input, .corrected_position = corrected_position}; + return ModelStylus(tip_states, stylus_state_modeler_); +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/stroke_modeler.h b/ink_stroke_modeler/stroke_modeler.h new file mode 100644 index 0000000..a25b971 --- /dev/null +++ b/ink_stroke_modeler/stroke_modeler.h @@ -0,0 +1,99 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_STROKE_MODELER_H_ +#define INK_STROKE_MODELER_STROKE_MODELER_H_ + +#include <memory> +#include <vector> + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "ink_stroke_modeler/internal/position_modeler.h" +#include "ink_stroke_modeler/internal/prediction/input_predictor.h" +#include "ink_stroke_modeler/internal/stylus_state_modeler.h" +#include "ink_stroke_modeler/internal/wobble_smoother.h" +#include "ink_stroke_modeler/params.h" +#include "ink_stroke_modeler/types.h" + +namespace ink { +namespace stroke_model { + +// This class models a stroke from a raw input stream. The modeling is performed +// in several stages, which are delegated to component classes: +// - Wobble Smoothing: Dampens high-frequency noise from quantization error. +// - Position Modeling: Models the pen tip as a mass, connected by a spring, to +// a moving anchor. +// - Stylus State Modeling: Constructs stylus states for modeled positions by +// interpolating over the raw input. +// +// Additionally, this class provides prediction of the modeled stroke. +// +// StrokeModeler is completely unit-agnostic. That is, it doesn't matter what +// units or coordinate-system the input is given in; the output will be given in +// the same coordinate-system and units. +class StrokeModeler { + public: + // Clears any in-progress stroke, and initializes (or re-initializes) the + // model with the given parameters. Returns an error if the parameters are + // invalid. + absl::Status Reset(const StrokeModelParams &stroke_model_params); + + // Updates the model with a raw input, returning the generated results. Any + // previously generated results are stable, i.e. any previously returned + // Results are still valid. + // + // Returns an error if the the model has not yet been initialized (via Reset) + // or if the input stream is malformed (e.g decreasing time, Up event before + // Down event). + // + // If this does not return an error, the result will contain at least one + // Result, and potentially more than one if the inputs are slower than + // the minimum output rate. + absl::StatusOr<std::vector<Result>> Update(const Input &input); + + // Model the given input prediction without changing the internal model state. + // + // Returns an error if the the model has not yet been initialized (via Reset), + // or if there is no stroke in progress. The output is limited to results + // where the predictor has sufficient confidence, + absl::StatusOr<std::vector<Result>> Predict() const; + + private: + absl::StatusOr<std::vector<Result>> ProcessDownEvent(const Input &input); + absl::StatusOr<std::vector<Result>> ProcessMoveEvent(const Input &input); + absl::StatusOr<std::vector<Result>> ProcessUpEvent(const Input &input); + + std::unique_ptr<InputPredictor> predictor_; + + absl::optional<StrokeModelParams> stroke_model_params_; + + WobbleSmoother wobble_smoother_; + PositionModeler position_modeler_; + StylusStateModeler stylus_state_modeler_; + + struct InputAndCorrectedPosition { + Input input; + Vec2 corrected_position{0}; + }; + absl::optional<InputAndCorrectedPosition> last_input_; +}; + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_STROKE_MODELER_H_ diff --git a/ink_stroke_modeler/stroke_modeler_test.cc b/ink_stroke_modeler/stroke_modeler_test.cc new file mode 100644 index 0000000..00f1dba --- /dev/null +++ b/ink_stroke_modeler/stroke_modeler_test.cc @@ -0,0 +1,1395 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/stroke_modeler.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "ink_stroke_modeler/internal/type_matchers.h" +#include "ink_stroke_modeler/params.h" + +namespace ink { +namespace stroke_model { +namespace { + +using ::testing::DoubleNear; +using ::testing::ElementsAre; +using ::testing::FloatNear; +using ::testing::IsEmpty; +using ::testing::Matches; +using ::testing::Not; + +constexpr float kTol = 1e-4; + +// These parameters use cm for distance and seconds for time. +const StrokeModelParams kDefaultParams{ + .wobble_smoother_params{ + .timeout = Duration(.04), .speed_floor = 1.31, .speed_ceiling = 1.44}, + .position_modeler_params{.spring_mass_constant = 11.f / 32400, + .drag_constant = 72.f}, + .sampling_params{.min_output_rate = 180, + .end_of_stroke_stopping_distance = .001, + .end_of_stroke_max_iterations = 20}, + .stylus_state_modeler_params{.max_input_samples = 20}, + .prediction_params = StrokeEndPredictorParams()}; + +MATCHER_P2(ResultNearMatcher, expected, tolerance, "") { + if (Matches(Vec2Near(expected.position, tolerance))(arg.position) && + Matches(Vec2Near(expected.velocity, tolerance))(arg.velocity) && + Matches(DoubleNear(expected.time.Value(), tolerance))(arg.time.Value()) && + Matches(FloatNear(expected.pressure, tolerance))(arg.pressure) && + Matches(FloatNear(expected.tilt, tolerance))(arg.tilt) && + Matches(FloatNear(expected.orientation, tolerance))(arg.orientation)) { + return true; + } + + return false; +} + +::testing::Matcher<Result> ResultNear(const Result &expected, float tolerance) { + return ResultNearMatcher(expected, tolerance); +} + +TEST(StrokeModelerTest, NoPredictionUponInit) { + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + EXPECT_EQ(modeler.Predict().status().code(), + absl::StatusCode::kFailedPrecondition); +} + +TEST(StrokeModelerTest, InputRateSlowerThanMinOutputRate) { + const Duration kDeltaTime{1. / 30}; + + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + + Time time{0}; + absl::StatusOr<std::vector<Result>> results = + modeler.Update({.event_type = Input::EventType::kDown, + .position = {3, 4}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT( + *results, + ElementsAre(ResultNear( + {.position = {3, 4}, .velocity = {0, 0}, .time = Time(0)}, kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, IsEmpty()); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {3.2, 4.2}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.0019, 4.0019}, + .velocity = {0.4007, 0.4007}, + .time = Time(0.0048)}, + kTol), + ResultNear({.position = {3.0069, 4.0069}, + .velocity = {1.0381, 1.0381}, + .time = Time(0.0095)}, + kTol), + ResultNear({.position = {3.0154, 4.0154}, + .velocity = {1.7883, 1.7883}, + .time = Time(0.0143)}, + kTol), + ResultNear({.position = {3.0276, 4.0276}, + .velocity = {2.5626, 2.5626}, + .time = Time(0.0190)}, + kTol), + ResultNear({.position = {3.0433, 4.0433}, + .velocity = {3.3010, 3.3010}, + .time = Time(0.0238)}, + kTol), + ResultNear({.position = {3.0622, 4.0622}, + .velocity = {3.9665, 3.9665}, + .time = Time(0.0286)}, + kTol), + ResultNear({.position = {3.0838, 4.0838}, + .velocity = {4.5397, 4.5397}, + .time = Time(0.0333)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.1095, 4.1095}, + .velocity = {4.6253, 4.6253}, + .time = Time(0.0389)}, + kTol), + ResultNear({.position = {3.1331, 4.1331}, + .velocity = {4.2563, 4.2563}, + .time = Time(0.0444)}, + kTol), + ResultNear({.position = {3.1534, 4.1534}, + .velocity = {3.6479, 3.6479}, + .time = Time(0.0500)}, + kTol), + ResultNear({.position = {3.1698, 4.1698}, + .velocity = {2.9512, 2.9512}, + .time = Time(0.0556)}, + kTol), + ResultNear({.position = {3.1824, 4.1824}, + .velocity = {2.2649, 2.2649}, + .time = Time(0.0611)}, + kTol), + ResultNear({.position = {3.1915, 4.1915}, + .velocity = {1.6473, 1.6473}, + .time = Time(0.0667)}, + kTol), + ResultNear({.position = {3.1978, 4.1978}, + .velocity = {1.1269, 1.1269}, + .time = Time(0.0722)}, + kTol), + ResultNear({.position = {3.1992, 4.1992}, + .velocity = {1.0232, 1.0232}, + .time = Time(0.0736)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {3.5, 4.2}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.1086, 4.1058}, + .velocity = {5.2142, 4.6131}, + .time = Time(0.0381)}, + kTol), + ResultNear({.position = {3.1368, 4.1265}, + .velocity = {5.9103, 4.3532}, + .time = Time(0.0429)}, + kTol), + ResultNear({.position = {3.1681, 4.1450}, + .velocity = {6.5742, 3.8917}, + .time = Time(0.0476)}, + kTol), + ResultNear({.position = {3.2022, 4.1609}, + .velocity = {7.1724, 3.3285}, + .time = Time(0.0524)}, + kTol), + ResultNear({.position = {3.2388, 4.1739}, + .velocity = {7.6876, 2.7361}, + .time = Time(0.0571)}, + kTol), + ResultNear({.position = {3.2775, 4.1842}, + .velocity = {8.1138, 2.1640}, + .time = Time(0.0619)}, + kTol), + ResultNear({.position = {3.3177, 4.1920}, + .velocity = {8.4531, 1.6436}, + .time = Time(0.0667)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.3625, 4.1982}, + .velocity = {8.0545, 1.1165}, + .time = Time(0.0722)}, + kTol), + ResultNear({.position = {3.4018, 4.2021}, + .velocity = {7.0831, 0.6987}, + .time = Time(0.0778)}, + kTol), + ResultNear({.position = {3.4344, 4.2043}, + .velocity = {5.8564, 0.3846}, + .time = Time(0.0833)}, + kTol), + ResultNear({.position = {3.4598, 4.2052}, + .velocity = {4.5880, 0.1611}, + .time = Time(0.0889)}, + kTol), + ResultNear({.position = {3.4788, 4.2052}, + .velocity = {3.4098, 0.0124}, + .time = Time(0.0944)}, + kTol), + ResultNear({.position = {3.4921, 4.2048}, + .velocity = {2.3929, -0.0780}, + .time = Time(0.1000)}, + kTol), + ResultNear({.position = {3.4976, 4.2045}, + .velocity = {1.9791, -0.1015}, + .time = Time(0.1028)}, + kTol), + ResultNear({.position = {3.5001, 4.2044}, + .velocity = {1.7911, -0.1098}, + .time = Time(0.1042)}, + kTol))); + + time += kDeltaTime; + // We get more results at the end of the stroke as it tries to "catch up" to + // the raw input. + results = modeler.Update({.event_type = Input::EventType::kUp, + .position = {3.7, 4.4}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {3.3583, 4.1996}, + .velocity = {8.5122, 1.5925}, + .time = Time(0.0714)}, + kTol), + ResultNear({.position = {3.3982, 4.2084}, + .velocity = {8.3832, 1.8534}, + .time = Time(0.0762)}, + kTol), + ResultNear({.position = {3.4369, 4.2194}, + .velocity = {8.1393, 2.3017}, + .time = Time(0.0810)}, + kTol), + ResultNear({.position = {3.4743, 4.2329}, + .velocity = {7.8362, 2.8434}, + .time = Time(0.0857)}, + kTol), + ResultNear({.position = {3.5100, 4.2492}, + .velocity = {7.5143, 3.4101}, + .time = Time(0.0905)}, + kTol), + ResultNear({.position = {3.5443, 4.2680}, + .velocity = {7.2016, 3.9556}, + .time = Time(0.0952)}, + kTol), + ResultNear({.position = {3.5773, 4.2892}, + .velocity = {6.9159, 4.4505}, + .time = Time(0.1000)}, + kTol), + ResultNear({.position = {3.6115, 4.3141}, + .velocity = {6.1580, 4.4832}, + .time = Time(0.1056)}, + kTol), + ResultNear({.position = {3.6400, 4.3369}, + .velocity = {5.1434, 4.0953}, + .time = Time(0.1111)}, + kTol), + ResultNear({.position = {3.6626, 4.3563}, + .velocity = {4.0671, 3.4902}, + .time = Time(0.1167)}, + kTol), + ResultNear({.position = {3.6796, 4.3719}, + .velocity = {3.0515, 2.8099}, + .time = Time(0.1222)}, + kTol), + ResultNear({.position = {3.6916, 4.3838}, + .velocity = {2.1648, 2.1462}, + .time = Time(0.1278)}, + kTol), + ResultNear({.position = {3.6996, 4.3924}, + .velocity = {1.4360, 1.5529}, + .time = Time(0.1333)}, + kTol), + ResultNear({.position = {3.7028, 4.3960}, + .velocity = {1.1520, 1.3044}, + .time = Time(0.1361)}, + kTol))); + + // The stroke is finished, so there's nothing to predict anymore. + EXPECT_EQ(modeler.Predict().status().code(), + absl::StatusCode::kFailedPrecondition); +} + +TEST(StrokeModelerTest, InputRateFasterThanMinOutputRate) { + const Duration kDeltaTime{1. / 300}; + + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + + Time time{2}; + absl::StatusOr<std::vector<Result>> results = + modeler.Update({.event_type = Input::EventType::kDown, + .position = {5, -3}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT( + *results, + ElementsAre(ResultNear( + {.position = {5, -3}, .velocity = {0, 0}, .time = Time(2)}, kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, IsEmpty()); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {5, -3.1}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {5, -3.0033}, + .velocity = {0, -0.9818}, + .time = Time(2.0033)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {5, -3.0153}, + .velocity = {0, -2.1719}, + .time = Time(2.0089)}, + kTol), + ResultNear({.position = {5, -3.0303}, + .velocity = {0, -2.6885}, + .time = Time(2.0144)}, + kTol), + ResultNear({.position = {5, -3.0456}, + .velocity = {0, -2.7541}, + .time = Time(2.0200)}, + kTol), + ResultNear({.position = {5, -3.0597}, + .velocity = {0, -2.5430}, + .time = Time(2.0256)}, + kTol), + ResultNear({.position = {5, -3.0718}, + .velocity = {0, -2.1852}, + .time = Time(2.0311)}, + kTol), + ResultNear({.position = {5, -3.0817}, + .velocity = {0, -1.7719}, + .time = Time(2.0367)}, + kTol), + ResultNear({.position = {5, -3.0893}, + .velocity = {0, -1.3628}, + .time = Time(2.0422)}, + kTol), + ResultNear({.position = {5, -3.0948}, + .velocity = {0, -0.9934}, + .time = Time(2.0478)}, + kTol), + ResultNear({.position = {5, -3.0986}, + .velocity = {0, -0.6815}, + .time = Time(2.0533)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {4.975, -3.175}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9992, -3.0114}, + .velocity = {-0.2455, -2.4322}, + .time = Time(2.0067)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9962, -3.0344}, + .velocity = {-0.5430, -4.1368}, + .time = Time(2.0122)}, + kTol), + ResultNear({.position = {4.9924, -3.0609}, + .velocity = {-0.6721, -4.7834}, + .time = Time(2.0178)}, + kTol), + ResultNear({.position = {4.9886, -3.0873}, + .velocity = {-0.6885, -4.7365}, + .time = Time(2.0233)}, + kTol), + ResultNear({.position = {4.9851, -3.1110}, + .velocity = {-0.6358, -4.2778}, + .time = Time(2.0289)}, + kTol), + ResultNear({.position = {4.9820, -3.1311}, + .velocity = {-0.5463, -3.6137}, + .time = Time(2.0344)}, + kTol), + ResultNear({.position = {4.9796, -3.1471}, + .velocity = {-0.4430, -2.8867}, + .time = Time(2.0400)}, + kTol), + ResultNear({.position = {4.9777, -3.1593}, + .velocity = {-0.3407, -2.1881}, + .time = Time(2.0456)}, + kTol), + ResultNear({.position = {4.9763, -3.1680}, + .velocity = {-0.2484, -1.5700}, + .time = Time(2.0511)}, + kTol), + ResultNear({.position = {4.9754, -3.1739}, + .velocity = {-0.1704, -1.0564}, + .time = Time(2.0567)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {4.9, -3.2}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9953, -3.0237}, + .velocity = {-1.1603, -3.7004}, + .time = Time(2.0100)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9828, -3.0521}, + .velocity = {-2.2559, -5.1049}, + .time = Time(2.0156)}, + kTol), + ResultNear({.position = {4.9677, -3.0825}, + .velocity = {-2.7081, -5.4835}, + .time = Time(2.0211)}, + kTol), + ResultNear({.position = {4.9526, -3.1115}, + .velocity = {-2.7333, -5.2122}, + .time = Time(2.0267)}, + kTol), + ResultNear({.position = {4.9387, -3.1369}, + .velocity = {-2.4999, -4.5756}, + .time = Time(2.0322)}, + kTol), + ResultNear({.position = {4.9268, -3.1579}, + .velocity = {-2.1326, -3.7776}, + .time = Time(2.0378)}, + kTol), + ResultNear({.position = {4.9173, -3.1743}, + .velocity = {-1.7184, -2.9554}, + .time = Time(2.0433)}, + kTol), + ResultNear({.position = {4.9100, -3.1865}, + .velocity = {-1.3136, -2.1935}, + .time = Time(2.0489)}, + kTol), + ResultNear({.position = {4.9047, -3.1950}, + .velocity = {-0.9513, -1.5369}, + .time = Time(2.0544)}, + kTol), + ResultNear({.position = {4.9011, -3.2006}, + .velocity = {-0.6475, -1.0032}, + .time = Time(2.0600)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {4.825, -3.2}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9868, -3.0389}, + .velocity = {-2.5540, -4.5431}, + .time = Time(2.0133)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9636, -3.0687}, + .velocity = {-4.1801, -5.3627}, + .time = Time(2.0189)}, + kTol), + ResultNear({.position = {4.9370, -3.0985}, + .velocity = {-4.7757, -5.3670}, + .time = Time(2.0244)}, + kTol), + ResultNear({.position = {4.9109, -3.1256}, + .velocity = {-4.6989, -4.8816}, + .time = Time(2.0300)}, + kTol), + ResultNear({.position = {4.8875, -3.1486}, + .velocity = {-4.2257, -4.1466}, + .time = Time(2.0356)}, + kTol), + ResultNear({.position = {4.8677, -3.1671}, + .velocity = {-3.5576, -3.3287}, + .time = Time(2.0411)}, + kTol), + ResultNear({.position = {4.8520, -3.1812}, + .velocity = {-2.8333, -2.5353}, + .time = Time(2.0467)}, + kTol), + ResultNear({.position = {4.8401, -3.1914}, + .velocity = {-2.1411, -1.8288}, + .time = Time(2.0522)}, + kTol), + ResultNear({.position = {4.8316, -3.1982}, + .velocity = {-1.5312, -1.2386}, + .time = Time(2.0578)}, + kTol), + ResultNear({.position = {4.8280, -3.2010}, + .velocity = {-1.2786, -1.0053}, + .time = Time(2.0606)}, + kTol), + ResultNear({.position = {4.8272, -3.2017}, + .velocity = {-1.2209, -0.9529}, + .time = Time(2.0613)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {4.75, -3.225}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9726, -3.0565}, + .velocity = {-4.2660, -5.2803}, + .time = Time(2.0167)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9381, -3.0894}, + .velocity = {-6.2018, -5.9261}, + .time = Time(2.0222)}, + kTol), + ResultNear({.position = {4.9004, -3.1215}, + .velocity = {-6.7995, -5.7749}, + .time = Time(2.0278)}, + kTol), + ResultNear({.position = {4.8640, -3.1501}, + .velocity = {-6.5400, -5.1591}, + .time = Time(2.0333)}, + kTol), + ResultNear({.position = {4.8319, -3.1741}, + .velocity = {-5.7897, -4.3207}, + .time = Time(2.0389)}, + kTol), + ResultNear({.position = {4.8051, -3.1932}, + .velocity = {-4.8133, -3.4248}, + .time = Time(2.0444)}, + kTol), + ResultNear({.position = {4.7841, -3.2075}, + .velocity = {-3.7898, -2.5759}, + .time = Time(2.0500)}, + kTol), + ResultNear({.position = {4.7683, -3.2176}, + .velocity = {-2.8312, -1.8324}, + .time = Time(2.0556)}, + kTol), + ResultNear({.position = {4.7572, -3.2244}, + .velocity = {-1.9986, -1.2198}, + .time = Time(2.0611)}, + kTol), + ResultNear({.position = {4.7526, -3.2271}, + .velocity = {-1.6580, -0.9805}, + .time = Time(2.0639)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {4.7, -3.3}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9529, -3.0778}, + .velocity = {-5.9184, -6.4042}, + .time = Time(2.0200)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9101, -3.1194}, + .velocity = {-7.6886, -7.4784}, + .time = Time(2.0256)}, + kTol), + ResultNear({.position = {4.8654, -3.1607}, + .velocity = {-8.0518, -7.4431}, + .time = Time(2.0311)}, + kTol), + ResultNear({.position = {4.8235, -3.1982}, + .velocity = {-7.5377, -6.7452}, + .time = Time(2.0367)}, + kTol), + ResultNear({.position = {4.7872, -3.2299}, + .velocity = {-6.5440, -5.7133}, + .time = Time(2.0422)}, + kTol), + ResultNear({.position = {4.7574, -3.2553}, + .velocity = {-5.3529, -4.5748}, + .time = Time(2.0478)}, + kTol), + ResultNear({.position = {4.7344, -3.2746}, + .velocity = {-4.1516, -3.4758}, + .time = Time(2.0533)}, + kTol), + ResultNear({.position = {4.7174, -3.2885}, + .velocity = {-3.0534, -2.5004}, + .time = Time(2.0589)}, + kTol), + ResultNear({.position = {4.7056, -3.2979}, + .velocity = {-2.1169, -1.6879}, + .time = Time(2.0644)}, + kTol), + ResultNear({.position = {4.7030, -3.3000}, + .velocity = {-1.9283, -1.5276}, + .time = Time(2.0658)}, + kTol), + ResultNear({.position = {4.7017, -3.3010}, + .velocity = {-1.8380, -1.4512}, + .time = Time(2.0665)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {4.675, -3.4}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9288, -3.1046}, + .velocity = {-7.2260, -8.0305}, + .time = Time(2.0233)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.8816, -3.1582}, + .velocity = {-8.4881, -9.6525}, + .time = Time(2.0289)}, + kTol), + ResultNear({.position = {4.8345, -3.2124}, + .velocity = {-8.4738, -9.7482}, + .time = Time(2.0344)}, + kTol), + ResultNear({.position = {4.7918, -3.2619}, + .velocity = {-7.6948, -8.9195}, + .time = Time(2.0400)}, + kTol), + ResultNear({.position = {4.7555, -3.3042}, + .velocity = {-6.5279, -7.6113}, + .time = Time(2.0456)}, + kTol), + ResultNear({.position = {4.7264, -3.3383}, + .velocity = {-5.2343, -6.1345}, + .time = Time(2.0511)}, + kTol), + ResultNear({.position = {4.7043, -3.3643}, + .velocity = {-3.9823, -4.6907}, + .time = Time(2.0567)}, + kTol), + ResultNear({.position = {4.6884, -3.3832}, + .velocity = {-2.8691, -3.3980}, + .time = Time(2.0622)}, + kTol), + ResultNear({.position = {4.6776, -3.3961}, + .velocity = {-1.9403, -2.3135}, + .time = Time(2.0678)}, + kTol), + ResultNear({.position = {4.6752, -3.3990}, + .velocity = {-1.7569, -2.0983}, + .time = Time(2.0692)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {4.675, -3.525}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.9022, -3.1387}, + .velocity = {-7.9833, -10.2310}, + .time = Time(2.0267)}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.8549, -3.2079}, + .velocity = {-8.5070, -12.4602}, + .time = Time(2.0322)}, + kTol), + ResultNear({.position = {4.8102, -3.2783}, + .velocity = {-8.0479, -12.6650}, + .time = Time(2.0378)}, + kTol), + ResultNear({.position = {4.7711, -3.3429}, + .velocity = {-7.0408, -11.6365}, + .time = Time(2.0433)}, + kTol), + ResultNear({.position = {4.7389, -3.3983}, + .velocity = {-5.7965, -9.9616}, + .time = Time(2.0489)}, + kTol), + ResultNear({.position = {4.7137, -3.4430}, + .velocity = {-4.5230, -8.0510}, + .time = Time(2.0544)}, + kTol), + ResultNear({.position = {4.6951, -3.4773}, + .velocity = {-3.3477, -6.1727}, + .time = Time(2.0600)}, + kTol), + ResultNear({.position = {4.6821, -3.5022}, + .velocity = {-2.3381, -4.4846}, + .time = Time(2.0656)}, + kTol), + ResultNear({.position = {4.6737, -3.5192}, + .velocity = {-1.5199, -3.0641}, + .time = Time(2.0711)}, + kTol), + ResultNear({.position = {4.6718, -3.5231}, + .velocity = {-1.3626, -2.7813}, + .time = Time(2.0725)}, + kTol))); + + time += kDeltaTime; + // We get more results at the end of the stroke as it tries to "catch up" to + // the raw input. + results = modeler.Update({.event_type = Input::EventType::kUp, + .position = {4.7, -3.6}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {4.8753, -3.1797}, + .velocity = {-8.0521, -12.3049}, + .time = Time(2.0300)}, + kTol), + ResultNear({.position = {4.8325, -3.2589}, + .velocity = {-7.7000, -14.2607}, + .time = Time(2.0356)}, + kTol), + ResultNear({.position = {4.7948, -3.3375}, + .velocity = {-6.7888, -14.1377}, + .time = Time(2.0411)}, + kTol), + ResultNear({.position = {4.7636, -3.4085}, + .velocity = {-5.6249, -12.7787}, + .time = Time(2.0467)}, + kTol), + ResultNear({.position = {4.7390, -3.4685}, + .velocity = {-4.4152, -10.8015}, + .time = Time(2.0522)}, + kTol), + ResultNear({.position = {4.7208, -3.5164}, + .velocity = {-3.2880, -8.6333}, + .time = Time(2.0578)}, + kTol), + ResultNear({.position = {4.7079, -3.5528}, + .velocity = {-2.3128, -6.5475}, + .time = Time(2.0633)}, + kTol), + ResultNear({.position = {4.6995, -3.5789}, + .velocity = {-1.5174, -4.7008}, + .time = Time(2.0689)}, + kTol), + ResultNear({.position = {4.6945, -3.5965}, + .velocity = {-0.9022, -3.1655}, + .time = Time(2.0744)}, + kTol), + ResultNear({.position = {4.6942, -3.5976}, + .velocity = {-0.8740, -3.0899}, + .time = Time(2.0748)}, + kTol))); + + // The stroke is finished, so there's nothing to predict anymore. + EXPECT_EQ(modeler.Predict().status().code(), + absl::StatusCode::kFailedPrecondition); +} + +TEST(StrokeModelerTest, WobbleSmoothed) { + const Duration kDeltaTime{.0167}; + + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + + Time time{4}; + absl::StatusOr<std::vector<Result>> results = + modeler.Update({.event_type = Input::EventType::kDown, + .position = {-6, -2}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT( + *results, + ElementsAre(ResultNear( + {.position = {-6, -2}, .velocity = {0, 0}, .time = Time(4)}, kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {-6.02, -2}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {-6.0001, -2}, + .velocity = {-0.0328, 0}, + .time = Time(4.0042)}, + kTol), + ResultNear({.position = {-6.0005, -2}, + .velocity = {-0.0869, 0}, + .time = Time(4.0084)}, + kTol), + ResultNear({.position = {-6.0011, -2}, + .velocity = {-0.1531, 0}, + .time = Time(4.0125)}, + kTol), + ResultNear({.position = {-6.0021, -2}, + .velocity = {-0.2244, 0}, + .time = Time(4.0167)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {-6.02, -2.02}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {-6.0032, -2.0001}, + .velocity = {-0.2709, -0.0205}, + .time = Time(4.0209)}, + kTol), + ResultNear({.position = {-6.0044, -2.0003}, + .velocity = {-0.2977, -0.0543}, + .time = Time(4.0251)}, + kTol), + ResultNear({.position = {-6.0057, -2.0007}, + .velocity = {-0.3093, -0.0956}, + .time = Time(4.0292)}, + kTol), + ResultNear({.position = {-6.0070, -2.0013}, + .velocity = {-0.3097, -0.1401}, + .time = Time(4.0334)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {-6.04, -2.02}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {-6.0084, -2.0021}, + .velocity = {-0.3350, -0.1845}, + .time = Time(4.0376)}, + kTol), + ResultNear({.position = {-6.0100, -2.0030}, + .velocity = {-0.3766, -0.2266}, + .time = Time(4.0418)}, + kTol), + ResultNear({.position = {-6.0118, -2.0041}, + .velocity = {-0.4273, -0.2649}, + .time = Time(4.0459)}, + kTol), + ResultNear({.position = {-6.0138, -2.0054}, + .velocity = {-0.4818, -0.2986}, + .time = Time(4.0501)}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {-6.04, -2.04}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({.position = {-6.0160, -2.0068}, + .velocity = {-0.5157, -0.3478}, + .time = Time(4.0543)}, + kTol), + ResultNear({.position = {-6.0182, -2.0085}, + .velocity = {-0.5334, -0.4054}, + .time = Time(4.0585)}, + kTol), + ResultNear({.position = {-6.0204, -2.0105}, + .velocity = {-0.5389, -0.4658}, + .time = Time(4.0626)}, + kTol), + ResultNear({.position = {-6.0227, -2.0126}, + .velocity = {-0.5356, -0.5251}, + .time = Time(4.0668)}, + kTol))); +} + +TEST(StrokeModelerTest, Reset) { + const Duration kDeltaTime{1. / 50}; + + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + + Time time{0}; + absl::StatusOr<std::vector<Result>> results = + modeler.Update({.event_type = Input::EventType::kDown, + .position = {-8, -10}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, IsEmpty()); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {-10, -8}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {-11, -5}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + EXPECT_EQ(modeler.Predict().status().code(), + absl::StatusCode::kFailedPrecondition); +} + +TEST(StrokeModelerTest, IgnoreInputsBeforeTDown) { + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + + EXPECT_EQ(modeler + .Update({.event_type = Input::EventType::kMove, + .position = {0, 0}, + .time = Time(0)}) + .status() + .code(), + absl::StatusCode::kFailedPrecondition); + + EXPECT_EQ(modeler + .Update({.event_type = Input::EventType::kUp, + .position = {0, 0}, + .time = Time(1)}) + .status() + .code(), + absl::StatusCode::kFailedPrecondition); +} + +TEST(StrokeModelerTest, IgnoreTDownWhileStrokeIsInProgress) { + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + + absl::StatusOr<std::vector<Result>> results = + modeler.Update({.event_type = Input::EventType::kDown, + .position = {0, 0}, + .time = Time(0)}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + EXPECT_EQ(modeler + .Update({.event_type = Input::EventType::kDown, + .position = {1, 1}, + .time = Time(1)}) + .status() + .code(), + absl::StatusCode::kFailedPrecondition); + + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {1, 1}, + .time = Time(1)}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + EXPECT_EQ(modeler + .Update({.event_type = Input::EventType::kDown, + .position = {2, 2}, + .time = Time(2)}) + .status() + .code(), + absl::StatusCode::kFailedPrecondition); +} + +TEST(StrokeModelerTest, AlternateParams) { + const Duration kDeltaTime{1. / 50}; + + StrokeModelParams stroke_model_params = kDefaultParams; + stroke_model_params.sampling_params.min_output_rate = 70; + + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(stroke_model_params).ok()); + + Time time{3}; + absl::StatusOr<std::vector<Result>> results = + modeler.Update({.event_type = Input::EventType::kDown, + .position = {0, 0}, + .time = time, + .pressure = .5, + .tilt = .2, + .orientation = .4}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear( + {{0, 0}, {0, 0}, Time{3}, .5, .2, .4}, kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, IsEmpty()); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {0, .5}, + .time = time, + .pressure = .4, + .tilt = .3, + .orientation = .3}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT( + *results, + ElementsAre( + ResultNear( + {{0, 0.0736}, {0, 7.3636}, Time{3.0100}, 0.4853, 0.2147, 0.3853}, + kTol), + ResultNear( + {{0, 0.2198}, {0, 14.6202}, Time{3.0200}, 0.4560, 0.2440, 0.3560}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT( + *results, + ElementsAre( + ResultNear( + {{0, 0.3823}, {0, 11.3709}, Time{3.0343}, 0.4235, 0.2765, 0.3235}, + kTol), + ResultNear( + {{0, 0.4484}, {0, 4.6285}, Time{3.0486}, 0.4103, 0.2897, 0.3103}, + kTol), + ResultNear( + {{0, 0.4775}, {0, 2.0389}, Time{3.0629}, 0.4045, 0.2955, 0.3045}, + kTol), + ResultNear( + {{0, 0.4902}, {0, 0.8873}, Time{3.0771}, 0.4020, 0.2980, 0.3020}, + kTol), + ResultNear( + {{0, 0.4957}, {0, 0.3868}, Time{3.0914}, 0.4009, 0.2991, 0.3009}, + kTol), + ResultNear( + {{0, 0.4981}, {0, 0.1686}, Time{3.1057}, 0.4004, 0.2996, 0.3004}, + kTol), + ResultNear( + {{0, 0.4992}, {0, 0.0735}, Time{3.1200}, 0.4002, 0.2998, 0.3002}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {.2, 1}, + .time = time, + .pressure = .3, + .tilt = .4, + .orientation = .2}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({{0.0295, 0.4169}, + {2.9455, 19.7093}, + Time{3.0300}, + 0.4166, + 0.2834, + 0.3166}, + kTol), + ResultNear({{0.0879, 0.6439}, + {5.8481, 22.6926}, + Time{3.0400}, + 0.3691, + 0.3309, + 0.2691}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({{0.1529, 0.8487}, + {4.5484, 14.3374}, + Time{3.0543}, + 0.3293, + 0.3707, + 0.2293}, + kTol), + ResultNear({{0.1794, 0.9338}, + {1.8514, 5.9577}, + Time{3.0686}, + 0.3128, + 0.3872, + 0.2128}, + kTol), + ResultNear({{0.1910, 0.9712}, + {0.8156, 2.6159}, + Time{3.0829}, + 0.3056, + 0.3944, + 0.2056}, + kTol), + ResultNear({{0.1961, 0.9874}, + {0.3549, 1.1389}, + Time{3.0971}, + 0.3024, + 0.3976, + 0.2024}, + kTol), + ResultNear({{0.1983, 0.9945}, + {0.1547, 0.4965}, + Time{3.1114}, + 0.3011, + 0.3989, + 0.2011}, + kTol), + ResultNear({{0.1993, 0.9976}, + {0.0674, 0.2164}, + Time{3.1257}, + 0.3005, + 0.3995, + 0.2005}, + kTol), + ResultNear({{0.1997, 0.9990}, + {0.0294, 0.0943}, + Time{3.1400}, + 0.3002, + 0.3998, + 0.2002}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {.4, 1.4}, + .time = time, + .pressure = .2, + .tilt = .7, + .orientation = 0}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({{0.1668, 0.8712}, + {7.8837, 22.7349}, + Time{3.0500}, + 0.3245, + 0.3755, + 0.2245}, + kTol), + ResultNear({{0.2575, 1.0906}, + {9.0771, 21.9411}, + Time{3.0600}, + 0.2761, + 0.4716, + 0.1522}, + kTol))); + + results = modeler.Predict(); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({{0.3395, 1.2676}, + {5.7349, 12.3913}, + Time{3.0743}, + 0.2325, + 0.6024, + 0.0651}, + kTol), + ResultNear({{0.3735, 1.3421}, + {2.3831, 5.2156}, + Time{3.0886}, + 0.2142, + 0.6573, + 0.0284}, + kTol), + ResultNear({{0.3885, 1.3748}, + {1.0463, 2.2854}, + Time{3.1029}, + 0.2062, + 0.6814, + 0.0124}, + kTol), + ResultNear({{0.3950, 1.3890}, + {0.4556, 0.9954}, + Time{3.1171}, + 0.2027, + 0.6919, + 0.0054}, + kTol), + ResultNear({{0.3978, 1.3952}, + {0.1986, 0.4339}, + Time{3.1314}, + 0.2012, + 0.6965, + 0.0024}, + kTol), + ResultNear({{0.3990, 1.3979}, + {0.0866, 0.1891}, + Time{3.1457}, + 0.2005, + 0.6985, + 0.0010}, + kTol), + ResultNear({{0.3996, 1.3991}, + {0.0377, 0.0824}, + Time{3.1600}, + 0.2002, + 0.6993, + 0.0004}, + kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kUp, + .position = {.7, 1.7}, + .time = time, + .pressure = .1, + .tilt = 1, + .orientation = 0}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, ElementsAre(ResultNear({{0.3691, 1.2874}, + {11.1558, 19.6744}, + Time{3.0700}, + 0.2256, + 0.6231, + 0.0512}, + kTol), + ResultNear({{0.4978, 1.4640}, + {12.8701, 17.6629}, + Time{3.0800}, + 0.1730, + 0.7809, + 0}, + kTol), + ResultNear({{0.6141, 1.5986}, + {8.1404, 9.4261}, + Time{3.0943}, + 0.1312, + 0.9064, + 0}, + kTol), + ResultNear({{0.6624, 1.6557}, + {3.3822, 3.9953}, + Time{3.1086}, + 0.1136, + 0.9591, + 0}, + kTol), + ResultNear({{0.6836, 1.6807}, + {1.4851, 1.7488}, + Time{3.1229}, + 0.1059, + 0.9822, + 0}, + kTol), + ResultNear({{0.6929, 1.6916}, + {0.6466, 0.7618}, + Time{3.1371}, + 0.1026, + 0.9922, + 0}, + kTol), + ResultNear({{0.6969, 1.6963}, + {0.2819, 0.3321}, + Time{3.1514}, + 0.1011, + 0.9966, + 0}, + kTol), + ResultNear({{0.6986, 1.6984}, + {0.1229, 0.1447}, + Time{3.1657}, + 0.1005, + 0.9985, + 0}, + kTol), + ResultNear({{0.6994, 1.6993}, + {0.0535, 0.0631}, + Time{3.1800}, + 0.1002, + 0.9994, + 0}, + kTol))); + + EXPECT_EQ(modeler.Predict().status().code(), + absl::StatusCode::kFailedPrecondition); +} + +TEST(StrokeModelerTest, GenerateOutputOnTUpEvenIfNoTimeDelta) { + const Duration kDeltaTime{1. / 500}; + + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + + Time time{0}; + absl::StatusOr<std::vector<Result>> results = + modeler.Update({.event_type = Input::EventType::kDown, + .position = {5, 5}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT( + *results, + ElementsAre(ResultNear( + {.position = {5, 5}, .velocity = {0, 0}, .time = Time(0)}, kTol))); + + time += kDeltaTime; + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {5, 5}, + .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, + ElementsAre(ResultNear( + {.position = {5, 5}, .velocity = {0, 0}, .time = Time(0.002)}, + kTol))); + + results = modeler.Update( + {.event_type = Input::EventType::kUp, .position = {5, 5}, .time = time}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT( + *results, + ElementsAre(ResultNear( + {.position = {5, 5}, .velocity = {0, 0}, .time = Time(0.0076)}, + kTol))); +} + +TEST(StrokeModelerTest, RejectInputIfNegativeTimeDelta) { + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + + absl::StatusOr<std::vector<Result>> results = + modeler.Update({.event_type = Input::EventType::kDown, + .position = {0, 0}, + .time = Time(0)}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + EXPECT_EQ(modeler + .Update({.event_type = Input::EventType::kMove, + .position = {1, 1}, + .time = Time(-.1)}) + .status() + .code(), + absl::StatusCode::kInvalidArgument); + + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {1, 1}, + .time = Time(1)}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + EXPECT_EQ(modeler + .Update({.event_type = Input::EventType::kUp, + .position = {1, 1}, + .time = Time(.9)}) + .status() + .code(), + absl::StatusCode::kInvalidArgument); +} + +TEST(StrokeModelerTest, RejectDuplicateInput) { + StrokeModeler modeler; + ASSERT_TRUE(modeler.Reset(kDefaultParams).ok()); + + absl::StatusOr<std::vector<Result>> results = + modeler.Update({.event_type = Input::EventType::kDown, + .position = {0, 0}, + .time = Time(0), + .pressure = .2, + .tilt = .3, + .orientation = .4}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + EXPECT_EQ(modeler + .Update({.event_type = Input::EventType::kDown, + .position = {0, 0}, + .time = Time(0), + .pressure = .2, + .tilt = .3, + .orientation = .4}) + .status() + .code(), + absl::StatusCode::kInvalidArgument); + + results = modeler.Update({.event_type = Input::EventType::kMove, + .position = {1, 2}, + .time = Time(1), + .pressure = .1, + .tilt = .2, + .orientation = .3}); + ASSERT_TRUE(results.ok()); + EXPECT_THAT(*results, Not(IsEmpty())); + + EXPECT_EQ(modeler + .Update({.event_type = Input::EventType::kMove, + .position = {1, 2}, + .time = Time(1), + .pressure = .1, + .tilt = .2, + .orientation = .3}) + .status() + .code(), + absl::StatusCode::kInvalidArgument); +} + +} // namespace +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/types.cc b/ink_stroke_modeler/types.cc new file mode 100644 index 0000000..9864ea3 --- /dev/null +++ b/ink_stroke_modeler/types.cc @@ -0,0 +1,36 @@ +#include "ink_stroke_modeler/types.h" + +#include "absl/status/status.h" +#include "ink_stroke_modeler/internal/validation.h" + +// This convenience macro evaluates the given expression, and if it does not +// return an OK status, returns and propagates the status. +#define RETURN_IF_ERROR(expr) \ + do { \ + if (auto status = (expr); !status.ok()) return status; \ + } while (false) + +namespace ink { +namespace stroke_model { + +absl::Status ValidateInput(const Input& input) { + switch (input.event_type) { + case Input::EventType::kUp: + case Input::EventType::kMove: + case Input::EventType::kDown: + break; + default: + return absl::InvalidArgumentError("Unknown Input.event_type."); + } + RETURN_IF_ERROR(ValidateIsFiniteNumber(input.position.x, "Input.position.x")); + RETURN_IF_ERROR(ValidateIsFiniteNumber(input.position.y, "Input.position.y")); + RETURN_IF_ERROR(ValidateIsFiniteNumber(input.time.Value(), "Input.time")); + RETURN_IF_ERROR(ValidateIsFiniteNumber(input.pressure, "Input.pressure")); + RETURN_IF_ERROR(ValidateIsFiniteNumber(input.tilt, "Input.tilt")); + RETURN_IF_ERROR( + ValidateIsFiniteNumber(input.orientation, "Input.orientation")); + return absl::OkStatus(); +} + +} // namespace stroke_model +} // namespace ink diff --git a/ink_stroke_modeler/types.h b/ink_stroke_modeler/types.h new file mode 100644 index 0000000..cc77565 --- /dev/null +++ b/ink_stroke_modeler/types.h @@ -0,0 +1,356 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INK_STROKE_MODELER_TYPES_H_ +#define INK_STROKE_MODELER_TYPES_H_ + +#include <cmath> +#include <ostream> + +#include "absl/status/status.h" + +namespace ink { +namespace stroke_model { + +// A vector (or point) in 2D space. +struct Vec2 { + float x = 0; + float y = 0; + + // The length of the vector, i.e. its distance from the origin. + float Magnitude() const { return std::sqrt(x * x + y * y); } +}; + +bool operator==(Vec2 lhs, Vec2 rhs); +bool operator!=(Vec2 lhs, Vec2 rhs); + +Vec2 operator+(Vec2 lhs, Vec2 rhs); +Vec2 operator-(Vec2 lhs, Vec2 rhs); +Vec2 operator*(float scalar, Vec2 v); +Vec2 operator*(Vec2 v, float scalar); +Vec2 operator/(Vec2 v, float scalar); + +Vec2 &operator+=(Vec2 &lhs, Vec2 rhs); +Vec2 &operator-=(Vec2 &lhs, Vec2 rhs); +Vec2 &operator*=(Vec2 &lhs, float scalar); +Vec2 &operator/=(Vec2 &lhs, float scalar); + +std::ostream &operator<<(std::ostream &stream, Vec2 v); + +// This represents a duration of time, i.e. the difference between two points in +// time (as represented by class Time, below). This class is unit-agnostic; it +// could represent e.g. hours, seconds, or years. +class Duration { + public: + Duration() : Duration(0) {} + explicit Duration(double value) : value_(value) {} + double Value() const { return value_; } + + private: + double value_ = 0; +}; + +Duration operator+(Duration lhs, Duration rhs); +Duration operator-(Duration lhs, Duration rhs); +Duration operator*(Duration duration, double scalar); +Duration operator*(double scalar, Duration duration); +Duration operator/(Duration duration, double scalar); + +Duration &operator+=(Duration &lhs, Duration rhs); +Duration &operator-=(Duration &lhs, Duration rhs); +Duration &operator*=(Duration &duration, double scalar); +Duration &operator/=(Duration &duration, double scalar); + +bool operator==(Duration lhs, Duration rhs); +bool operator!=(Duration lhs, Duration rhs); +bool operator<(Duration lhs, Duration rhs); +bool operator>(Duration lhs, Duration rhs); +bool operator<=(Duration lhs, Duration rhs); +bool operator>=(Duration lhs, Duration rhs); + +std::ostream &operator<<(std::ostream &s, Duration duration); + +// This represents a point in time. This class is unit- and offset-agnostic; it +// could be measured in e.g. hours, seconds, or years, and Time(0) has no +// specific meaning outside of the context it is used in. +class Time { + public: + Time() : Time(0) {} + explicit Time(double value) : value_(value) {} + double Value() const { return value_; } + + private: + double value_ = 0; +}; + +Time operator+(Time time, Duration duration); +Time operator+(Duration duration, Time time); +Time operator-(Time time, Duration duration); +Duration operator-(Time lhs, Time rhs); + +Time &operator+=(Time &time, Duration duration); +Time &operator-=(Time &time, Duration duration); + +bool operator==(Time lhs, Time rhs); +bool operator!=(Time lhs, Time rhs); +bool operator<(Time lhs, Time rhs); +bool operator>(Time lhs, Time rhs); +bool operator<=(Time lhs, Time rhs); +bool operator>=(Time lhs, Time rhs); + +std::ostream &operator<<(std::ostream &s, Time time); + +// The input passed to the stroke modeler. +struct Input { + // The type of event represented by the input. A "kDown" event represents + // the beginning of the stroke, a "kUp" event represents the end of the + // stroke, and all events in between are "kMove" events. + enum class EventType { kDown, kMove, kUp }; + EventType event_type; + + // The position of the input. + Vec2 position{0}; + + // The time at which the input occurs. + Time time{0}; + + // The amount of pressure applied to the stylus. This is expected to lie in + // the range [0, 1]. A negative value indicates unknown pressure. + float pressure = -1; + + // The angle between the stylus and the plane of the device's screen. This + // is expected to lie in the range [0, π/2]. A value of 0 indicates that the + // stylus is perpendicular to the screen, while a value of π/2 indicates + // that it is flush with the screen. A negative value indicates unknown + // tilt. + float tilt = -1; + + // The angle between the projection of the stylus onto the screen and the + // positive x-axis, measured counter-clockwise. This is expected to lie in + // the range [0, 2π). A negative value indicates unknown orientation. + float orientation = -1; +}; + +bool operator==(const Input &lhs, const Input &rhs); +bool operator!=(const Input &lhs, const Input &rhs); + +std::ostream &operator<<(std::ostream &s, Input::EventType event_type); +std::ostream &operator<<(std::ostream &s, const Input &input); + +absl::Status ValidateInput(const Input &input); + +// A modeled input produced by the stroke modeler. +struct Result { + // The position and velocity of the stroke tip. + Vec2 position{0}; + Vec2 velocity{0}; + + // The time at which this input occurs. + Time time{0}; + + // These pressure, tilt, and orientation of the stylus. See the + // corresponding fields on the Input struct for more info. + float pressure = -1; + float tilt = -1; + float orientation = -1; +}; + +std::ostream &operator<<(std::ostream &s, const Result &result); + +//////////////////////////////////////////////////////////////////////////////// +// Inline function definitions +//////////////////////////////////////////////////////////////////////////////// + +inline bool operator==(Vec2 lhs, Vec2 rhs) { + return lhs.x == rhs.x && lhs.y == rhs.y; +} +inline bool operator!=(Vec2 lhs, Vec2 rhs) { return !(lhs == rhs); } + +inline Vec2 operator+(Vec2 lhs, Vec2 rhs) { + return {.x = lhs.x + rhs.x, .y = lhs.y + rhs.y}; +} +inline Vec2 operator-(Vec2 lhs, Vec2 rhs) { + return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y}; +} +inline Vec2 operator*(float scalar, Vec2 v) { + return {.x = scalar * v.x, .y = scalar * v.y}; +} +inline Vec2 operator*(Vec2 v, float scalar) { return scalar * v; } +inline Vec2 operator/(Vec2 v, float scalar) { + return {.x = v.x / scalar, .y = v.y / scalar}; +} + +inline Vec2 &operator+=(Vec2 &lhs, Vec2 rhs) { + lhs.x += rhs.x; + lhs.y += rhs.y; + return lhs; +} +inline Vec2 &operator-=(Vec2 &lhs, Vec2 rhs) { + lhs.x -= rhs.x; + lhs.y -= rhs.y; + return lhs; +} +inline Vec2 &operator*=(Vec2 &lhs, float scalar) { + lhs.x *= scalar; + lhs.y *= scalar; + return lhs; +} +inline Vec2 &operator/=(Vec2 &lhs, float scalar) { + lhs.x /= scalar; + lhs.y /= scalar; + return lhs; +} + +inline std::ostream &operator<<(std::ostream &stream, Vec2 v) { + return stream << "(" << v.x << ", " << v.y << ")"; +} + +inline Duration operator+(Duration lhs, Duration rhs) { + return Duration(lhs.Value() + rhs.Value()); +} +inline Duration operator-(Duration lhs, Duration rhs) { + return Duration(lhs.Value() - rhs.Value()); +} +inline Duration operator*(Duration duration, double scalar) { + return Duration(duration.Value() * scalar); +} +inline Duration operator*(double scalar, Duration duration) { + return Duration(scalar * duration.Value()); +} +inline Duration operator/(Duration duration, double scalar) { + return Duration(duration.Value() / scalar); +} + +inline Duration &operator+=(Duration &lhs, Duration rhs) { + lhs = lhs + rhs; + return lhs; +} +inline Duration &operator-=(Duration &lhs, Duration rhs) { + lhs = lhs - rhs; + return lhs; +} +inline Duration &operator*=(Duration &duration, double scalar) { + duration = duration * scalar; + return duration; +} +inline Duration &operator/=(Duration &duration, double scalar) { + duration = duration / scalar; + return duration; +} + +inline bool operator==(Duration lhs, Duration rhs) { + return lhs.Value() == rhs.Value(); +} +inline bool operator!=(Duration lhs, Duration rhs) { + return lhs.Value() != rhs.Value(); +} +inline bool operator<(Duration lhs, Duration rhs) { + return lhs.Value() < rhs.Value(); +} +inline bool operator>(Duration lhs, Duration rhs) { + return lhs.Value() > rhs.Value(); +} +inline bool operator<=(Duration lhs, Duration rhs) { + return lhs.Value() <= rhs.Value(); +} +inline bool operator>=(Duration lhs, Duration rhs) { + return lhs.Value() >= rhs.Value(); +} + +inline Time operator+(Time time, Duration duration) { + return Time(time.Value() + duration.Value()); +} +inline Time operator+(Duration duration, Time time) { + return Time(time.Value() + duration.Value()); +} +inline Time operator-(Time time, Duration duration) { + return Time(time.Value() - duration.Value()); +} +inline Duration operator-(Time lhs, Time rhs) { + return Duration(lhs.Value() - rhs.Value()); +} + +inline Time &operator+=(Time &time, Duration duration) { + time = time + duration; + return time; +} +inline Time &operator-=(Time &time, Duration duration) { + time = time - duration; + return time; +} + +inline bool operator==(Time lhs, Time rhs) { + return lhs.Value() == rhs.Value(); +} +inline bool operator!=(Time lhs, Time rhs) { + return lhs.Value() != rhs.Value(); +} +inline bool operator<(Time lhs, Time rhs) { return lhs.Value() < rhs.Value(); } +inline bool operator>(Time lhs, Time rhs) { return lhs.Value() > rhs.Value(); } +inline bool operator<=(Time lhs, Time rhs) { + return lhs.Value() <= rhs.Value(); +} +inline bool operator>=(Time lhs, Time rhs) { + return lhs.Value() >= rhs.Value(); +} + +inline bool operator==(const Input &lhs, const Input &rhs) { + return lhs.event_type == rhs.event_type && lhs.position == rhs.position && + lhs.time == rhs.time && lhs.pressure == rhs.pressure && + lhs.tilt == rhs.tilt && lhs.orientation == rhs.orientation; +} +inline bool operator!=(const Input &lhs, const Input &rhs) { + return !(lhs == rhs); +} + +inline std::ostream &operator<<(std::ostream &s, Duration duration) { + return s << duration.Value(); +} + +inline std::ostream &operator<<(std::ostream &s, Time time) { + return s << time.Value(); +} + +inline std::ostream &operator<<(std::ostream &s, Input::EventType event_type) { + switch (event_type) { + case Input::EventType::kDown: + return s << "Down"; + case Input::EventType::kMove: + return s << "Move"; + case Input::EventType::kUp: + return s << "Up"; + } + return s << "UnknownEventType<" << event_type << ">"; +} + +inline std::ostream &operator<<(std::ostream &s, const Input &input) { + return s << "<Input: " << input.event_type << ", pos: " << input.position + << ", time: " << input.time << ", pressure: " << input.pressure + << ", tilt: " << input.tilt << ", orientation: " << input.orientation + << ">"; +} + +inline std::ostream &operator<<(std::ostream &s, const Result &result) { + return s << "<Result: pos: " << result.position + << ", velocity: " << result.velocity << ", time: " << result.time + << ", pressure: " << result.pressure << ", tilt: " << result.tilt + << ", orientation: " << result.orientation << ">"; +} + +} // namespace stroke_model +} // namespace ink + +#endif // INK_STROKE_MODELER_TYPES_H_ diff --git a/ink_stroke_modeler/types_test.cc b/ink_stroke_modeler/types_test.cc new file mode 100644 index 0000000..e812d76 --- /dev/null +++ b/ink_stroke_modeler/types_test.cc @@ -0,0 +1,212 @@ +// Copyright 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ink_stroke_modeler/types.h" + +#include <cmath> +#include <limits> +#include <sstream> +#include <string> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "ink_stroke_modeler/internal/type_matchers.h" + +namespace ink { +namespace stroke_model { +namespace { + +using ::testing::Not; + +TEST(TypesTest, Vec2Equality) { + EXPECT_EQ((Vec2{1, 2}), (Vec2{1, 2})); + EXPECT_EQ((Vec2{-.4, 17}), (Vec2{-.4, 17})); + + EXPECT_NE((Vec2{1, 2}), (Vec2{1, 5})); + EXPECT_NE((Vec2{3, 2}), (Vec2{.7, 2})); + EXPECT_NE((Vec2{-4, .3}), (Vec2{.5, 12})); +} + +TEST(TypesTest, Vec2EqMatcher) { + EXPECT_THAT((Vec2{1, 2}), Vec2Eq({1, 2})); + EXPECT_THAT((Vec2{3, 4}), Not(Vec2Eq({3, 5}))); + EXPECT_THAT((Vec2{5, 6}), Not(Vec2Eq({4, 6}))); + + // Vec2Eq delegates to FloatEq, which uses a tolerance of 4 ULP. + constexpr float kEps = std::numeric_limits<float>::epsilon(); + EXPECT_THAT((Vec2{1, 1}), Vec2Eq({1 + kEps, 1 - kEps})); +} + +TEST(TypesTest, Vec2NearMatcher) { + EXPECT_THAT((Vec2{1, 2}), Vec2Near({1.05, 1.95}, .1)); + EXPECT_THAT((Vec2{3, 4}), Not(Vec2Near({3, 5}, .5))); + EXPECT_THAT((Vec2{5, 6}), Not(Vec2Near({4, 6}, .5))); +} + +TEST(TypesTest, Vec2Magnitude) { + EXPECT_FLOAT_EQ((Vec2{1, 1}).Magnitude(), std::sqrt(2)); + EXPECT_FLOAT_EQ((Vec2{-3, 4}).Magnitude(), 5); + EXPECT_FLOAT_EQ((Vec2{0, 0}).Magnitude(), 0); + EXPECT_FLOAT_EQ((Vec2{0, 17}).Magnitude(), 17); +} + +TEST(TypesTest, Vec2Addition) { + Vec2 a{3, 0}; + Vec2 b{-1, .3}; + Vec2 c{2.7, 4}; + + EXPECT_THAT(a + b, Vec2Eq({2, .3})); + EXPECT_THAT(a + c, Vec2Eq({5.7, 4})); + EXPECT_THAT(b + c, Vec2Eq({1.7, 4.3})); +} + +TEST(TypesTest, Vec2Subtraction) { + Vec2 a{0, -2}; + Vec2 b{.5, 19}; + Vec2 c{1.1, -3.4}; + + EXPECT_THAT(a - b, Vec2Eq({-.5, -21})); + EXPECT_THAT(a - c, Vec2Eq({-1.1, 1.4})); + EXPECT_THAT(b - c, Vec2Eq({-.6, 22.4})); +} + +TEST(TypesTest, Vec2ScalarMultiplication) { + Vec2 a{.7, -3}; + Vec2 b{3, 5}; + + EXPECT_THAT(a * 2, Vec2Eq({1.4, -6})); + EXPECT_THAT(.1 * a, Vec2Eq({.07, -.3})); + EXPECT_THAT(b * -.3, Vec2Eq({-.9, -1.5})); + EXPECT_THAT(4 * b, Vec2({12, 20})); +} + +TEST(TypesTest, Vec2ScalarDivision) { + Vec2 a{7, .9}; + Vec2 b{-4.5, -2}; + + EXPECT_THAT(a / 2, Vec2Eq({3.5, .45})); + EXPECT_THAT(a / -.1, Vec2Eq({-70, -9})); + EXPECT_THAT(b / 5, Vec2Eq({-.9, -.4})); + EXPECT_THAT(b / .2, Vec2Eq({-22.5, -10})); +} + +TEST(TypesTest, Vec2AddAssign) { + Vec2 a{1, 2}; + a += {.x = 3, .y = -1}; + EXPECT_THAT(a, Vec2Eq({4, 1})); + a += {.x = -.5, .y = 2}; + EXPECT_THAT(a, Vec2Eq({3.5, 3})); +} + +TEST(TypesTest, Vec2SubractAssign) { + Vec2 a{1, 2}; + a -= {.x = 3, .y = -1}; + EXPECT_THAT(a, Vec2Eq({-2, 3})); + a -= {.x = -.5, .y = 2}; + EXPECT_THAT(a, Vec2Eq({-1.5, 1})); +} + +TEST(TypesTest, Vec2ScalarMultiplyAssign) { + Vec2 a{1, 2}; + a *= 2; + EXPECT_THAT(a, Vec2Eq({2, 4})); + a *= -.4; + EXPECT_THAT(a, Vec2Eq({-.8, -1.6})); +} + +TEST(TypesTest, Vec2ScalarDivideAssign) { + Vec2 a{1, 2}; + a /= 2; + EXPECT_THAT(a, Vec2Eq({.5, 1})); + a /= -.4; + EXPECT_THAT(a, Vec2Eq({-1.25, -2.5})); +} + +TEST(TypesTest, Vec2Stream) { + std::stringstream s; + s << Vec2{3.5, -2.7}; + EXPECT_EQ(s.str(), "(3.5, -2.7)"); +} + +TEST(TypesTest, DurationArithmetic) { + EXPECT_EQ(Duration(1) + Duration(2), Duration(3)); + EXPECT_EQ(Duration(6) - Duration(1), Duration(5)); + EXPECT_EQ(Duration(3) * 2, Duration(6)); + EXPECT_EQ(.5 * Duration(7), Duration(3.5)); + EXPECT_EQ(Duration(12) / 4, Duration(3)); +} + +TEST(TypesTest, DurationArithmeticAssignment) { + Duration d{5}; + d += Duration(2); + EXPECT_EQ(d, Duration(7)); + d -= Duration(3); + EXPECT_EQ(d, Duration(4)); + d *= 5; + EXPECT_EQ(d, Duration(20)); + d /= 2; + EXPECT_EQ(d, Duration(10)); +} + +TEST(TypesTest, DurationComparison) { + EXPECT_EQ(Duration(0), Duration(0)); + EXPECT_NE(Duration(0), Duration(1)); + EXPECT_LT(Duration(1), Duration(2)); + EXPECT_LE(Duration(4), Duration(5)); + EXPECT_LE(Duration(2), Duration(2)); + EXPECT_GT(Duration(10), Duration(9)); + EXPECT_GE(Duration(7), Duration(5)); + EXPECT_GE(Duration(5), Duration(5)); +} + +TEST(TypesTest, TimeArithmetic) { + EXPECT_EQ(Time(5) + Duration(1), Time(6)); + EXPECT_EQ(Duration(7) + Time(12), Time(19)); + EXPECT_EQ(Time(23) - Duration(5), Time(18)); + EXPECT_EQ(Time(35) - Time(7), Duration(28)); +} + +TEST(TypesTest, TimeArithmeticAssignment) { + Time t{20}; + t += Duration(10); + EXPECT_EQ(t, Time(30)); + t -= Duration(24); + EXPECT_EQ(t, Time(6)); +} + +TEST(TypesTest, TimeComparison) { + EXPECT_EQ(Time(0), Time(0)); + EXPECT_NE(Time(0), Time(1)); + EXPECT_LT(Time(1), Time(2)); + EXPECT_LE(Time(4), Time(5)); + EXPECT_LE(Time(2), Time(2)); + EXPECT_GT(Time(10), Time(9)); + EXPECT_GE(Time(7), Time(5)); + EXPECT_GE(Time(5), Time(5)); +} + +TEST(TypesTest, InputEquality) { + const Input kBaseline{}; + EXPECT_EQ(kBaseline, Input()); + EXPECT_NE(kBaseline, Input{.event_type = Input::EventType::kMove}); + EXPECT_NE(kBaseline, (Input{.position = {1, -1}})); + EXPECT_NE(kBaseline, Input{.time = Time(1)}); + EXPECT_NE(kBaseline, Input{.pressure = .5}); + EXPECT_NE(kBaseline, Input{.tilt = .2}); + EXPECT_NE(kBaseline, Input{.orientation = .7}); +} + +} // namespace +} // namespace stroke_model +} // namespace ink |