Skip to content

Commit

Permalink
Conditional serialization routines, with delayed static asserts
Browse files Browse the repository at this point in the history
to help users identify the source of serialization issues.
  • Loading branch information
akleeman committed Jul 3, 2018
1 parent 9e449d1 commit 947f0bd
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 26 deletions.
34 changes: 33 additions & 1 deletion albatross/core/distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#ifndef ALBATROSS_CORE_DISTRIBUTION_H
#define ALBATROSS_CORE_DISTRIBUTION_H

#include "cereal/cereal.hpp"
#include "core/traits.h"
#include "eigen/serializable_diagonal_matrix.h"
#include "indexing.h"
#include <Eigen/Core>
#include <iostream>
Expand Down Expand Up @@ -52,9 +55,38 @@ template <typename CovarianceType> struct Distribution {
Distribution(const Eigen::VectorXd &mean_) : mean(mean_), covariance(){};
Distribution(const Eigen::VectorXd &mean_, const CovarianceType &covariance_)
: mean(mean_), covariance(covariance_){};

/*
* If the CovarianceType is serializable, add a serialize method.
*/
template <class Archive>
typename std::enable_if<
valid_in_out_serializer<CovarianceType, Archive>::value, void>::type
serialize(Archive &archive) {
archive(cereal::make_nvp("mean", mean));
archive(cereal::make_nvp("covariance", covariance));
}

/*
* If you try to serialize a Distribution for which the covariance
* type is not serializable you'll get an error.
*/
template <class Archive>
typename std::enable_if<
!valid_in_out_serializer<CovarianceType, Archive>::value, void>::type
save(Archive &archive) {
static_assert(delay_static_assert<Archive>::value,
"In order to serialize a Distribution the corresponding "
"CovarianceType must be serializable.");
}

bool operator==(const Distribution &other) const {
return (mean == other.mean && covariance == other.covariance);
}
};

using DiagonalMatrixXd = Eigen::DiagonalMatrix<double, Eigen::Dynamic>;
using DiagonalMatrixXd =
Eigen::SerializableDiagonalMatrix<double, Eigen::Dynamic>;
using DenseDistribution = Distribution<Eigen::MatrixXd>;
using DiagonalDistribution = Distribution<DiagonalMatrixXd>;

Expand Down
31 changes: 13 additions & 18 deletions albatross/core/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,6 @@ namespace albatross {
using TargetDistribution = DiagonalDistribution;
using PredictDistribution = DenseDistribution;

template <class Archive>
void save(Archive &archive, const TargetDistribution &distribution) {
archive(cereal::make_nvp("mean", distribution.mean));
archive(cereal::make_nvp("diagonal", distribution.covariance.diagonal()));
}

template <class Archive>
void load(Archive &archive, TargetDistribution &distribution) {
Eigen::VectorXd mean;
archive(cereal::make_nvp("mean", mean));
distribution.mean = mean;

Eigen::VectorXd diagonal;
archive(cereal::make_nvp("diagonal", diagonal));
distribution.covariance = diagonal.asDiagonal();
}

/*
* A RegressionDataset holds two vectors of data, the features
* where a single feature can be any class that contains the information used
Expand All @@ -70,10 +53,22 @@ template <typename FeatureType> struct RegressionDataset {
const Eigen::VectorXd &targets_)
: RegressionDataset(features_, TargetDistribution(targets_)) {}

template <class Archive> void serialize(Archive &archive) {
template <class Archive>
typename std::enable_if<valid_in_out_serializer<FeatureType, Archive>::value,
void>::type
serialize(Archive &archive) {
archive(cereal::make_nvp("features", features));
archive(cereal::make_nvp("targets", targets));
}

template <class Archive>
typename std::enable_if<!valid_in_out_serializer<FeatureType, Archive>::value,
void>::type
serialize(Archive &archive) {
static_assert(delay_static_assert<Archive>::value,
"In order to serialize a RegressionDataset the corresponding "
"FeatureType must be serializable.");
}
};

typedef int32_t s32;
Expand Down
38 changes: 35 additions & 3 deletions albatross/core/serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef ALBATROSS_CORE_SERIALIZE_H
#define ALBATROSS_CORE_SERIALIZE_H

#include "core/traits.h"
#include <cereal/archives/json.hpp>
#include <cereal/types/polymorphic.hpp>
#include <iostream>
Expand All @@ -38,21 +39,52 @@ class SerializableRegressionModel : public RegressionModel<FeatureType> {
model_fit_ == other.get_fit());
}

// todo: enable if ModelFit is serializable.
template <class Archive> void save(Archive &archive) const {
/*
* Include save/load methods conditional on the ability to serialize
* ModelFit.
*/
template <class Archive>
typename std::enable_if<valid_output_serializer<ModelFit, Archive>::value,
void>::type
save(Archive &archive) const {
archive(cereal::make_nvp(
"model_definition",
cereal::base_class<RegressionModel<FeatureType>>(this)));
archive(cereal::make_nvp("model_fit", this->model_fit_));
}

template <class Archive> void load(Archive &archive) {
template <class Archive>
typename std::enable_if<valid_input_serializer<ModelFit, Archive>::value,
void>::type
load(Archive &archive) {
archive(cereal::make_nvp(
"model_definition",
cereal::base_class<RegressionModel<FeatureType>>(this)));
archive(cereal::make_nvp("model_fit", this->model_fit_));
}

/*
* If ModelFit does not have valid serialization methods and you attempt to
* (de)serialize a SerializableRegressionModel you'll get an error.
*/
template <class Archive>
typename std::enable_if<!valid_output_serializer<ModelFit, Archive>::value,
void>::type
save(Archive &archive) const {
static_assert(delay_static_assert<Archive>::value,
"SerializableRegressionModel requires a ModelFit type which "
"is serializable.");
}

template <class Archive>
typename std::enable_if<!valid_input_serializer<ModelFit, Archive>::value,
void>::type
load(Archive &archive) const {
static_assert(delay_static_assert<Archive>::value,
"SerializableRegressionModel requires a ModelFit type which "
"is serializable.");
}

virtual ModelFit get_fit() const { return model_fit_; }

protected:
Expand Down
51 changes: 51 additions & 0 deletions albatross/core/traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,20 @@
#ifndef ALBATROSS_CORE_MAGIC_H
#define ALBATROSS_CORE_MAGIC_H

#include "cereal/details/traits.hpp"
#include <utility>

namespace albatross {

/*
* This little trick was borrowed from cereal, you an think of it as
* a function that will always return false ... but that doesn't
* get resolved until template instantiation, which when combined
* with a static assert let's you include a static assert that
* only triggers with a particular template parameter is used.
*/
template <class T> struct delay_static_assert : std::false_type {};

/*
* This determines whether or not a class has a method defined for,
* `operator() (X x, Y y, Z z, ...)`
Expand Down Expand Up @@ -82,6 +92,47 @@ using fit_type_if_serializable =
typename enable_if_serializable<X,
typename fit_type_or_void<X>::type>::type;

/*
* The following helper functions let you inspect a type and cereal Archive
* and determine if the type has a valid serialization method for that Archive
* type.
*/
template <typename X, typename Archive> class valid_output_serializer {
template <typename T>
static typename std::enable_if<
1 == cereal::traits::detail::count_output_serializers<T, Archive>::value,
std::true_type>::type
test(int);
template <typename T> static std::false_type test(...);

public:
static constexpr bool value = decltype(test<X>(0))::value;
};

template <typename X, typename Archive> class valid_input_serializer {
template <typename T>
static typename std::enable_if<
1 == cereal::traits::detail::count_input_serializers<T, Archive>::value,
std::true_type>::type
test(int);
template <typename T> static std::false_type test(...);

public:
static constexpr bool value = decltype(test<X>(0))::value;
};

template <typename X, typename Archive> class valid_in_out_serializer {
template <typename T>
static typename std::enable_if<valid_input_serializer<T, Archive>::value &&
valid_output_serializer<T, Archive>::value,
std::true_type>::type
test(int);
template <typename T> static std::false_type test(...);

public:
static constexpr bool value = decltype(test<X>(0))::value;
};

} // namespace albatross

#endif
58 changes: 58 additions & 0 deletions albatross/eigen/serializable_diagonal_matrix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (C) 2018 Swift Navigation Inc.
* Contact: Swift Navigation <dev@swiftnav.com>
*
* This source is subject to the license found in the file 'LICENSE' which must
* be distributed together with this source. All other rights reserved.
*
* THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND,
* EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A PARTICULAR PURPOSE.
*/

#ifndef ALBATROSS_EIGEN_SERIALIZABLE_DIAGONAL_MATRIX_H
#define ALBATROSS_EIGEN_SERIALIZABLE_DIAGONAL_MATRIX_H

/*
* The Eigen::DiagonalMatrix doesn't provide the public methods
* required to reliably serialize the `m_diagonal` private
* member. In order to make the DiagonalMatrix serializable
* we instead inherit from it, giving private access to the
* diagonal elements which in turn allows us to serialize it.
*/

#include "Eigen/Cholesky"
#include "Eigen/Dense"
#include "cereal/cereal.hpp"
#include <math.h>

namespace Eigen {

template <typename _Scalar, int SizeAtCompileTime>
class SerializableDiagonalMatrix
: public Eigen::DiagonalMatrix<_Scalar, SizeAtCompileTime> {
using BaseClass = Eigen::DiagonalMatrix<_Scalar, SizeAtCompileTime>;

public:
SerializableDiagonalMatrix() : BaseClass(){};

SerializableDiagonalMatrix(const BaseClass &other)
// Can we get around copying here?
: BaseClass(other){};

template <typename OtherDerived>
inline SerializableDiagonalMatrix(const DiagonalBase<OtherDerived> &other)
: BaseClass(other){};

template <typename Archive> void serialize(Archive &archive) {
archive(cereal::make_nvp("diagonal", this->m_diagonal));
}

bool operator==(const BaseClass &other) const {
return (this->m_diagonal == other.diagonal());
}
};

} // namesapce Eigen

#endif
43 changes: 39 additions & 4 deletions tests/test_serialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,39 @@ struct EigenMatrixXd : public SerializableType<Eigen::MatrixXd> {
}
};

struct FullDenseDistribution : public SerializableType<DenseDistribution> {
DenseDistribution create() const override {
Eigen::MatrixXd cov(3, 3);
cov << 1., 2., 3., 4., 5., 6., 7, 8, 9;
Eigen::VectorXd mean = Eigen::VectorXd::Ones(3);
return DenseDistribution(mean, cov);
}
};

struct MeanOnlyDenseDistribution : public SerializableType<DenseDistribution> {
DenseDistribution create() const override {
Eigen::MatrixXd mean = Eigen::VectorXd::Ones(3);
return DenseDistribution(mean);
}
};

struct FullDiagonalDistribution
: public SerializableType<DiagonalDistribution> {
DiagonalDistribution create() const override {
Eigen::VectorXd diag = Eigen::VectorXd::Ones(3);
Eigen::VectorXd mean = Eigen::VectorXd::Ones(3);
return DiagonalDistribution(mean, diag.asDiagonal());
}
};

struct MeanOnlyDiagonalDistribution
: public SerializableType<DiagonalDistribution> {
DiagonalDistribution create() const override {
Eigen::MatrixXd mean = Eigen::VectorXd::Ones(3);
return DiagonalDistribution(mean);
}
};

struct LDLT : public SerializableType<Eigen::SerializableLDLT> {
Eigen::Index n = 3;

Expand Down Expand Up @@ -281,10 +314,12 @@ struct PolymorphicSerializeTest : public ::testing::Test {
typedef ::testing::Types<
LDLT, SerializableType<Eigen::Matrix3d>, SerializableType<Eigen::Matrix2i>,
EmptyEigenVectorXd, EigenVectorXd, EmptyEigenMatrixXd, EigenMatrixXd,
ParameterStoreType, SerializableType<MockModel>, UnfitSerializableModel,
FitSerializableModel, FitDirectModel, UnfitDirectModel,
UnfitRegressionModel, FitLinearRegressionModel,
FitLinearSerializablePointer, UnfitGaussianProcess, FitGaussianProcess>
FullDenseDistribution, MeanOnlyDenseDistribution, FullDiagonalDistribution,
MeanOnlyDiagonalDistribution, ParameterStoreType,
SerializableType<MockModel>, UnfitSerializableModel, FitSerializableModel,
FitDirectModel, UnfitDirectModel, UnfitRegressionModel,
FitLinearRegressionModel, FitLinearSerializablePointer,
UnfitGaussianProcess, FitGaussianProcess>
ToTest;

TYPED_TEST_CASE(PolymorphicSerializeTest, ToTest);
Expand Down

0 comments on commit 947f0bd

Please sign in to comment.