1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ink_stroke_modeler/internal/prediction/kalman_filter/kalman_filter.h"
#include "ink_stroke_modeler/internal/prediction/kalman_filter/matrix.h"
namespace ink {
namespace stroke_model {
KalmanFilter::KalmanFilter(const Matrix4& state_transition,
const Matrix4& process_noise_covariance,
const Vec4& measurement_vector,
double measurement_noise_variance,
int min_stable_iteration)
: state_transition_matrix_(state_transition),
process_noise_covariance_matrix_(process_noise_covariance),
measurement_vector_(measurement_vector),
measurement_noise_variance_(measurement_noise_variance),
min_stable_iteration_(min_stable_iteration),
iter_num_(0) {}
void KalmanFilter::Predict() {
// X = F * X
state_estimation_ = state_transition_matrix_ * state_estimation_;
// P = F * P * F' + Q
error_covariance_matrix_ = state_transition_matrix_ *
error_covariance_matrix_ *
state_transition_matrix_.Transpose() +
process_noise_covariance_matrix_;
}
void KalmanFilter::Update(double observation) {
if (iter_num_++ == 0) {
// We only update the state estimation in the first iteration.
state_estimation_[0] = observation;
return;
}
Predict();
// Y = z - H * X
double y = observation - DotProduct(measurement_vector_, state_estimation_);
// S = H * P * H' + R
double S = DotProduct(measurement_vector_ * error_covariance_matrix_,
measurement_vector_) +
measurement_noise_variance_;
// K = P * H' * inv(S)
Vec4 kalman_gain = measurement_vector_ * error_covariance_matrix_ / S;
// X = X + K * Y
state_estimation_ = state_estimation_ + kalman_gain * y;
// I_HK = eye(P) - K * H
Matrix4 I_KH = Matrix4() - OuterProduct(kalman_gain, measurement_vector_);
// P = I_KH * P * I_KH' + K * R * K'
error_covariance_matrix_ =
I_KH * error_covariance_matrix_ * I_KH.Transpose() +
OuterProduct(kalman_gain, kalman_gain) * measurement_noise_variance_;
}
void KalmanFilter::Reset() {
state_estimation_ = {0, 0, 0, 0};
error_covariance_matrix_ = Matrix4(); // identity
iter_num_ = 0;
}
} // namespace stroke_model
} // namespace ink
|