Skip to content

Commit

Permalink
Add tests to ensure predict, predict_marginal and predict_mean
Browse files Browse the repository at this point in the history
are consistent for Gaussian Processes.
  • Loading branch information
akleeman committed Jul 3, 2018
1 parent f96fa21 commit 77d8b6a
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 12 deletions.
10 changes: 9 additions & 1 deletion albatross/core/distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,17 @@ template <typename CovarianceType> struct Distribution {
}
};

// A JointDistribution has a dense covariance matrix, which
// contains the covariance between each variable and all others.
using JointDistribution = Distribution<Eigen::MatrixXd>;

// We use a wrapper around DiagonalMatrix in order to make
// the resulting distribution serializable
using DiagonalMatrixXd =
Eigen::SerializableDiagonalMatrix<double, Eigen::Dynamic>;
using JointDistribution = Distribution<Eigen::MatrixXd>;
// A MarginalDistribution has only a digaonal covariance
// matrix, so in turn only describes the variance of each
// variable independent of all others.
using MarginalDistribution = Distribution<DiagonalMatrixXd>;

template <typename CovarianceType, typename SizeType>
Expand Down
5 changes: 2 additions & 3 deletions albatross/evaluate.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ negative_log_likelihood(const Eigen::VectorXd &deviation,
*/
namespace evaluation_metrics {

static inline double
root_mean_square_error(const JointDistribution &prediction,
const MarginalDistribution &truth) {
static inline double root_mean_square_error(const JointDistribution &prediction,
const MarginalDistribution &truth) {
const Eigen::VectorXd error = prediction.mean - truth.mean;
double mse = error.dot(error) / static_cast<double>(error.size());
return sqrt(mse);
Expand Down
11 changes: 5 additions & 6 deletions albatross/models/gp.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ class GaussianProcessRegression
}

protected:
FitType serializable_fit_(const std::vector<FeatureType> &features,
const MarginalDistribution &targets) const override {
FitType
serializable_fit_(const std::vector<FeatureType> &features,
const MarginalDistribution &targets) const override {
Eigen::MatrixXd cov = symmetric_covariance(covariance_function_, features);
FitType model_fit;
model_fit.train_features = features;
Expand All @@ -145,7 +146,6 @@ class GaussianProcessRegression
predict_(const std::vector<FeatureType> &features) const override {
const auto cross_cov = asymmetric_covariance(
covariance_function_, features, this->model_fit_.train_features);
// Then we can use the information vector to determine the posterior
const Eigen::VectorXd pred = cross_cov * this->model_fit_.information;
Eigen::MatrixXd pred_cov =
symmetric_covariance(covariance_function_, features);
Expand All @@ -158,9 +158,9 @@ class GaussianProcessRegression
predict_marginal_(const std::vector<FeatureType> &features) const override {
const auto cross_cov = asymmetric_covariance(
covariance_function_, features, this->model_fit_.train_features);
// Then we can use the information vector to determine the posterior
const Eigen::VectorXd pred = cross_cov * this->model_fit_.information;

// Here we efficiently only compute the diagonal of the posterior
// covariance matrix.
auto ldlt = this->model_fit_.train_ldlt;
Eigen::MatrixXd explained = ldlt.solve(cross_cov.transpose());
Eigen::VectorXd marginal_variance =
Expand All @@ -176,7 +176,6 @@ class GaussianProcessRegression
predict_mean_(const std::vector<FeatureType> &features) const override {
const auto cross_cov = asymmetric_covariance(
covariance_function_, features, this->model_fit_.train_features);
// Then we can use the information vector to determine the posterior
const Eigen::VectorXd pred = cross_cov * this->model_fit_.information;
return pred;
}
Expand Down
18 changes: 18 additions & 0 deletions tests/test_models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,22 @@ TEST(test_models, test_with_target_distribution) {

EXPECT_LE(scores.mean(), scores_without_variance.mean());
}

TEST(test_models, test_predict_variants) {
auto dataset = make_heteroscedastic_toy_linear_data();

auto model = MakeGaussianProcess().create();
model->fit(dataset);
const auto joint_predictions = model->predict(dataset.features);
const auto marginal_predictions = model->predict_marginal(dataset.features);
const auto mean_predictions = model->predict_mean(dataset.features);

for (Eigen::Index i = 0; i < joint_predictions.mean.size(); i++) {
EXPECT_NEAR(joint_predictions.mean[i], mean_predictions[i], 1e-6);
EXPECT_NEAR(joint_predictions.mean[i], marginal_predictions.mean[i], 1e-6);
EXPECT_NEAR(joint_predictions.covariance(i, i),
marginal_predictions.covariance.diagonal()[i], 1e-6);
}
}

} // namespace albatross
5 changes: 3 additions & 2 deletions tests/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ class MockModel : public SerializableRegressionModel<MockPredictor, MockFit> {

protected:
// builds the map from int to value
MockFit serializable_fit_(const std::vector<MockPredictor> &features,
const MarginalDistribution &targets) const override {
MockFit
serializable_fit_(const std::vector<MockPredictor> &features,
const MarginalDistribution &targets) const override {
int n = static_cast<int>(features.size());
Eigen::VectorXd predictions(n);

Expand Down

0 comments on commit 77d8b6a

Please sign in to comment.