aboutsummaryrefslogtreecommitdiff
path: root/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc')
-rw-r--r--ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc215
1 files changed, 215 insertions, 0 deletions
diff --git a/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc b/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc
new file mode 100644
index 0000000..14babc7
--- /dev/null
+++ b/ink_stroke_modeler/internal/prediction/kalman_predictor_test.cc
@@ -0,0 +1,215 @@
+// Copyright 2022 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ink_stroke_modeler/internal/prediction/kalman_predictor.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "absl/types/optional.h"
+#include "ink_stroke_modeler/internal/prediction/input_predictor.h"
+#include "ink_stroke_modeler/internal/type_matchers.h"
+#include "ink_stroke_modeler/params.h"
+#include "ink_stroke_modeler/types.h"
+
+namespace ink {
+namespace stroke_model {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Matcher;
+using ::testing::Optional;
+
+constexpr float kTol = 1e-4;
+
+const KalmanPredictorParams kDefaultKalmanParams{
+ .process_noise = .00026458,
+ .measurement_noise = .026458,
+ .min_catchup_velocity = .01,
+ .prediction_interval{1. / 60},
+ .confidence_params{.max_estimation_distance = .04,
+ .min_travel_speed = 3,
+ .max_travel_speed = 15,
+ .max_linear_deviation = .2}};
+constexpr SamplingParams kDefaultSamplingParams{
+ .min_output_rate = 180, .end_of_stroke_stopping_distance = .001};
+
+// Matcher for the State. Note that, because each of position,
+// velocity, acceleration, and jerk are divided by increasing powers of the time
+// delta, the values grow exponentially. As such, this uses a relative tolerance
+// unless one of the arguments is exactly zero.
+MATCHER_P5(StateNearMatcher, position, velocity, acceleration, jerk,
+ relative_tol, "") {
+ auto within_relative_tol = [](float lhs, float rhs, float tol) {
+ if (lhs == 0 || rhs == 0) {
+ return std::abs(lhs - rhs) < tol;
+ }
+ return std::abs(lhs / rhs - 1.f) < tol;
+ };
+
+ if (within_relative_tol(position.x, arg.position.x, relative_tol) &&
+ within_relative_tol(position.y, arg.position.y, relative_tol) &&
+ within_relative_tol(velocity.x, arg.velocity.x, relative_tol) &&
+ within_relative_tol(velocity.y, arg.velocity.y, relative_tol) &&
+ within_relative_tol(acceleration.x, arg.acceleration.x, relative_tol) &&
+ within_relative_tol(acceleration.y, arg.acceleration.y, relative_tol) &&
+ within_relative_tol(jerk.x, arg.jerk.x, relative_tol) &&
+ within_relative_tol(jerk.y, arg.jerk.y, relative_tol)) {
+ return true;
+ }
+
+ *result_listener << "\n expected:" //
+ << "\n p = " << position //
+ << "\n v = " << velocity //
+ << "\n a = " << acceleration //
+ << "\n j = " << jerk //
+ << "\n actual:" //
+ << "\n p = " << arg.position //
+ << "\n v = " << arg.velocity //
+ << "\n a = " << arg.acceleration //
+ << "\n j = " << arg.jerk;
+ return false;
+}
+
+// Wrapping the matcher in a function allows the compiler to perform template
+// deduction, so we can specify arguments as initializer lists.
+Matcher<KalmanPredictor::State> StateNear(Vec2 position, Vec2 velocity,
+ Vec2 acceleration, Vec2 jerk,
+ float tolerance) {
+ return StateNearMatcher(position, velocity, acceleration, jerk, tolerance);
+}
+
+TEST(KalmanPredictorTest, EmptyPrediction) {
+ KalmanPredictor predictor{kDefaultKalmanParams, kDefaultSamplingParams};
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{4, 3}, {2, -4}, Time{3}}).empty());
+
+ predictor.Update({1, 3}, Time{4});
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{1, 3}, {0, 0}, Time{3.1}}).empty());
+}
+
+TEST(KalmanPredictorTest, TypicalCase) {
+ KalmanPredictor predictor{kDefaultKalmanParams, kDefaultSamplingParams};
+
+ predictor.Update({0, 0}, Time{0});
+ predictor.Update({.1, 0}, Time{.01});
+ predictor.Update({.2, 0}, Time{.02});
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{4, 3}, {2, -4}, Time{3}}).empty());
+
+ predictor.Update({.3, 0}, Time{.03});
+ EXPECT_THAT(predictor.GetEstimatedState(),
+ Optional(StateNear({.30078, 0}, {13.584, 0}, {-66.806, 0},
+ {-3382.8, 0}, kTol)));
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{.2, 0}, {10, 0}, Time{.03}}),
+ ElementsAre(TipStateNear({{.2454, 0}, {7.7094, 0}, Time{.0356}}, kTol),
+ TipStateNear({{.3008, 0}, {13.5837, 0}, Time{.0411}}, kTol),
+ TipStateNear({{.3751, 0}, {13.1604, 0}, Time{.0467}}, kTol)));
+
+ predictor.Update({.5, .1}, Time{.04});
+ EXPECT_THAT(predictor.GetEstimatedState(),
+ Optional(StateNear({.49705, .097146}, {28.217, 16.732},
+ {671.91, 813.82}, {4454.3, 6998.2}, kTol)));
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{.3, 0}, {10, 0}, Time{.04}}),
+ ElementsAre(
+ TipStateNear({{.3732, .0253}, {17.047, 8.9317}, Time{.0456}}, kTol),
+ TipStateNear({{.497, .0971}, {28.2172, 16.7319}, Time{.0511}}, kTol),
+ TipStateNear({{.6643, .2029}, {32.0188, 21.3611}, Time{.0567}},
+ kTol)));
+}
+
+TEST(KalmanPredictorTest, AlternateParams) {
+ auto kalman_params = kDefaultKalmanParams;
+ auto sampling_params = kDefaultSamplingParams;
+ kalman_params.prediction_interval = Duration(.025);
+ sampling_params.min_output_rate = 200;
+ KalmanPredictor predictor{kalman_params, sampling_params};
+
+ predictor.Update({2, 5}, Time{1});
+ predictor.Update({2.2, 4.9}, Time{1.02});
+ predictor.Update({2.3, 4.7}, Time{1.04});
+ predictor.Update({2.3, 4.4}, Time{1.06});
+ EXPECT_THAT(
+ predictor.GetEstimatedState(),
+ Optional(StateNear({2.3016, 4.3992}, {-3.9981, -24.374},
+ {-338.22, -288.12}, {-1852.9, -584.31}, kTol)));
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{2.25, 4.75}, {1, -20}, Time{1.06}}),
+ ElementsAre(
+ TipStateNear({{2.27, 4.6417}, {5.917, -23.0547}, Time{1.065}}, kTol),
+ TipStateNear({{2.2982, 4.5221}, {4.251, -24.5126}, Time{1.07}}, kTol),
+ TipStateNear({{2.3016, 4.3992}, {-3.9981, -24.3736}, Time{1.075}},
+ kTol),
+ TipStateNear({{2.2773, 4.2738}, {-5.7123, -25.8215}, Time{1.08}},
+ kTol)));
+
+ predictor.Update({2.2, 4.2}, Time{1.08});
+ EXPECT_THAT(predictor.GetEstimatedState(),
+ Optional(StateNear({2.1987, 4.1933}, {-11.457, -11.953},
+ {-328.01, 185.32}, {-1133.8, 1569.8}, kTol)));
+ EXPECT_THAT(
+ predictor.ConstructPrediction({{2.25, 4.5}, {-1, -20}, Time{1.08}}),
+ ElementsAre(
+ TipStateNear({{2.2499, 4.407}, {.5082, -17.2661}, Time{1.085}}, kTol),
+ TipStateNear({{2.2505, 4.3265}, {-.7319, -15.0137}, Time{1.09}},
+ kTol),
+ TipStateNear({{2.238, 4.2561}, {-4.7203, -13.2427}, Time{1.095}},
+ kTol),
+ TipStateNear({{2.1987, 4.1933}, {-11.4569, -11.9531}, Time{1.1}},
+ kTol),
+ TipStateNear({{2.1373, 4.1359}, {-13.1112, -11.0068}, Time{1.105}},
+ kTol)));
+}
+
+TEST(KalmanPredictorTest, Reset) {
+ KalmanPredictor predictor{kDefaultKalmanParams, kDefaultSamplingParams};
+
+ predictor.Update({4, -4}, Time{6});
+ predictor.Update({-6, 9}, Time{6.03});
+ predictor.Update({10, 5}, Time{6.06});
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{6.06}}).empty());
+
+ predictor.Update({2, 4}, Time{6.09});
+ EXPECT_NE(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_FALSE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{6.06}}).empty());
+
+ predictor.Reset();
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{6.09}}).empty());
+
+ predictor.Update({-9, 3}, Time{2});
+ predictor.Update({-6, -1}, Time{2.1});
+ predictor.Update({6, -6}, Time{2.2});
+ EXPECT_EQ(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_TRUE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{2.2}}).empty());
+
+ predictor.Update({3, 6}, Time{2.3});
+ EXPECT_NE(predictor.GetEstimatedState(), absl::nullopt);
+ EXPECT_FALSE(
+ predictor.ConstructPrediction({{1, 1}, {6, -3}, Time{2.3}}).empty());
+}
+
+} // namespace
+} // namespace stroke_model
+} // namespace ink