Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Serialization #1368

Merged
merged 9 commits into from
Jan 4, 2023
31 changes: 31 additions & 0 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ namespace gtsam {
*/
size_t nrAssignments_;

/// Default constructor for serialization.
Leaf() {}

/// Constructor from constant
Leaf(const Y& constant, size_t nrAssignments = 1)
: constant_(constant), nrAssignments_(nrAssignments) {}
Expand Down Expand Up @@ -154,6 +157,18 @@ namespace gtsam {
}

bool isLeaf() const override { return true; }

private:
using Base = DecisionTree<L, Y>::Node;

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(constant_);
ar& BOOST_SERIALIZATION_NVP(nrAssignments_);
}
}; // Leaf

/****************************************************************************/
Expand All @@ -177,6 +192,9 @@ namespace gtsam {
using ChoicePtr = boost::shared_ptr<const Choice>;

public:
/// Default constructor for serialization.
Choice() {}

~Choice() override {
#ifdef DT_DEBUG_MEMORY
std::std::cout << Node::nrNodes << " destructing (Choice) " << this->id()
Expand Down Expand Up @@ -428,6 +446,19 @@ namespace gtsam {
r->push_back(branch->choose(label, index));
return Unique(r);
}

private:
using Base = DecisionTree<L, Y>::Node;

/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar & BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_NVP(label_);
ar& BOOST_SERIALIZATION_NVP(branches_);
ar& BOOST_SERIALIZATION_NVP(allSame_);
}
}; // Choice

/****************************************************************************/
Expand Down
19 changes: 19 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

#pragma once

#include <gtsam/base/Testable.h>
#include <gtsam/base/types.h>
#include <gtsam/discrete/Assignment.h>

#include <boost/serialization/nvp.hpp>
#include <boost/shared_ptr.hpp>
#include <functional>
#include <iostream>
Expand Down Expand Up @@ -113,6 +115,12 @@ namespace gtsam {
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
virtual Ptr choose(const L& label, size_t index) const = 0;
virtual bool isLeaf() const = 0;

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {}
};
/** ------------------------ Node base class --------------------------- */

Expand Down Expand Up @@ -364,8 +372,19 @@ namespace gtsam {
compose(Iterator begin, Iterator end, const L& label) const;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_NVP(root_);
}
}; // DecisionTree

template <class L, class Y>
struct traits<DecisionTree<L, Y>> : public Testable<DecisionTree<L, Y>> {};

/** free versions of apply */

/// Apply unary operator `op` to DecisionTree `f`.
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,16 @@ namespace gtsam {
const Names& names = {}) const override;

/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(ADT);
ar& BOOST_SERIALIZATION_NVP(cardinalities_);
}
};

// traits
Expand Down
9 changes: 9 additions & 0 deletions gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ class GTSAM_EXPORT DiscreteConditional
/// Internal version of choose
DiscreteConditional::ADT choose(const DiscreteValues& given,
bool forceComplete) const;

private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
};
// DiscreteConditional

Expand Down
36 changes: 32 additions & 4 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
// #define DT_DEBUG_MEMORY
// #define GTSAM_DT_NO_PRUNING
#define DISABLE_DOT
#include <gtsam/discrete/DecisionTree-inl.h>

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h>

#include <CppUnitLite/TestHarness.h>

using namespace std;
using namespace gtsam;

Expand Down Expand Up @@ -529,6 +528,35 @@ TEST(DecisionTree, ApplyWithAssignment) {
EXPECT_LONGS_EQUAL(5, count);
}

/* ****************************************************************************/
using Tree = gtsam::DecisionTree<string, int>;

BOOST_CLASS_EXPORT_GUID(Tree, "gtsam_DecisionTreeStringInt")
BOOST_CLASS_EXPORT_GUID(Tree::Leaf, "gtsam_DecisionTree_Leaf")
BOOST_CLASS_EXPORT_GUID(Tree::Choice, "gtsam_DecisionTree_Choice")

// Test HybridBayesNet serialization.
TEST(DecisionTree, Serialization) {
Tree tree({{"A", 2}}, std::vector<int>{1, 2});

using namespace serializationTestHelpers;

// Object roundtrip
Tree outputObj = create<Tree>();
roundtrip<Tree>(tree, outputObj);
EXPECT(tree.equals(outputObj));

// XML roundtrip
Tree outputXml = create<Tree>();
roundtripXML<Tree>(tree, outputXml);
EXPECT(tree.equals(outputXml));

// Binary roundtrip
Tree outputBinary = create<Tree>();
roundtripBinary<Tree>(tree, outputBinary);
EXPECT(tree.equals(outputBinary));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
25 changes: 25 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/Signature.h>
Expand Down Expand Up @@ -220,6 +221,30 @@ TEST(DecisionTreeFactor, htmlWithValueFormatter) {
EXPECT(actual == expected);
}

/* ************************************************************************* */
BOOST_CLASS_EXPORT_GUID(DecisionTreeFactor, "gtsam_DecisionTreeFactor");
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT, "gtsam_AlgebraicDecisionTree");
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_DecisionTree_Leaf")
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_DecisionTree_Choice")

// Check serialization for AlgebraicDecisionTree and the DecisionTreeFactor
TEST(DecisionTreeFactor, Serialization) {
using namespace serializationTestHelpers;

DiscreteKey A(1, 2), B(2, 2), C(3, 2);

DecisionTreeFactor::ADT tree(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<ADT>(tree));
EXPECT(equalsXML<DecisionTreeFactor::ADT>(tree));
EXPECT(equalsBinary<DecisionTreeFactor::ADT>(tree));

DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");
EXPECT(equalsObj<DecisionTreeFactor>(f));
EXPECT(equalsXML<DecisionTreeFactor>(f));
EXPECT(equalsBinary<DecisionTreeFactor>(f));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
23 changes: 20 additions & 3 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
* @date Feb 14, 2011
*/

#include <boost/make_shared.hpp>

#include <CppUnitLite/TestHarness.h>
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/inference/Symbol.h>

#include <boost/make_shared.hpp>

using namespace std;
using namespace gtsam;

Expand Down Expand Up @@ -209,7 +210,6 @@ TEST(DiscreteConditional, marginals2) {
DiscreteConditional conditional(A | B = "2/2 3/1");
DiscreteConditional prior(B % "1/2");
DiscreteConditional pAB = prior * conditional;
GTSAM_PRINT(pAB);
// P(A=0) = P(A=0|B=0)P(B=0) + P(A=0|B=1)P(B=1) = 2*1 + 3*2 = 8
// P(A=1) = P(A=1|B=0)P(B=0) + P(A=1|B=1)P(B=1) = 2*1 + 1*2 = 4
DiscreteConditional actualA = pAB.marginal(A.first);
Expand Down Expand Up @@ -368,6 +368,23 @@ TEST(DiscreteConditional, html) {
EXPECT(actual == expected);
}

/* ************************************************************************* */
using ADT = AlgebraicDecisionTree<Key>;
BOOST_CLASS_EXPORT_GUID(ADT::Leaf, "gtsam_ADT_Leaf")
BOOST_CLASS_EXPORT_GUID(ADT::Choice, "gtsam_ADT_Choice")

// Check serialization for DiscreteConditional
TEST(DiscreteConditional, Serialization) {
using namespace serializationTestHelpers;

DiscreteKey A(Symbol('x', 1), 3);
DiscreteConditional conditional(A % "1/2/2");

EXPECT(equalsObj<DiscreteConditional>(conditional));
EXPECT(equalsXML<DiscreteConditional>(conditional));
EXPECT(equalsBinary<DiscreteConditional>(conditional));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
10 changes: 10 additions & 0 deletions gtsam/hybrid/GaussianMixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ class GTSAM_EXPORT GaussianMixture
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar &BOOST_SERIALIZATION_NVP(conditionals_);
}
};

/// Return the DiscreteKey vector as a set.
Expand Down
18 changes: 18 additions & 0 deletions gtsam/hybrid/GaussianMixtureFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
bool operator==(const FactorAndConstant &other) const {
return factor == other.factor && constant == other.constant;
}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_NVP(factor);
ar &BOOST_SERIALIZATION_NVP(constant);
}
};

/// typedef for Decision Tree of Gaussian factors and log-constant.
Expand Down Expand Up @@ -179,6 +188,15 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
return sum;
}
/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(factors_);
}
};

// traits
Expand Down
1 change: 0 additions & 1 deletion gtsam/hybrid/HybridConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ bool HybridConditional::equals(const HybridFactor &other, double tol) const {
auto other = e->asDiscrete();
return other != nullptr && dc->equals(*other, tol);
}
return inner_->equals(*(e->inner_), tol);

return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
Expand Down
13 changes: 13 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ class GTSAM_EXPORT HybridConditional
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
ar& BOOST_SERIALIZATION_NVP(inner_);

// register the various casts based on the type of inner_
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
if (isDiscrete()) {
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
} else if (isContinuous()) {
boost::serialization::void_cast_register<GaussianConditional, Factor>(
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
} else {
boost::serialization::void_cast_register<GaussianMixture, Factor>(
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL));
}
}

}; // HybridConditional
Expand Down
6 changes: 4 additions & 2 deletions gtsam/hybrid/HybridDiscreteFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf)
/* ************************************************************************ */
bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
// TODO(Varun) How to compare inner_ when they are abstract types?
return e != nullptr && Base::equals(*e, tol);
if (e == nullptr) return false;
if (!Base::equals(*e, tol)) return false;
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}

/* ************************************************************************ */
Expand Down
12 changes: 12 additions & 0 deletions gtsam/hybrid/HybridDiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
/// @name Constructors
/// @{

/// Default constructor - for serialization.
HybridDiscreteFactor() = default;

// Implicit conversion from a shared ptr of DF
HybridDiscreteFactor(DiscreteFactor::shared_ptr other);

Expand All @@ -70,6 +73,15 @@ class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor {
/// Return the error of the underlying Discrete Factor.
double error(const HybridValues &values) const override;
/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
ar &BOOST_SERIALIZATION_NVP(inner_);
}
};

// traits
Expand Down
Loading