Skip to content

Commit 4fd5a10

Browse files
authored
Merge pull request #55 from akleeman/ransac_testing
Add additional tests of ransac with groups.
2 parents c01c727 + 909fa5c commit 4fd5a10

File tree

5 files changed

+55
-25
lines changed

5 files changed

+55
-25
lines changed

albatross/evaluate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ negative_log_likelihood(const Eigen::VectorXd &deviation,
5656
const double rank = static_cast<double>(diag.size());
5757
const double mahalanobis = deviation.dot(ldlt.solve(deviation));
5858
const double log_det = log_sum(diag);
59-
return -0.5 * (log_det + mahalanobis + rank * log(2 * M_PI));
59+
return 0.5 * (log_det + mahalanobis + rank * log(2 * M_PI));
6060
}
6161

6262
/*

albatross/models/ransac.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,9 @@ ransac(const typename RansacFunctions<FitType>::Fitter &fitter,
101101
inliers.push_back(test_ind);
102102
}
103103
}
104-
105104
// If there is enough agreement, consider this random set of inliers
106105
// as a candidate model.
107-
if (inliers.size() > min_inliers) {
106+
if (inliers.size() + random_sample_size >= min_inliers) {
108107
const auto inlier_inds = concatenate_subset_of_groups(inliers, indexer);
109108
ref_inds.insert(ref_inds.end(), inlier_inds.begin(), inlier_inds.end());
110109
std::sort(ref_inds.begin(), ref_inds.end());

tests/test_evaluate.cc

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,17 @@ TEST(test_evaluate, test_negative_log_likelihood) {
4444
cov << 1., 0.9, 0.8, 0.9, 1., 0.9, 0.8, 0.9, 1.;
4545

4646
const auto nll = albatross::negative_log_likelihood(x, cov);
47-
EXPECT_NEAR(nll, -6.0946974293510134, 1e-6);
47+
EXPECT_NEAR(nll, 6.0946974293510134, 1e-6);
4848

4949
const auto ldlt_nll = albatross::negative_log_likelihood(x, cov.ldlt());
5050
EXPECT_NEAR(nll, ldlt_nll, 1e-6);
51+
52+
const DiagonalMatrixXd diagonal_matrix = cov.diagonal().asDiagonal();
53+
const Eigen::MatrixXd dense_diagonal = diagonal_matrix.toDenseMatrix();
54+
const auto diag_nll = albatross::negative_log_likelihood(x, diagonal_matrix);
55+
const auto dense_diag_nll =
56+
albatross::negative_log_likelihood(x, dense_diagonal);
57+
EXPECT_NEAR(diag_nll, dense_diag_nll, 1e-6);
5158
}
5259

5360
TEST_F(LinearRegressionTest, test_leave_one_out) {
@@ -67,27 +74,6 @@ TEST_F(LinearRegressionTest, test_leave_one_out) {
6774
EXPECT_LT(in_sample_rmse, out_of_sample_rmse);
6875
}
6976

70-
// Group values by interval, but return keys that once sorted won't be
71-
// in order
72-
std::string group_by_interval(const double &x) {
73-
if (x <= 3) {
74-
return "2";
75-
} else if (x <= 6) {
76-
return "3";
77-
} else {
78-
return "1";
79-
}
80-
}
81-
82-
bool is_monotonic_increasing(const Eigen::VectorXd &x) {
83-
for (Eigen::Index i = 0; i < x.size() - 1; i++) {
84-
if (x[i + 1] - x[i] <= 0.) {
85-
return false;
86-
}
87-
}
88-
return true;
89-
}
90-
9177
TEST_F(LinearRegressionTest, test_cross_validated_predict) {
9278
const auto folds = leave_one_group_out<double>(dataset_, group_by_interval);
9379

tests/test_outlier.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,30 @@ TEST(test_outlier, test_ransac) {
4040
modified.features.end());
4141
}
4242

43+
// Group values by interval, but return keys that once sorted won't be
44+
// in order
45+
std::string group_by_modulo(const double &x) {
46+
const int x_int = static_cast<int>(x);
47+
return std::to_string(x_int % 4);
48+
}
49+
50+
TEST(test_outlier, test_ransac_groups) {
51+
auto dataset = make_toy_linear_data();
52+
const auto model_ptr = toy_gaussian_process();
53+
54+
EvaluationMetric<JointDistribution> nll =
55+
albatross::evaluation_metrics::negative_log_likelihood;
56+
57+
dataset.targets.mean[5] = -300.;
58+
59+
const auto fold_indexer =
60+
leave_one_group_out_indexer<double>(dataset, group_by_modulo);
61+
const auto modified =
62+
ransac(dataset, fold_indexer, model_ptr.get(), nll, 0., 1, 1, 20);
63+
64+
EXPECT_LE(modified.features.size(), dataset.features.size());
65+
}
66+
4367
TEST(test_outlier, test_ransac_gp) {
4468
auto dataset = make_toy_linear_data();
4569

tests/test_utils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,27 @@ inline auto random_spherical_points(std::size_t n = 10, double radius = 1.,
282282
return points;
283283
}
284284

285+
// Group values by interval, but return keys that once sorted won't be
286+
// in order
287+
inline std::string group_by_interval(const double &x) {
288+
if (x <= 3) {
289+
return "2";
290+
} else if (x <= 6) {
291+
return "3";
292+
} else {
293+
return "1";
294+
}
295+
}
296+
297+
inline bool is_monotonic_increasing(const Eigen::VectorXd &x) {
298+
for (Eigen::Index i = 0; i < x.size() - 1; i++) {
299+
if (x[i + 1] - x[i] <= 0.) {
300+
return false;
301+
}
302+
}
303+
return true;
304+
}
305+
285306
} // namespace albatross
286307

287308
#endif

0 commit comments

Comments
 (0)