Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion albatross/evaluate.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ negative_log_likelihood(const Eigen::VectorXd &deviation,
const double rank = static_cast<double>(diag.size());
const double mahalanobis = deviation.dot(ldlt.solve(deviation));
const double log_det = log_sum(diag);
return -0.5 * (log_det + mahalanobis + rank * log(2 * M_PI));
return 0.5 * (log_det + mahalanobis + rank * log(2 * M_PI));
}

/*
Expand Down
3 changes: 1 addition & 2 deletions albatross/models/ransac.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,9 @@ ransac(const typename RansacFunctions<FitType>::Fitter &fitter,
inliers.push_back(test_ind);
}
}

// If there is enough agreement, consider this random set of inliers
// as a candidate model.
if (inliers.size() > min_inliers) {
if (inliers.size() + random_sample_size >= min_inliers) {
const auto inlier_inds = concatenate_subset_of_groups(inliers, indexer);
ref_inds.insert(ref_inds.end(), inlier_inds.begin(), inlier_inds.end());
std::sort(ref_inds.begin(), ref_inds.end());
Expand Down
30 changes: 8 additions & 22 deletions tests/test_evaluate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,17 @@ TEST(test_evaluate, test_negative_log_likelihood) {
cov << 1., 0.9, 0.8, 0.9, 1., 0.9, 0.8, 0.9, 1.;

const auto nll = albatross::negative_log_likelihood(x, cov);
EXPECT_NEAR(nll, -6.0946974293510134, 1e-6);
EXPECT_NEAR(nll, 6.0946974293510134, 1e-6);

const auto ldlt_nll = albatross::negative_log_likelihood(x, cov.ldlt());
EXPECT_NEAR(nll, ldlt_nll, 1e-6);

const DiagonalMatrixXd diagonal_matrix = cov.diagonal().asDiagonal();
const Eigen::MatrixXd dense_diagonal = diagonal_matrix.toDenseMatrix();
const auto diag_nll = albatross::negative_log_likelihood(x, diagonal_matrix);
const auto dense_diag_nll =
albatross::negative_log_likelihood(x, dense_diagonal);
EXPECT_NEAR(diag_nll, dense_diag_nll, 1e-6);
}

TEST_F(LinearRegressionTest, test_leave_one_out) {
Expand All @@ -67,27 +74,6 @@ TEST_F(LinearRegressionTest, test_leave_one_out) {
EXPECT_LT(in_sample_rmse, out_of_sample_rmse);
}

// Group values by interval, but return keys that once sorted won't be
// in order
std::string group_by_interval(const double &x) {
if (x <= 3) {
return "2";
} else if (x <= 6) {
return "3";
} else {
return "1";
}
}

bool is_monotonic_increasing(const Eigen::VectorXd &x) {
for (Eigen::Index i = 0; i < x.size() - 1; i++) {
if (x[i + 1] - x[i] <= 0.) {
return false;
}
}
return true;
}

TEST_F(LinearRegressionTest, test_cross_validated_predict) {
const auto folds = leave_one_group_out<double>(dataset_, group_by_interval);

Expand Down
24 changes: 24 additions & 0 deletions tests/test_outlier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ TEST(test_outlier, test_ransac) {
modified.features.end());
}

// Group values by interval, but return keys that once sorted won't be
// in order
std::string group_by_modulo(const double &x) {
const int x_int = static_cast<int>(x);
return std::to_string(x_int % 4);
}

TEST(test_outlier, test_ransac_groups) {
auto dataset = make_toy_linear_data();
const auto model_ptr = toy_gaussian_process();

EvaluationMetric<JointDistribution> nll =
albatross::evaluation_metrics::negative_log_likelihood;

dataset.targets.mean[5] = -300.;

const auto fold_indexer =
leave_one_group_out_indexer<double>(dataset, group_by_modulo);
const auto modified =
ransac(dataset, fold_indexer, model_ptr.get(), nll, 0., 1, 1, 20);

EXPECT_LE(modified.features.size(), dataset.features.size());
}

TEST(test_outlier, test_ransac_gp) {
auto dataset = make_toy_linear_data();

Expand Down
21 changes: 21 additions & 0 deletions tests/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,27 @@ inline auto random_spherical_points(std::size_t n = 10, double radius = 1.,
return points;
}

// Group values by interval, but return keys that once sorted won't be
// in order
inline std::string group_by_interval(const double &x) {
if (x <= 3) {
return "2";
} else if (x <= 6) {
return "3";
} else {
return "1";
}
}

inline bool is_monotonic_increasing(const Eigen::VectorXd &x) {
for (Eigen::Index i = 0; i < x.size() - 1; i++) {
if (x[i + 1] - x[i] <= 0.) {
return false;
}
}
return true;
}

} // namespace albatross

#endif