diff --git a/albatross/evaluate.h b/albatross/evaluate.h index 5280a1e1..5fd29d53 100644 --- a/albatross/evaluate.h +++ b/albatross/evaluate.h @@ -56,7 +56,7 @@ negative_log_likelihood(const Eigen::VectorXd &deviation, const double rank = static_cast(diag.size()); const double mahalanobis = deviation.dot(ldlt.solve(deviation)); const double log_det = log_sum(diag); - return -0.5 * (log_det + mahalanobis + rank * log(2 * M_PI)); + return 0.5 * (log_det + mahalanobis + rank * log(2 * M_PI)); } /* diff --git a/albatross/models/ransac.h b/albatross/models/ransac.h index d3073ecb..86db5b3a 100644 --- a/albatross/models/ransac.h +++ b/albatross/models/ransac.h @@ -101,10 +101,9 @@ ransac(const typename RansacFunctions::Fitter &fitter, inliers.push_back(test_ind); } } - // If there is enough agreement, consider this random set of inliers // as a candidate model. - if (inliers.size() > min_inliers) { + if (inliers.size() + random_sample_size >= min_inliers) { const auto inlier_inds = concatenate_subset_of_groups(inliers, indexer); ref_inds.insert(ref_inds.end(), inlier_inds.begin(), inlier_inds.end()); std::sort(ref_inds.begin(), ref_inds.end()); diff --git a/tests/test_evaluate.cc b/tests/test_evaluate.cc index d51ad17e..969de90d 100644 --- a/tests/test_evaluate.cc +++ b/tests/test_evaluate.cc @@ -44,10 +44,17 @@ TEST(test_evaluate, test_negative_log_likelihood) { cov << 1., 0.9, 0.8, 0.9, 1., 0.9, 0.8, 0.9, 1.; const auto nll = albatross::negative_log_likelihood(x, cov); - EXPECT_NEAR(nll, -6.0946974293510134, 1e-6); + EXPECT_NEAR(nll, 6.0946974293510134, 1e-6); const auto ldlt_nll = albatross::negative_log_likelihood(x, cov.ldlt()); EXPECT_NEAR(nll, ldlt_nll, 1e-6); + + const DiagonalMatrixXd diagonal_matrix = cov.diagonal().asDiagonal(); + const Eigen::MatrixXd dense_diagonal = diagonal_matrix.toDenseMatrix(); + const auto diag_nll = albatross::negative_log_likelihood(x, diagonal_matrix); + const auto dense_diag_nll = + albatross::negative_log_likelihood(x, dense_diagonal); + EXPECT_NEAR(diag_nll, dense_diag_nll, 1e-6); } TEST_F(LinearRegressionTest, test_leave_one_out) { @@ -67,27 +74,6 @@ TEST_F(LinearRegressionTest, test_leave_one_out) { EXPECT_LT(in_sample_rmse, out_of_sample_rmse); } -// Group values by interval, but return keys that once sorted won't be -// in order -std::string group_by_interval(const double &x) { - if (x <= 3) { - return "2"; - } else if (x <= 6) { - return "3"; - } else { - return "1"; - } -} - -bool is_monotonic_increasing(const Eigen::VectorXd &x) { - for (Eigen::Index i = 0; i < x.size() - 1; i++) { - if (x[i + 1] - x[i] <= 0.) { - return false; - } - } - return true; -} - TEST_F(LinearRegressionTest, test_cross_validated_predict) { const auto folds = leave_one_group_out(dataset_, group_by_interval); diff --git a/tests/test_outlier.cc b/tests/test_outlier.cc index 3554d2e6..f45dd3aa 100644 --- a/tests/test_outlier.cc +++ b/tests/test_outlier.cc @@ -40,6 +40,30 @@ TEST(test_outlier, test_ransac) { modified.features.end()); } +// Group values by interval, but return keys that once sorted won't be +// in order +std::string group_by_modulo(const double &x) { + const int x_int = static_cast(x); + return std::to_string(x_int % 4); +} + +TEST(test_outlier, test_ransac_groups) { + auto dataset = make_toy_linear_data(); + const auto model_ptr = toy_gaussian_process(); + + EvaluationMetric nll = + albatross::evaluation_metrics::negative_log_likelihood; + + dataset.targets.mean[5] = -300.; + + const auto fold_indexer = + leave_one_group_out_indexer(dataset, group_by_modulo); + const auto modified = + ransac(dataset, fold_indexer, model_ptr.get(), nll, 0., 1, 1, 20); + + EXPECT_LE(modified.features.size(), dataset.features.size()); +} + TEST(test_outlier, test_ransac_gp) { auto dataset = make_toy_linear_data(); diff --git a/tests/test_utils.h b/tests/test_utils.h index 3364ff36..2df48716 100644 --- a/tests/test_utils.h +++ b/tests/test_utils.h @@ -282,6 +282,27 @@ inline auto random_spherical_points(std::size_t n = 10, double radius = 1., return points; } +// Group values by interval, but return keys that once sorted won't be +// in order +inline std::string group_by_interval(const double &x) { + if (x <= 3) { + return "2"; + } else if (x <= 6) { + return "3"; + } else { + return "1"; + } +} + +inline bool is_monotonic_increasing(const Eigen::VectorXd &x) { + for (Eigen::Index i = 0; i < x.size() - 1; i++) { + if (x[i + 1] - x[i] <= 0.) { + return false; + } + } + return true; +} + } // namespace albatross #endif