forked from vancegroup/eigen-kalman
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathkfExampleTune.cpp
157 lines (123 loc) · 4.01 KB
/
kfExampleTune.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
/**
@file kfExampleTune.cpp
@brief
@date 2010
@author
Ryan Pavlik
<rpavlik@iastate.edu> and <abiryan@ryand.net>
http://academic.cleardefinition.com/
Iowa State University Virtual Reality Applications Center
Human-Computer Interaction Graduate Program
*/
// Internal Includes
#include "generateData.h"
// Library/third-party includes
#include <eigenkf/KalmanFilter.h>
#include <Eigen/Eigen>
// Standard includes
#include <iostream>
#include <iomanip>
#include <cmath>
#include <ctime>
using namespace eigenkf;
#define COL 10
#define USE_GNUPLOT
const double INTERVALS = 20;
const double dt = 0.5;
double runSimulation(std::vector<StatePair> const& data, const double measurementVariance, const double processModelVariance) {
/// We want a simple 2d state
typedef SimpleState<2> state_t;
/// Our process model is the simplest possible: doesn't change the mean
typedef ConstantProcess<2, state_t> process_t;
/// Create a kalman filter instance with our chosen state and process types
KalmanFilter<state_t, process_t> kf;
/// Set our process model's variance
kf.processModel.sigma = state_t::VecState::Constant(processModelVariance);
double sumSquaredError = 0;
for (unsigned int i = 0; i < data.size(); ++i) {
/// Predict step: Update Kalman filter by predicting ahead by dt
kf.predict(dt);
/// "take a measurement" - in this case, noisify the actual measurement
AbsoluteMeasurement<state_t> meas;
meas.measurement = data[i].second;
meas.covariance = Eigen::Vector2d::Constant(measurementVariance).asDiagonal();
/// Correct step: incorporate information from measurement into KF's state
kf.correct(meas);
Eigen::Vector2d pos(data[i].first);
double squaredError = (pos[0] - kf.state.x[0]) * (pos[0] - kf.state.x[0]);
sumSquaredError += squaredError;
}
return sumSquaredError;
}
void runWindow(std::vector<StatePair> const& data, double lowMVar, double highMVar, double lowPVar, double highPVar, int recursionsRemaining = 0) {
std::stringstream ss;
#ifdef USE_GNUPLOT
static bool doneOutput = false;
std::ostream & output(doneOutput ? ss : std::cout);
std::string separator(" ");
doneOutput = true;
#else
std::ostream & output( (recursionsRemaining == 0) ? std::cout : ss);
std::string separator(",");
#endif
double minErr = 10000;
double bestMVar = lowMVar;
double bestPVar = lowPVar;
#ifndef USE_GNUPLOT
// Output column headers
output << separator;
for (double pVar = lowPVar; pVar < highPVar; pVar += (highPVar - lowPVar) / INTERVALS) {
output << pVar << separator;
}
output << "process variance" << std::endl;
#endif
const double dMVar = (highMVar - lowMVar) / INTERVALS;
const double dPVar = (highPVar - lowPVar) / INTERVALS;
for (double mVar = lowMVar; mVar < highMVar; mVar += dMVar) {
#ifndef USE_GNUPLOT
/// row headers
output << mVar;
#endif
for (double pVar = lowPVar; pVar < highPVar; pVar += dPVar) {
double err = runSimulation(data, mVar, pVar);
#ifdef USE_GNUPLOT
//std::cout << mVar << " " << pVar << " " << err << std::endl;
#endif
output << separator << err;
if (err < minErr) {
bestMVar = mVar;
bestPVar = pVar;
minErr = err;
}
}
output << std::endl;
}
output << "measurement variance" << std::endl;
std::cerr << std::endl;
std::cerr << "Best found in the grid of parameters: " << std::endl;
std::cerr << "Measurement Variance: " << bestMVar << std::endl;
std::cerr << "Process variance: " << bestPVar << std::endl;
std::cerr << "Sum squared error: " << minErr << std::endl;
std::cerr << std::endl;
if (recursionsRemaining > 0) {
std::cerr << "Recursing..." << std::endl << std::endl;
runWindow(data,
bestMVar - dMVar,
bestMVar + dMVar,
bestPVar - dPVar,
bestPVar + dPVar,
recursionsRemaining -1);
}
}
int main(int /*argc*/, char * /*argv*/[]) {
std::vector<StatePair> data = generateLineData();
double lowMVar = 0;
double highMVar = 9.0;
double lowPVar = 0;
double highPVar = 11;
#ifdef USE_GNUPLOT
//std::cout << "mVar pVar err" << std::endl;
#endif
runWindow(data, lowMVar, highMVar, lowPVar, highPVar, 3);
return 0;
}