diff --git a/gtsam/CMakeLists.txt b/gtsam/CMakeLists.txt index a293c6ec28..09f1ea8064 100644 --- a/gtsam/CMakeLists.txt +++ b/gtsam/CMakeLists.txt @@ -10,6 +10,7 @@ set (gtsam_subdirs inference symbolic discrete + hybrid linear nonlinear sam diff --git a/gtsam/hybrid/CMakeLists.txt b/gtsam/hybrid/CMakeLists.txt new file mode 100644 index 0000000000..f1cfcd5c4b --- /dev/null +++ b/gtsam/hybrid/CMakeLists.txt @@ -0,0 +1,8 @@ +# Install headers +set(subdir hybrid) +file(GLOB hybrid_headers "*.h") +# FIXME: exclude headers +install(FILES ${hybrid_headers} DESTINATION include/gtsam/hybrid) + +# Add all tests +add_subdirectory(tests) diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp new file mode 100644 index 0000000000..0000575182 --- /dev/null +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -0,0 +1,110 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GaussianMixture.cpp + * @brief A hybrid conditional in the Conditional Linear Gaussian scheme + * @author Fan Jiang + * @author Varun Agrawal + * @author Frank Dellaert + * @date Mar 12, 2022 + */ + +#include +#include +#include +#include +#include + +namespace gtsam { + +GaussianMixture::GaussianMixture( + const KeyVector &continuousFrontals, const KeyVector &continuousParents, + const DiscreteKeys &discreteParents, + const GaussianMixture::Conditionals &conditionals) + : BaseFactor(CollectKeys(continuousFrontals, continuousParents), + discreteParents), + BaseConditional(continuousFrontals.size()), + conditionals_(conditionals) {} + +/* *******************************************************************************/ +const GaussianMixture::Conditionals & +GaussianMixture::conditionals() { + return conditionals_; +} + +/* *******************************************************************************/ +GaussianMixture GaussianMixture::FromConditionals( + const KeyVector &continuousFrontals, const KeyVector &continuousParents, + const DiscreteKeys &discreteParents, + const std::vector &conditionalsList) { + Conditionals dt(discreteParents, conditionalsList); + + return GaussianMixture(continuousFrontals, continuousParents, + discreteParents, dt); +} + +/* *******************************************************************************/ +GaussianMixture::Sum GaussianMixture::add( + const GaussianMixture::Sum &sum) const { + using Y = GaussianFactorGraph; + auto add = [](const Y &graph1, const Y &graph2) { + auto result = graph1; + result.push_back(graph2); + return result; + }; + const Sum tree = asGaussianFactorGraphTree(); + return sum.empty() ? tree : sum.apply(tree, add); +} + +/* *******************************************************************************/ +GaussianMixture::Sum +GaussianMixture::asGaussianFactorGraphTree() const { + auto lambda = [](const GaussianFactor::shared_ptr &factor) { + GaussianFactorGraph result; + result.push_back(factor); + return result; + }; + return {conditionals_, lambda}; +} + +/* *******************************************************************************/ +bool GaussianMixture::equals(const HybridFactor &lf, + double tol) const { + const This *e = dynamic_cast(&lf); + return e != nullptr && BaseFactor::equals(*e, tol); +} + +/* *******************************************************************************/ +void GaussianMixture::print(const std::string &s, + const KeyFormatter &formatter) const { + std::cout << s; + if (isContinuous()) std::cout << "Continuous "; + if (isDiscrete()) std::cout << "Discrete "; + if (isHybrid()) std::cout << "Hybrid "; + BaseConditional::print("", formatter); + std::cout << "\nDiscrete Keys = "; + for (auto &dk : discreteKeys()) { + std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; + } + std::cout << "\n"; + conditionals_.print( + "", [&](Key k) { return formatter(k); }, + [&](const GaussianConditional::shared_ptr &gf) -> std::string { + RedirectCout rd; + if (!gf->empty()) + gf->print("", formatter); + else + return {"nullptr"}; + return rd.str(); + }); +} +} // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h new file mode 100644 index 0000000000..e855067150 --- /dev/null +++ b/gtsam/hybrid/GaussianMixture.h @@ -0,0 +1,133 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GaussianMixture.h + * @brief A hybrid conditional in the Conditional Linear Gaussian scheme + * @author Fan Jiang + * @author Varun Agrawal + * @date Mar 12, 2022 + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + +/** + * @brief A conditional of gaussian mixtures indexed by discrete variables, as + * part of a Bayes Network. + * + * Represents the conditional density P(X | M, Z) where X is a continuous random + * variable, M is the selection of discrete variables corresponding to a subset + * of the Gaussian variables and Z is parent of this node + * + * The probability P(x|y,z,...) is proportional to + * \f$ \sum_i k_i \exp - \frac{1}{2} |R_i x - (d_i - S_i y - T_i z - ...)|^2 \f$ + * where i indexes the components and k_i is a component-wise normalization + * constant. + * + */ +class GTSAM_EXPORT GaussianMixture + : public HybridFactor, + public Conditional { + public: + using This = GaussianMixture; + using shared_ptr = boost::shared_ptr; + using BaseFactor = HybridFactor; + using BaseConditional = Conditional; + + /// Alias for DecisionTree of GaussianFactorGraphs + using Sum = DecisionTree; + + /// typedef for Decision Tree of Gaussian Conditionals + using Conditionals = DecisionTree; + + private: + Conditionals conditionals_; + + /** + * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. + */ + Sum asGaussianFactorGraphTree() const; + + public: + /// @name Constructors + /// @{ + + /// Defaut constructor, mainly for serialization. + GaussianMixture() = default; + + /** + * @brief Construct a new GaussianMixture object. + * + * @param continuousFrontals the continuous frontals. + * @param continuousParents the continuous parents. + * @param discreteParents the discrete parents. Will be placed last. + * @param conditionals a decision tree of GaussianConditionals. The number of + * conditionals should be C^(number of discrete parents), where C is the + * cardinality of the DiscreteKeys in discreteParents, since the + * discreteParents will be used as the labels in the decision tree. + */ + GaussianMixture(const KeyVector &continuousFrontals, + const KeyVector &continuousParents, + const DiscreteKeys &discreteParents, + const Conditionals &conditionals); + + /** + * @brief Make a Gaussian Mixture from a list of Gaussian conditionals + * + * @param continuousFrontals The continuous frontal variables + * @param continuousParents The continuous parent variables + * @param discreteParents Discrete parents variables + * @param conditionals List of conditionals + */ + static This FromConditionals( + const KeyVector &continuousFrontals, const KeyVector &continuousParents, + const DiscreteKeys &discreteParents, + const std::vector &conditionals); + + /// @} + /// @name Testable + /// @{ + + /// Test equality with base HybridFactor + bool equals(const HybridFactor &lf, double tol = 1e-9) const override; + + /* print utility */ + void print( + const std::string &s = "GaussianMixture\n", + const KeyFormatter &formatter = DefaultKeyFormatter) const override; + + /// @} + + /// Getter for the underlying Conditionals DecisionTree + const Conditionals &conditionals(); + + /** + * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while + * maintaining the decision tree structure. + * + * @param sum Decision Tree of Gaussian Factor Graphs + * @return Sum + */ + Sum add(const Sum &sum) const; +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.cpp b/gtsam/hybrid/GaussianMixtureFactor.cpp new file mode 100644 index 0000000000..a81cf341d9 --- /dev/null +++ b/gtsam/hybrid/GaussianMixtureFactor.cpp @@ -0,0 +1,94 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GaussianMixtureFactor.cpp + * @brief A set of Gaussian factors indexed by a set of discrete keys. + * @author Fan Jiang + * @author Varun Agrawal + * @author Frank Dellaert + * @date Mar 12, 2022 + */ + +#include +#include +#include +#include +#include + +namespace gtsam { + +/* *******************************************************************************/ +GaussianMixtureFactor::GaussianMixtureFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys, + const Factors &factors) + : Base(continuousKeys, discreteKeys), factors_(factors) {} + +/* *******************************************************************************/ +bool GaussianMixtureFactor::equals(const HybridFactor &lf, double tol) const { + const This *e = dynamic_cast(&lf); + return e != nullptr && Base::equals(*e, tol); +} + +/* *******************************************************************************/ +GaussianMixtureFactor GaussianMixtureFactor::FromFactors( + const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, + const std::vector &factors) { + Factors dt(discreteKeys, factors); + + return GaussianMixtureFactor(continuousKeys, discreteKeys, dt); +} + +/* *******************************************************************************/ +void GaussianMixtureFactor::print(const std::string &s, + const KeyFormatter &formatter) const { + HybridFactor::print(s, formatter); + factors_.print( + "mixture = ", [&](Key k) { return formatter(k); }, + [&](const GaussianFactor::shared_ptr &gf) -> std::string { + RedirectCout rd; + if (!gf->empty()) + gf->print("", formatter); + else + return {"nullptr"}; + return rd.str(); + }); +} + +/* *******************************************************************************/ +const GaussianMixtureFactor::Factors &GaussianMixtureFactor::factors() { + return factors_; +} + +/* *******************************************************************************/ +GaussianMixtureFactor::Sum GaussianMixtureFactor::add( + const GaussianMixtureFactor::Sum &sum) const { + using Y = GaussianFactorGraph; + auto add = [](const Y &graph1, const Y &graph2) { + auto result = graph1; + result.push_back(graph2); + return result; + }; + const Sum tree = asGaussianFactorGraphTree(); + return sum.empty() ? tree : sum.apply(tree, add); +} + +/* *******************************************************************************/ +GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() + const { + auto wrap = [](const GaussianFactor::shared_ptr &factor) { + GaussianFactorGraph result; + result.push_back(factor); + return result; + }; + return {factors_, wrap}; +} +} // namespace gtsam diff --git a/gtsam/hybrid/GaussianMixtureFactor.h b/gtsam/hybrid/GaussianMixtureFactor.h new file mode 100644 index 0000000000..21770f836e --- /dev/null +++ b/gtsam/hybrid/GaussianMixtureFactor.h @@ -0,0 +1,121 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file GaussianMixtureFactor.h + * @brief A set of GaussianFactors, indexed by a set of discrete keys. + * @author Fan Jiang + * @author Varun Agrawal + * @author Frank Dellaert + * @date Mar 12, 2022 + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + +class GaussianFactorGraph; + +using GaussianFactorVector = std::vector; + +/** + * @brief Implementation of a discrete conditional mixture factor. + * Implements a joint discrete-continuous factor where the discrete variable + * serves to "select" a mixture component corresponding to a GaussianFactor type + * of measurement. + * + * Represents the underlying Gaussian Mixture as a Decision Tree, where the set + * of discrete variables indexes to the continuous gaussian distribution. + * + */ +class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor { + public: + using Base = HybridFactor; + using This = GaussianMixtureFactor; + using shared_ptr = boost::shared_ptr; + + using Sum = DecisionTree; + + /// typedef for Decision Tree of Gaussian Factors + using Factors = DecisionTree; + + private: + /// Decision tree of Gaussian factors indexed by discrete keys. + Factors factors_; + + /** + * @brief Helper function to return factors and functional to create a + * DecisionTree of Gaussian Factor Graphs. + * + * @return Sum (DecisionTree) + */ + Sum asGaussianFactorGraphTree() const; + + public: + /// @name Constructors + /// @{ + + /// Default constructor, mainly for serialization. + GaussianMixtureFactor() = default; + + /** + * @brief Construct a new Gaussian Mixture Factor object. + * + * @param continuousKeys A vector of keys representing continuous variables. + * @param discreteKeys A vector of keys representing discrete variables and + * their cardinalities. + * @param factors The decision tree of Gaussian Factors stored as the mixture + * density. + */ + GaussianMixtureFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys, + const Factors &factors); + + static This FromFactors( + const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys, + const std::vector &factors); + + /// @} + /// @name Testable + /// @{ + + bool equals(const HybridFactor &lf, double tol = 1e-9) const override; + + void print( + const std::string &s = "HybridFactor\n", + const KeyFormatter &formatter = DefaultKeyFormatter) const override; + /// @} + + /// Getter for the underlying Gaussian Factor Decision Tree. + const Factors &factors(); + + /** + * @brief Combine the Gaussian Factor Graphs in `sum` and `this` while + * maintaining the original tree structure. + * + * @param sum Decision Tree of Gaussian Factor Graphs indexed by the + * variables. + * @return Sum + */ + Sum add(const Sum &sum) const; +}; + +// traits +template <> +struct traits : public Testable { +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp new file mode 100644 index 0000000000..1292711d89 --- /dev/null +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -0,0 +1,16 @@ +/* ---------------------------------------------------------------------------- + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + * See LICENSE for the license information + * -------------------------------------------------------------------------- */ + +/** + * @file HybridBayesNet.cpp + * @brief A bayes net of Gaussian Conditionals indexed by discrete keys. + * @author Fan Jiang + * @date January 2022 + */ + +#include diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h new file mode 100644 index 0000000000..43eead2801 --- /dev/null +++ b/gtsam/hybrid/HybridBayesNet.h @@ -0,0 +1,41 @@ +/* ---------------------------------------------------------------------------- + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + * See LICENSE for the license information + * -------------------------------------------------------------------------- */ + +/** + * @file HybridBayesNet.h + * @brief A bayes net of Gaussian Conditionals indexed by discrete keys. + * @author Varun Agrawal + * @author Fan Jiang + * @author Frank Dellaert + * @date December 2021 + */ + +#pragma once + +#include +#include + +namespace gtsam { + +/** + * A hybrid Bayes net is a collection of HybridConditionals, which can have + * discrete conditionals, Gaussian mixtures, or pure Gaussian conditionals. + */ +class GTSAM_EXPORT HybridBayesNet : public BayesNet { + public: + using Base = BayesNet; + using This = HybridBayesNet; + using ConditionalType = HybridConditional; + using shared_ptr = boost::shared_ptr; + using sharedConditional = boost::shared_ptr; + + /** Construct empty bayes net */ + HybridBayesNet() = default; +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.cpp b/gtsam/hybrid/HybridBayesTree.cpp new file mode 100644 index 0000000000..d65270f91d --- /dev/null +++ b/gtsam/hybrid/HybridBayesTree.cpp @@ -0,0 +1,38 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridBayesTree.cpp + * @brief Hybrid Bayes Tree, the result of eliminating a + * HybridJunctionTree + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#include +#include +#include +#include +#include + +namespace gtsam { + +// Instantiate base class +template class BayesTreeCliqueBase; +template class BayesTree; + +/* ************************************************************************* */ +bool HybridBayesTree::equals(const This& other, double tol) const { + return Base::equals(other, tol); +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridBayesTree.h b/gtsam/hybrid/HybridBayesTree.h new file mode 100644 index 0000000000..0b89ca8c4b --- /dev/null +++ b/gtsam/hybrid/HybridBayesTree.h @@ -0,0 +1,117 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridBayesTree.h + * @brief Hybrid Bayes Tree, the result of eliminating a + * HybridJunctionTree + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace gtsam { + +// Forward declarations +class HybridConditional; +class VectorValues; + +/* ************************************************************************* */ +/** A clique in a HybridBayesTree + * which is a HybridConditional internally. + */ +class GTSAM_EXPORT HybridBayesTreeClique + : public BayesTreeCliqueBase { + public: + typedef HybridBayesTreeClique This; + typedef BayesTreeCliqueBase + Base; + typedef boost::shared_ptr shared_ptr; + typedef boost::weak_ptr weak_ptr; + HybridBayesTreeClique() {} + virtual ~HybridBayesTreeClique() {} + HybridBayesTreeClique(const boost::shared_ptr& conditional) + : Base(conditional) {} +}; + +/* ************************************************************************* */ +/** A Bayes tree representing a Hybrid density */ +class GTSAM_EXPORT HybridBayesTree : public BayesTree { + private: + typedef BayesTree Base; + + public: + typedef HybridBayesTree This; + typedef boost::shared_ptr shared_ptr; + + /// @name Standard interface + /// @{ + /** Default constructor, creates an empty Bayes tree */ + HybridBayesTree() = default; + + /** Check equality */ + bool equals(const This& other, double tol = 1e-9) const; + + /// @} +}; + +/** + * @brief Class for Hybrid Bayes tree orphan subtrees. + * + * This does special stuff for the hybrid case + * + * @tparam CLIQUE + */ +template +class BayesTreeOrphanWrapper< + CLIQUE, typename std::enable_if< + boost::is_same::value> > + : public CLIQUE::ConditionalType { + public: + typedef CLIQUE CliqueType; + typedef typename CLIQUE::ConditionalType Base; + + boost::shared_ptr clique; + + /** + * @brief Construct a new Bayes Tree Orphan Wrapper object. + * + * @param clique Bayes tree clique. + */ + BayesTreeOrphanWrapper(const boost::shared_ptr& clique) + : clique(clique) { + // Store parent keys in our base type factor so that eliminating those + // parent keys will pull this subtree into the elimination. + this->keys_.assign(clique->conditional()->beginParents(), + clique->conditional()->endParents()); + this->discreteKeys_.assign(clique->conditional()->discreteKeys().begin(), + clique->conditional()->discreteKeys().end()); + } + + /// print utility + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { + clique->print(s + "stored clique", formatter); + } +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.cpp b/gtsam/hybrid/HybridConditional.cpp new file mode 100644 index 0000000000..8e071532dd --- /dev/null +++ b/gtsam/hybrid/HybridConditional.cpp @@ -0,0 +1,108 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridConditional.cpp + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#include +#include +#include +#include + +namespace gtsam { + +/* ************************************************************************ */ +HybridConditional::HybridConditional(const KeyVector &continuousFrontals, + const DiscreteKeys &discreteFrontals, + const KeyVector &continuousParents, + const DiscreteKeys &discreteParents) + : HybridConditional( + CollectKeys( + {continuousFrontals.begin(), continuousFrontals.end()}, + KeyVector{continuousParents.begin(), continuousParents.end()}), + CollectDiscreteKeys( + {discreteFrontals.begin(), discreteFrontals.end()}, + {discreteParents.begin(), discreteParents.end()}), + continuousFrontals.size() + discreteFrontals.size()) {} + +/* ************************************************************************ */ +HybridConditional::HybridConditional( + boost::shared_ptr continuousConditional) + : HybridConditional(continuousConditional->keys(), {}, + continuousConditional->nrFrontals()) { + inner_ = continuousConditional; +} + +/* ************************************************************************ */ +HybridConditional::HybridConditional( + boost::shared_ptr discreteConditional) + : HybridConditional({}, discreteConditional->discreteKeys(), + discreteConditional->nrFrontals()) { + inner_ = discreteConditional; +} + +/* ************************************************************************ */ +HybridConditional::HybridConditional( + boost::shared_ptr gaussianMixture) + : BaseFactor(KeyVector(gaussianMixture->keys().begin(), + gaussianMixture->keys().begin() + + gaussianMixture->nrContinuous()), + gaussianMixture->discreteKeys()), + BaseConditional(gaussianMixture->nrFrontals()) { + inner_ = gaussianMixture; +} + +/* ************************************************************************ */ +void HybridConditional::print(const std::string &s, + const KeyFormatter &formatter) const { + std::cout << s; + + if (inner_) { + inner_->print("", formatter); + + } else { + if (isContinuous()) std::cout << "Continuous "; + if (isDiscrete()) std::cout << "Discrete "; + if (isHybrid()) std::cout << "Hybrid "; + BaseConditional::print("", formatter); + + std::cout << "P("; + size_t index = 0; + const size_t N = keys().size(); + const size_t contN = N - discreteKeys_.size(); + while (index < N) { + if (index > 0) { + if (index == nrFrontals_) + std::cout << " | "; + else + std::cout << ", "; + } + if (index < contN) { + std::cout << formatter(keys()[index]); + } else { + auto &dk = discreteKeys_[index - contN]; + std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")"; + } + index++; + } + } +} + +/* ************************************************************************ */ +bool HybridConditional::equals(const HybridFactor &other, double tol) const { + const This *e = dynamic_cast(&other); + return e != nullptr && BaseFactor::equals(*e, tol); +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridConditional.h b/gtsam/hybrid/HybridConditional.h new file mode 100644 index 0000000000..3ba5da393b --- /dev/null +++ b/gtsam/hybrid/HybridConditional.h @@ -0,0 +1,177 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridConditional.h + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace gtsam { + +class HybridGaussianFactorGraph; + +/** + * Hybrid Conditional Density + * + * As a type-erased variant of: + * - DiscreteConditional + * - GaussianConditional + * - GaussianMixture + * + * The reason why this is important is that `Conditional` is a CRTP class. + * CRTP is static polymorphism such that all CRTP classes, while bearing the + * same name, are different classes not sharing a vtable. This prevents them + * from being contained in any container, and thus it is impossible to + * dynamically cast between them. A better option, as illustrated here, is + * treating them as an implementation detail - such that the hybrid mechanism + * does not know what is inside the HybridConditional. This prevents us from + * having diamond inheritances, and neutralized the need to change other + * components of GTSAM to make hybrid elimination work. + * + * A great reference to the type-erasure pattern is Eduaado Madrid's CppCon + * talk (https://www.youtube.com/watch?v=s082Qmd_nHs). + */ +class GTSAM_EXPORT HybridConditional + : public HybridFactor, + public Conditional { + public: + // typedefs needed to play nice with gtsam + typedef HybridConditional This; ///< Typedef to this class + typedef boost::shared_ptr shared_ptr; ///< shared_ptr to this class + typedef HybridFactor BaseFactor; ///< Typedef to our factor base class + typedef Conditional + BaseConditional; ///< Typedef to our conditional base class + + protected: + // Type-erased pointer to the inner type + boost::shared_ptr inner_; + + public: + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + HybridConditional() = default; + + /** + * @brief Construct a new Hybrid Conditional object + * + * @param continuousKeys Vector of keys for continuous variables. + * @param discreteKeys Keys and cardinalities for discrete variables. + * @param nFrontals The number of frontal variables in the conditional. + */ + HybridConditional(const KeyVector& continuousKeys, + const DiscreteKeys& discreteKeys, size_t nFrontals) + : BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {} + + /** + * @brief Construct a new Hybrid Conditional object + * + * @param continuousFrontals Vector of keys for continuous variables. + * @param discreteFrontals Keys and cardinalities for discrete variables. + * @param continuousParents Vector of keys for parent continuous variables. + * @param discreteParents Keys and cardinalities for parent discrete + * variables. + */ + HybridConditional(const KeyVector& continuousFrontals, + const DiscreteKeys& discreteFrontals, + const KeyVector& continuousParents, + const DiscreteKeys& discreteParents); + + /** + * @brief Construct a new Hybrid Conditional object + * + * @param continuousConditional Conditional used to create the + * HybridConditional. + */ + HybridConditional( + boost::shared_ptr continuousConditional); + + /** + * @brief Construct a new Hybrid Conditional object + * + * @param discreteConditional Conditional used to create the + * HybridConditional. + */ + HybridConditional(boost::shared_ptr discreteConditional); + + /** + * @brief Construct a new Hybrid Conditional object + * + * @param gaussianMixture Gaussian Mixture Conditional used to create the + * HybridConditional. + */ + HybridConditional( + boost::shared_ptr gaussianMixture); + + /** + * @brief Return HybridConditional as a GaussianMixture + * + * @return GaussianMixture::shared_ptr + */ + GaussianMixture::shared_ptr asMixture() { + if (!isHybrid()) throw std::invalid_argument("Not a mixture"); + return boost::static_pointer_cast(inner_); + } + + /** + * @brief Return conditional as a DiscreteConditional + * + * @return DiscreteConditional::shared_ptr + */ + DiscreteConditional::shared_ptr asDiscreteConditional() { + if (!isDiscrete()) + throw std::invalid_argument("Not a discrete conditional"); + return boost::static_pointer_cast(inner_); + } + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Hybrid Conditional: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// GTSAM-style equals + bool equals(const HybridFactor& other, double tol = 1e-9) const override; + + /// @} + + /// Get the type-erased pointer to the inner type + boost::shared_ptr inner() { return inner_; } + +}; // DiscreteConditional + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridDiscreteFactor.cpp b/gtsam/hybrid/HybridDiscreteFactor.cpp new file mode 100644 index 0000000000..2bdcdee8cb --- /dev/null +++ b/gtsam/hybrid/HybridDiscreteFactor.cpp @@ -0,0 +1,53 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridDiscreteFactor.cpp + * @brief Wrapper for a discrete factor + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#include + +#include + +#include "gtsam/discrete/DecisionTreeFactor.h" + +namespace gtsam { + +/* ************************************************************************ */ +// TODO(fan): THIS IS VERY VERY DIRTY! We need to get DiscreteFactor right! +HybridDiscreteFactor::HybridDiscreteFactor(DiscreteFactor::shared_ptr other) + : Base(boost::dynamic_pointer_cast(other) + ->discreteKeys()), + inner_(other) {} + +/* ************************************************************************ */ +HybridDiscreteFactor::HybridDiscreteFactor(DecisionTreeFactor &&dtf) + : Base(dtf.discreteKeys()), + inner_(boost::make_shared(std::move(dtf))) {} + +/* ************************************************************************ */ +bool HybridDiscreteFactor::equals(const HybridFactor &lf, double tol) const { + const This *e = dynamic_cast(&lf); + // TODO(Varun) How to compare inner_ when they are abstract types? + return e != nullptr && Base::equals(*e, tol); +} + +/* ************************************************************************ */ +void HybridDiscreteFactor::print(const std::string &s, + const KeyFormatter &formatter) const { + HybridFactor::print(s, formatter); + inner_->print("inner: ", formatter); +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridDiscreteFactor.h b/gtsam/hybrid/HybridDiscreteFactor.h new file mode 100644 index 0000000000..9cbea8170f --- /dev/null +++ b/gtsam/hybrid/HybridDiscreteFactor.h @@ -0,0 +1,69 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridDiscreteFactor.h + * @date Mar 11, 2022 + * @author Fan Jiang + * @author Varun Agrawal + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + +/** + * A HybridDiscreteFactor is a thin container for DiscreteFactor, which allows + * us to hide the implementation of DiscreteFactor and thus avoid diamond + * inheritance. + */ +class GTSAM_EXPORT HybridDiscreteFactor : public HybridFactor { + private: + DiscreteFactor::shared_ptr inner_; + + public: + using Base = HybridFactor; + using This = HybridDiscreteFactor; + using shared_ptr = boost::shared_ptr; + + /// @name Constructors + /// @{ + + // Implicit conversion from a shared ptr of DF + HybridDiscreteFactor(DiscreteFactor::shared_ptr other); + + // Forwarding constructor from concrete DecisionTreeFactor + HybridDiscreteFactor(DecisionTreeFactor &&dtf); + + /// @} + /// @name Testable + /// @{ + virtual bool equals(const HybridFactor &lf, double tol) const override; + + void print( + const std::string &s = "HybridFactor\n", + const KeyFormatter &formatter = DefaultKeyFormatter) const override; + + /// @} + + /// Return pointer to the internal discrete factor + DiscreteFactor::shared_ptr inner() const { return inner_; } +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridEliminationTree.cpp b/gtsam/hybrid/HybridEliminationTree.cpp new file mode 100644 index 0000000000..c2df2dd600 --- /dev/null +++ b/gtsam/hybrid/HybridEliminationTree.cpp @@ -0,0 +1,42 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridEliminationTree.cpp + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#include +#include + +namespace gtsam { + +// Instantiate base class +template class EliminationTree; + +/* ************************************************************************* */ +HybridEliminationTree::HybridEliminationTree( + const HybridGaussianFactorGraph& factorGraph, + const VariableIndex& structure, const Ordering& order) + : Base(factorGraph, structure, order) {} + +/* ************************************************************************* */ +HybridEliminationTree::HybridEliminationTree( + const HybridGaussianFactorGraph& factorGraph, const Ordering& order) + : Base(factorGraph, order) {} + +/* ************************************************************************* */ +bool HybridEliminationTree::equals(const This& other, double tol) const { + return Base::equals(other, tol); +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridEliminationTree.h b/gtsam/hybrid/HybridEliminationTree.h new file mode 100644 index 0000000000..77a84fea85 --- /dev/null +++ b/gtsam/hybrid/HybridEliminationTree.h @@ -0,0 +1,69 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridEliminationTree.h + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + +/** + * Elimination Tree type for Hybrid + */ +class GTSAM_EXPORT HybridEliminationTree + : public EliminationTree { + private: + friend class ::EliminationTreeTester; + + public: + typedef EliminationTree + Base; ///< Base class + typedef HybridEliminationTree This; ///< This class + typedef boost::shared_ptr shared_ptr; ///< Shared pointer to this class + + /// @name Constructors + /// @{ + + /** + * Build the elimination tree of a factor graph using pre-computed column + * structure. + * @param factorGraph The factor graph for which to build the elimination tree + * @param structure The set of factors involving each variable. If this is + * not precomputed, you can call the Create(const FactorGraph&) + * named constructor instead. + * @return The elimination tree + */ + HybridEliminationTree(const HybridGaussianFactorGraph& factorGraph, + const VariableIndex& structure, const Ordering& order); + + /** Build the elimination tree of a factor graph. Note that this has to + * compute the column structure as a VariableIndex, so if you already have + * this precomputed, use the other constructor instead. + * @param factorGraph The factor graph for which to build the elimination tree + */ + HybridEliminationTree(const HybridGaussianFactorGraph& factorGraph, + const Ordering& order); + + /// @} + + /** Test whether the tree is equal to another */ + bool equals(const This& other, double tol = 1e-9) const; +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridFactor.cpp b/gtsam/hybrid/HybridFactor.cpp new file mode 100644 index 0000000000..127c9761c0 --- /dev/null +++ b/gtsam/hybrid/HybridFactor.cpp @@ -0,0 +1,89 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridFactor.cpp + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#include + +namespace gtsam { + +/* ************************************************************************ */ +KeyVector CollectKeys(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys) { + KeyVector allKeys; + std::copy(continuousKeys.begin(), continuousKeys.end(), + std::back_inserter(allKeys)); + std::transform(discreteKeys.begin(), discreteKeys.end(), + std::back_inserter(allKeys), + [](const DiscreteKey &k) { return k.first; }); + return allKeys; +} + +/* ************************************************************************ */ +KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2) { + KeyVector allKeys; + std::copy(keys1.begin(), keys1.end(), std::back_inserter(allKeys)); + std::copy(keys2.begin(), keys2.end(), std::back_inserter(allKeys)); + return allKeys; +} + +/* ************************************************************************ */ +DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, + const DiscreteKeys &key2) { + DiscreteKeys allKeys; + std::copy(key1.begin(), key1.end(), std::back_inserter(allKeys)); + std::copy(key2.begin(), key2.end(), std::back_inserter(allKeys)); + return allKeys; +} + +/* ************************************************************************ */ +HybridFactor::HybridFactor(const KeyVector &keys) + : Base(keys), isContinuous_(true), nrContinuous_(keys.size()) {} + +/* ************************************************************************ */ +HybridFactor::HybridFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys) + : Base(CollectKeys(continuousKeys, discreteKeys)), + isDiscrete_((continuousKeys.size() == 0) && (discreteKeys.size() != 0)), + isContinuous_((continuousKeys.size() != 0) && (discreteKeys.size() == 0)), + isHybrid_((continuousKeys.size() != 0) && (discreteKeys.size() != 0)), + nrContinuous_(continuousKeys.size()), + discreteKeys_(discreteKeys) {} + +/* ************************************************************************ */ +HybridFactor::HybridFactor(const DiscreteKeys &discreteKeys) + : Base(CollectKeys({}, discreteKeys)), + isDiscrete_(true), + discreteKeys_(discreteKeys) {} + +/* ************************************************************************ */ +bool HybridFactor::equals(const HybridFactor &lf, double tol) const { + const This *e = dynamic_cast(&lf); + return e != nullptr && Base::equals(*e, tol) && + isDiscrete_ == e->isDiscrete_ && isContinuous_ == e->isContinuous_ && + isHybrid_ == e->isHybrid_ && nrContinuous_ == e->nrContinuous_; +} + +/* ************************************************************************ */ +void HybridFactor::print(const std::string &s, + const KeyFormatter &formatter) const { + std::cout << s; + if (isContinuous_) std::cout << "Continuous "; + if (isDiscrete_) std::cout << "Discrete "; + if (isHybrid_) std::cout << "Hybrid "; + this->printKeys("", formatter); +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h new file mode 100644 index 0000000000..244fba4ccb --- /dev/null +++ b/gtsam/hybrid/HybridFactor.h @@ -0,0 +1,133 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridFactor.h + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +namespace gtsam { + +KeyVector CollectKeys(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys); +KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); +DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1, + const DiscreteKeys &key2); + +/** + * Base class for hybrid probabilistic factors + * + * Examples: + * - HybridGaussianFactor + * - HybridDiscreteFactor + * - GaussianMixtureFactor + * - GaussianMixture + */ +class GTSAM_EXPORT HybridFactor : public Factor { + private: + bool isDiscrete_ = false; + bool isContinuous_ = false; + bool isHybrid_ = false; + + size_t nrContinuous_ = 0; + + protected: + DiscreteKeys discreteKeys_; + + public: + // typedefs needed to play nice with gtsam + typedef HybridFactor This; ///< This class + typedef boost::shared_ptr + shared_ptr; ///< shared_ptr to this class + typedef Factor Base; ///< Our base class + + /// @name Standard Constructors + /// @{ + + /** Default constructor creates empty factor */ + HybridFactor() = default; + + /** + * @brief Construct hybrid factor from continuous keys. + * + * @param keys Vector of continuous keys. + */ + explicit HybridFactor(const KeyVector &keys); + + /** + * @brief Construct hybrid factor from discrete keys. + * + * @param keys Vector of discrete keys. + */ + explicit HybridFactor(const DiscreteKeys &discreteKeys); + + /** + * @brief Construct a new Hybrid Factor object. + * + * @param continuousKeys Vector of keys for continuous variables. + * @param discreteKeys Vector of keys for discrete variables. + */ + HybridFactor(const KeyVector &continuousKeys, + const DiscreteKeys &discreteKeys); + + /// Virtual destructor + virtual ~HybridFactor() = default; + + /// @} + /// @name Testable + /// @{ + + /// equals + virtual bool equals(const HybridFactor &lf, double tol = 1e-9) const; + + /// print + void print( + const std::string &s = "HybridFactor\n", + const KeyFormatter &formatter = DefaultKeyFormatter) const override; + + /// @} + /// @name Standard Interface + /// @{ + + /// True if this is a factor of discrete variables only. + bool isDiscrete() const { return isDiscrete_; } + + /// True if this is a factor of continuous variables only. + bool isContinuous() const { return isContinuous_; } + + /// True is this is a Discrete-Continuous factor. + bool isHybrid() const { return isHybrid_; } + + /// Return the number of continuous variables in this factor. + size_t nrContinuous() const { return nrContinuous_; } + + /// Return vector of discrete keys. + DiscreteKeys discreteKeys() const { return discreteKeys_; } + + /// @} +}; +// HybridFactor + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.cpp b/gtsam/hybrid/HybridGaussianFactor.cpp new file mode 100644 index 0000000000..59d20fb794 --- /dev/null +++ b/gtsam/hybrid/HybridGaussianFactor.cpp @@ -0,0 +1,47 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridGaussianFactor.cpp + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#include + +#include + +namespace gtsam { + +/* ************************************************************************* */ +HybridGaussianFactor::HybridGaussianFactor(GaussianFactor::shared_ptr other) + : Base(other->keys()), inner_(other) {} + +/* ************************************************************************* */ +HybridGaussianFactor::HybridGaussianFactor(JacobianFactor &&jf) + : Base(jf.keys()), + inner_(boost::make_shared(std::move(jf))) {} + +/* ************************************************************************* */ +bool HybridGaussianFactor::equals(const HybridFactor &other, double tol) const { + const This *e = dynamic_cast(&other); + // TODO(Varun) How to compare inner_ when they are abstract types? + return e != nullptr && Base::equals(*e, tol); +} + +/* ************************************************************************* */ +void HybridGaussianFactor::print(const std::string &s, + const KeyFormatter &formatter) const { + HybridFactor::print(s, formatter); + inner_->print("inner: ", formatter); +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactor.h b/gtsam/hybrid/HybridGaussianFactor.h new file mode 100644 index 0000000000..2a92c717c2 --- /dev/null +++ b/gtsam/hybrid/HybridGaussianFactor.h @@ -0,0 +1,67 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridGaussianFactor.h + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + +/** + * A HybridGaussianFactor is a layer over GaussianFactor so that we do not have + * a diamond inheritance i.e. an extra factor type that inherits from both + * HybridFactor and GaussianFactor. + */ +class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor { + private: + GaussianFactor::shared_ptr inner_; + + public: + using Base = HybridFactor; + using This = HybridGaussianFactor; + using shared_ptr = boost::shared_ptr; + + // Explicit conversion from a shared ptr of GF + explicit HybridGaussianFactor(GaussianFactor::shared_ptr other); + + // Forwarding constructor from concrete JacobianFactor + explicit HybridGaussianFactor(JacobianFactor &&jf); + + public: + /// @name Testable + /// @{ + + /// Check equality. + virtual bool equals(const HybridFactor &lf, double tol) const override; + + /// GTSAM print utility. + void print( + const std::string &s = "HybridFactor\n", + const KeyFormatter &formatter = DefaultKeyFormatter) const override; + + /// @} + + GaussianFactor::shared_ptr inner() const { return inner_; } +}; + +// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp new file mode 100644 index 0000000000..88730cae95 --- /dev/null +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -0,0 +1,369 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridGaussianFactorGraph.cpp + * @brief Hybrid factor graph that uses type erasure + * @author Fan Jiang + * @author Varun Agrawal + * @author Frank Dellaert + * @date Mar 11, 2022 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace gtsam { + +template class EliminateableFactorGraph; + +/* ************************************************************************ */ +static GaussianMixtureFactor::Sum &addGaussian( + GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { + using Y = GaussianFactorGraph; + // If the decision tree is not intiialized, then intialize it. + if (sum.empty()) { + GaussianFactorGraph result; + result.push_back(factor); + sum = GaussianMixtureFactor::Sum(result); + + } else { + auto add = [&factor](const Y &graph) { + auto result = graph; + result.push_back(factor); + return result; + }; + sum = sum.apply(add); + } + return sum; +} + +/* ************************************************************************ */ +GaussianMixtureFactor::Sum sumFrontals( + const HybridGaussianFactorGraph &factors) { + // sum out frontals, this is the factor on the separator + gttic(sum); + + GaussianMixtureFactor::Sum sum; + std::vector deferredFactors; + + for (auto &f : factors) { + if (f->isHybrid()) { + if (auto cgmf = boost::dynamic_pointer_cast(f)) { + sum = cgmf->add(sum); + } + + if (auto gm = boost::dynamic_pointer_cast(f)) { + sum = gm->asMixture()->add(sum); + } + + } else if (f->isContinuous()) { + deferredFactors.push_back( + boost::dynamic_pointer_cast(f)->inner()); + } else { + // We need to handle the case where the object is actually an + // BayesTreeOrphanWrapper! + auto orphan = boost::dynamic_pointer_cast< + BayesTreeOrphanWrapper>(f); + if (!orphan) { + auto &fr = *f; + throw std::invalid_argument( + std::string("factor is discrete in continuous elimination") + + typeid(fr).name()); + } + } + } + + for (auto &f : deferredFactors) { + sum = addGaussian(sum, f); + } + + gttoc(sum); + + return sum; +} + +/* ************************************************************************ */ +std::pair +continuousElimination(const HybridGaussianFactorGraph &factors, + const Ordering &frontalKeys) { + GaussianFactorGraph gfg; + for (auto &fp : factors) { + if (auto ptr = boost::dynamic_pointer_cast(fp)) { + gfg.push_back(ptr->inner()); + } else if (auto p = + boost::static_pointer_cast(fp)->inner()) { + gfg.push_back(boost::static_pointer_cast(p)); + } else { + // It is an orphan wrapped conditional + } + } + + auto result = EliminatePreferCholesky(gfg, frontalKeys); + return {boost::make_shared(result.first), + boost::make_shared(result.second)}; +} + +/* ************************************************************************ */ +std::pair +discreteElimination(const HybridGaussianFactorGraph &factors, + const Ordering &frontalKeys) { + DiscreteFactorGraph dfg; + for (auto &fp : factors) { + if (auto ptr = boost::dynamic_pointer_cast(fp)) { + dfg.push_back(ptr->inner()); + } else if (auto p = + boost::static_pointer_cast(fp)->inner()) { + dfg.push_back(boost::static_pointer_cast(p)); + } else { + // It is an orphan wrapper + } + } + + auto result = EliminateDiscrete(dfg, frontalKeys); + + return {boost::make_shared(result.first), + boost::make_shared(result.second)}; +} + +/* ************************************************************************ */ +std::pair +hybridElimination(const HybridGaussianFactorGraph &factors, + const Ordering &frontalKeys, + const KeySet &continuousSeparator, + const std::set &discreteSeparatorSet) { + // NOTE: since we use the special JunctionTree, + // only possiblity is continuous conditioned on discrete. + DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), + discreteSeparatorSet.end()); + + // sum out frontals, this is the factor on the separator + GaussianMixtureFactor::Sum sum = sumFrontals(factors); + + using EliminationPair = GaussianFactorGraph::EliminationResult; + + KeyVector keysOfEliminated; // Not the ordering + KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? + + // This is the elimination method on the leaf nodes + auto eliminate = [&](const GaussianFactorGraph &graph) + -> GaussianFactorGraph::EliminationResult { + if (graph.empty()) { + return {nullptr, nullptr}; + } + auto result = EliminatePreferCholesky(graph, frontalKeys); + if (keysOfEliminated.empty()) { + keysOfEliminated = + result.first->keys(); // Initialize the keysOfEliminated to be the + } + // keysOfEliminated of the GaussianConditional + if (keysOfSeparator.empty()) { + keysOfSeparator = result.second->keys(); + } + return result; + }; + + // Perform elimination! + DecisionTree eliminationResults(sum, eliminate); + + // Separate out decision tree into conditionals and remaining factors. + auto pair = unzip(eliminationResults); + + const GaussianMixtureFactor::Factors &separatorFactors = pair.second; + + // Create the GaussianMixture from the conditionals + auto conditional = boost::make_shared( + frontalKeys, keysOfSeparator, discreteSeparator, pair.first); + + // If there are no more continuous parents, then we should create here a + // DiscreteFactor, with the error for each discrete choice. + if (keysOfSeparator.empty()) { + VectorValues empty_values; + auto factorError = [&](const GaussianFactor::shared_ptr &factor) { + if (!factor) return 0.0; // TODO(fan): does this make sense? + return exp(-factor->error(empty_values)); + }; + DecisionTree fdt(separatorFactors, factorError); + auto discreteFactor = + boost::make_shared(discreteSeparator, fdt); + + return {boost::make_shared(conditional), + boost::make_shared(discreteFactor)}; + + } else { + // Create a resulting DCGaussianMixture on the separator. + auto factor = boost::make_shared( + KeyVector(continuousSeparator.begin(), continuousSeparator.end()), + discreteSeparator, separatorFactors); + return {boost::make_shared(conditional), factor}; + } +} +/* ************************************************************************ */ +std::pair // +EliminateHybrid(const HybridGaussianFactorGraph &factors, + const Ordering &frontalKeys) { + // NOTE: Because we are in the Conditional Gaussian regime there are only + // a few cases: + // 1. continuous variable, make a Gaussian Mixture if there are hybrid + // factors; + // 2. continuous variable, we make a Gaussian Factor if there are no hybrid + // factors; + // 3. discrete variable, no continuous factor is allowed + // (escapes Conditional Gaussian regime), if discrete only we do the discrete + // elimination. + + // However it is not that simple. During elimination it is possible that the + // multifrontal needs to eliminate an ordering that contains both Gaussian and + // hybrid variables, for example x1, c1. + // In this scenario, we will have a density P(x1, c1) that is a Conditional + // Linear Gaussian P(x1|c1)P(c1) (see Murphy02). + + // The issue here is that, how can we know which variable is discrete if we + // unify Values? Obviously we can tell using the factors, but is that fast? + + // In the case of multifrontal, we will need to use a constrained ordering + // so that the discrete parts will be guaranteed to be eliminated last! + // Because of all these reasons, we carefully consider how to + // implement the hybrid factors so that we do not get poor performance. + + // The first thing is how to represent the GaussianMixture. + // A very possible scenario is that the incoming factors will have different + // levels of discrete keys. For example, imagine we are going to eliminate the + // fragment: $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid. + // Now we will need to know how to retrieve the corresponding continuous + // densities for the assignment (c1,c2,c3) (OR (c2,c3,c1), note there is NO + // defined order!). We also need to consider when there is pruning. Two + // mixture factors could have different pruning patterns - one could have + // (c1=0,c2=1) pruned, and another could have (c2=0,c3=1) pruned, and this + // creates a big problem in how to identify the intersection of non-pruned + // branches. + + // Our approach is first building the collection of all discrete keys. After + // that we enumerate the space of all key combinations *lazily* so that the + // exploration branch terminates whenever an assignment yields NULL in any of + // the hybrid factors. + + // When the number of assignments is large we may encounter stack overflows. + // However this is also the case with iSAM2, so no pressure :) + + // PREPROCESS: Identify the nature of the current elimination + std::unordered_map mapFromKeyToDiscreteKey; + std::set discreteSeparatorSet; + std::set discreteFrontals; + + KeySet separatorKeys; + KeySet allContinuousKeys; + KeySet continuousFrontals; + KeySet continuousSeparator; + + // This initializes separatorKeys and mapFromKeyToDiscreteKey + for (auto &&factor : factors) { + separatorKeys.insert(factor->begin(), factor->end()); + if (!factor->isContinuous()) { + for (auto &k : factor->discreteKeys()) { + mapFromKeyToDiscreteKey[k.first] = k; + } + } + } + + // remove frontals from separator + for (auto &k : frontalKeys) { + separatorKeys.erase(k); + } + + // Fill in discrete frontals and continuous frontals for the end result + for (auto &k : frontalKeys) { + if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { + discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k)); + } else { + continuousFrontals.insert(k); + allContinuousKeys.insert(k); + } + } + + // Fill in discrete frontals and continuous frontals for the end result + for (auto &k : separatorKeys) { + if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { + discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); + } else { + continuousSeparator.insert(k); + allContinuousKeys.insert(k); + } + } + + // NOTE: We should really defer the product here because of pruning + + // Case 1: we are only dealing with continuous + if (mapFromKeyToDiscreteKey.empty() && !allContinuousKeys.empty()) { + return continuousElimination(factors, frontalKeys); + } + + // Case 2: we are only dealing with discrete + if (allContinuousKeys.empty()) { + return discreteElimination(factors, frontalKeys); + } + + // Case 3: We are now in the hybrid land! + return hybridElimination(factors, frontalKeys, continuousSeparator, + discreteSeparatorSet); +} + +/* ************************************************************************ */ +void HybridGaussianFactorGraph::add(JacobianFactor &&factor) { + FactorGraph::add(boost::make_shared(std::move(factor))); +} + +/* ************************************************************************ */ +void HybridGaussianFactorGraph::add(JacobianFactor::shared_ptr factor) { + FactorGraph::add(boost::make_shared(factor)); +} + +/* ************************************************************************ */ +void HybridGaussianFactorGraph::add(DecisionTreeFactor &&factor) { + FactorGraph::add(boost::make_shared(std::move(factor))); +} + +/* ************************************************************************ */ +void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) { + FactorGraph::add(boost::make_shared(factor)); +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h new file mode 100644 index 0000000000..0188aa652c --- /dev/null +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -0,0 +1,118 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridGaussianFactorGraph.h + * @brief Linearized Hybrid factor graph that uses type erasure + * @author Fan Jiang + * @date Mar 11, 2022 + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + +// Forward declarations +class HybridGaussianFactorGraph; +class HybridConditional; +class HybridBayesNet; +class HybridEliminationTree; +class HybridBayesTree; +class HybridJunctionTree; +class DecisionTreeFactor; + +class JacobianFactor; + +/** Main elimination function for HybridGaussianFactorGraph */ +GTSAM_EXPORT +std::pair, HybridFactor::shared_ptr> +EliminateHybrid(const HybridGaussianFactorGraph& factors, const Ordering& keys); + +/* ************************************************************************* */ +template <> +struct EliminationTraits { + typedef HybridFactor FactorType; ///< Type of factors in factor graph + typedef HybridGaussianFactorGraph + FactorGraphType; ///< Type of the factor graph (e.g. + ///< HybridGaussianFactorGraph) + typedef HybridConditional + ConditionalType; ///< Type of conditionals from elimination + typedef HybridBayesNet + BayesNetType; ///< Type of Bayes net from sequential elimination + typedef HybridEliminationTree + EliminationTreeType; ///< Type of elimination tree + typedef HybridBayesTree BayesTreeType; ///< Type of Bayes tree + typedef HybridJunctionTree + JunctionTreeType; ///< Type of Junction tree + /// The default dense elimination function + static std::pair, + boost::shared_ptr > + DefaultEliminate(const FactorGraphType& factors, const Ordering& keys) { + return EliminateHybrid(factors, keys); + } +}; + +/** + * Gaussian Hybrid Factor Graph + * ----------------------- + * This is the linearized version of a hybrid factor graph. + * Everything inside needs to be hybrid factor or hybrid conditional. + */ +class GTSAM_EXPORT HybridGaussianFactorGraph + : public FactorGraph, + public EliminateableFactorGraph { + public: + using Base = FactorGraph; + using This = HybridGaussianFactorGraph; ///< this class + using BaseEliminateable = + EliminateableFactorGraph; ///< for elimination + using shared_ptr = boost::shared_ptr; ///< shared_ptr to This + + using Values = gtsam::Values; ///< backwards compatibility + using Indices = KeyVector; ///> map from keys to values + + /// @name Constructors + /// @{ + + HybridGaussianFactorGraph() = default; + + /** + * Implicit copy/downcast constructor to override explicit template container + * constructor. In BayesTree this is used for: + * `cachedSeparatorMarginal_.reset(*separatorMarginal)` + * */ + template + HybridGaussianFactorGraph(const FactorGraph& graph) + : Base(graph) {} + + /// @} + + using FactorGraph::add; + + /// Add a Jacobian factor to the factor graph. + void add(JacobianFactor&& factor); + + /// Add a Jacobian factor as a shared ptr. + void add(boost::shared_ptr factor); + + /// Add a DecisionTreeFactor to the factor graph. + void add(DecisionTreeFactor&& factor); + + /// Add a DecisionTreeFactor as a shared ptr. + void add(boost::shared_ptr factor); +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianISAM.cpp b/gtsam/hybrid/HybridGaussianISAM.cpp new file mode 100644 index 0000000000..7783a88ddc --- /dev/null +++ b/gtsam/hybrid/HybridGaussianISAM.cpp @@ -0,0 +1,101 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridGaussianISAM.h + * @date March 31, 2022 + * @author Fan Jiang + * @author Frank Dellaert + * @author Richard Roberts + */ + +#include +#include +#include +#include +#include + +#include + +namespace gtsam { + +// Instantiate base class +// template class ISAM; + +/* ************************************************************************* */ +HybridGaussianISAM::HybridGaussianISAM() {} + +/* ************************************************************************* */ +HybridGaussianISAM::HybridGaussianISAM(const HybridBayesTree& bayesTree) + : Base(bayesTree) {} + +/* ************************************************************************* */ +void HybridGaussianISAM::updateInternal( + const HybridGaussianFactorGraph& newFactors, + HybridBayesTree::Cliques* orphans, + const HybridBayesTree::Eliminate& function) { + // Remove the contaminated part of the Bayes tree + BayesNetType bn; + const KeySet newFactorKeys = newFactors.keys(); + if (!this->empty()) { + KeyVector keyVector(newFactorKeys.begin(), newFactorKeys.end()); + this->removeTop(keyVector, &bn, orphans); + } + + // Add the removed top and the new factors + FactorGraphType factors; + factors += bn; + factors += newFactors; + + // Add the orphaned subtrees + for (const sharedClique& orphan : *orphans) + factors += boost::make_shared >(orphan); + + KeySet allDiscrete; + for (auto& factor : factors) { + for (auto& k : factor->discreteKeys()) { + allDiscrete.insert(k.first); + } + } + KeyVector newKeysDiscreteLast; + for (auto& k : newFactorKeys) { + if (!allDiscrete.exists(k)) { + newKeysDiscreteLast.push_back(k); + } + } + std::copy(allDiscrete.begin(), allDiscrete.end(), + std::back_inserter(newKeysDiscreteLast)); + + // KeyVector new + + // Get an ordering where the new keys are eliminated last + const VariableIndex index(factors); + const Ordering ordering = Ordering::ColamdConstrainedLast( + index, KeyVector(newKeysDiscreteLast.begin(), newKeysDiscreteLast.end()), + true); + + // eliminate all factors (top, added, orphans) into a new Bayes tree + auto bayesTree = factors.eliminateMultifrontal(ordering, function, index); + + // Re-add into Bayes tree data structures + this->roots_.insert(this->roots_.end(), bayesTree->roots().begin(), + bayesTree->roots().end()); + this->nodes_.insert(bayesTree->nodes().begin(), bayesTree->nodes().end()); +} + +/* ************************************************************************* */ +void HybridGaussianISAM::update(const HybridGaussianFactorGraph& newFactors, + const HybridBayesTree::Eliminate& function) { + Cliques orphans; + this->updateInternal(newFactors, &orphans, function); +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridGaussianISAM.h b/gtsam/hybrid/HybridGaussianISAM.h new file mode 100644 index 0000000000..d5b6271da7 --- /dev/null +++ b/gtsam/hybrid/HybridGaussianISAM.h @@ -0,0 +1,70 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridGaussianISAM.h + * @date March 31, 2022 + * @author Fan Jiang + * @author Frank Dellaert + * @author Richard Roberts + */ + +#pragma once + +#include +#include +#include +#include + +namespace gtsam { + +class GTSAM_EXPORT HybridGaussianISAM : public ISAM { + public: + typedef ISAM Base; + typedef HybridGaussianISAM This; + typedef boost::shared_ptr shared_ptr; + + /// @name Standard Constructors + /// @{ + + /** Create an empty Bayes Tree */ + HybridGaussianISAM(); + + /** Copy constructor */ + HybridGaussianISAM(const HybridBayesTree& bayesTree); + + /// @} + + private: + /// Internal method that performs the ISAM update. + void updateInternal( + const HybridGaussianFactorGraph& newFactors, + HybridBayesTree::Cliques* orphans, + const HybridBayesTree::Eliminate& function = + HybridBayesTree::EliminationTraitsType::DefaultEliminate); + + public: + /** + * @brief Perform update step with new factors. + * + * @param newFactors Factor graph of new factors to add and eliminate. + * @param function Elimination function. + */ + void update(const HybridGaussianFactorGraph& newFactors, + const HybridBayesTree::Eliminate& function = + HybridBayesTree::EliminationTraitsType::DefaultEliminate); +}; + +/// traits +template <> +struct traits : public Testable {}; + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp new file mode 100644 index 0000000000..7725742cf6 --- /dev/null +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -0,0 +1,173 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridJunctionTree.cpp + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#include +#include +#include +#include +#include + +#include + +namespace gtsam { + +// Instantiate base classes +template class EliminatableClusterTree; +template class JunctionTree; + +struct HybridConstructorTraversalData { + typedef + typename JunctionTree::Node + Node; + typedef + typename JunctionTree::sharedNode sharedNode; + + HybridConstructorTraversalData* const parentData; + sharedNode junctionTreeNode; + FastVector childSymbolicConditionals; + FastVector childSymbolicFactors; + KeySet discreteKeys; + + // Small inner class to store symbolic factors + class SymbolicFactors : public FactorGraph {}; + + HybridConstructorTraversalData(HybridConstructorTraversalData* _parentData) + : parentData(_parentData) {} + + // Pre-order visitor function + static HybridConstructorTraversalData ConstructorTraversalVisitorPre( + const boost::shared_ptr& node, + HybridConstructorTraversalData& parentData) { + // On the pre-order pass, before children have been visited, we just set up + // a traversal data structure with its own JT node, and create a child + // pointer in its parent. + HybridConstructorTraversalData data = + HybridConstructorTraversalData(&parentData); + data.junctionTreeNode = boost::make_shared(node->key, node->factors); + parentData.junctionTreeNode->addChild(data.junctionTreeNode); + + for (HybridFactor::shared_ptr& f : node->factors) { + for (auto& k : f->discreteKeys()) { + data.discreteKeys.insert(k.first); + } + } + + return data; + } + + // Post-order visitor function + static void ConstructorTraversalVisitorPostAlg2( + const boost::shared_ptr& ETreeNode, + const HybridConstructorTraversalData& data) { + // In this post-order visitor, we combine the symbolic elimination results + // from the elimination tree children and symbolically eliminate the current + // elimination tree node. We then check whether each of our elimination + // tree child nodes should be merged with us. The check for this is that + // our number of symbolic elimination parents is exactly 1 less than + // our child's symbolic elimination parents - this condition indicates that + // eliminating the current node did not introduce any parents beyond those + // already in the child-> + + // Do symbolic elimination for this node + SymbolicFactors symbolicFactors; + symbolicFactors.reserve(ETreeNode->factors.size() + + data.childSymbolicFactors.size()); + // Add ETree node factors + symbolicFactors += ETreeNode->factors; + // Add symbolic factors passed up from children + symbolicFactors += data.childSymbolicFactors; + + Ordering keyAsOrdering; + keyAsOrdering.push_back(ETreeNode->key); + SymbolicConditional::shared_ptr conditional; + SymbolicFactor::shared_ptr separatorFactor; + boost::tie(conditional, separatorFactor) = + internal::EliminateSymbolic(symbolicFactors, keyAsOrdering); + + // Store symbolic elimination results in the parent + data.parentData->childSymbolicConditionals.push_back(conditional); + data.parentData->childSymbolicFactors.push_back(separatorFactor); + data.parentData->discreteKeys.merge(data.discreteKeys); + + sharedNode node = data.junctionTreeNode; + const FastVector& childConditionals = + data.childSymbolicConditionals; + node->problemSize_ = (int)(conditional->size() * symbolicFactors.size()); + + // Merge our children if they are in our clique - if our conditional has + // exactly one fewer parent than our child's conditional. + const size_t nrParents = conditional->nrParents(); + const size_t nrChildren = node->nrChildren(); + assert(childConditionals.size() == nrChildren); + + // decide which children to merge, as index into children + std::vector nrChildrenFrontals = node->nrFrontalsOfChildren(); + std::vector merge(nrChildren, false); + size_t nrFrontals = 1; + for (size_t i = 0; i < nrChildren; i++) { + // Check if we should merge the i^th child + if (nrParents + nrFrontals == childConditionals[i]->nrParents()) { + const bool myType = + data.discreteKeys.exists(conditional->frontals()[0]); + const bool theirType = + data.discreteKeys.exists(childConditionals[i]->frontals()[0]); + + if (myType == theirType) { + // Increment number of frontal variables + nrFrontals += nrChildrenFrontals[i]; + merge[i] = true; + } + } + } + + // now really merge + node->mergeChildren(merge); + } +}; + +/* ************************************************************************* */ +HybridJunctionTree::HybridJunctionTree( + const HybridEliminationTree& eliminationTree) { + gttic(JunctionTree_FromEliminationTree); + // Here we rely on the BayesNet having been produced by this elimination tree, + // such that the conditionals are arranged in DFS post-order. We traverse the + // elimination tree, and inspect the symbolic conditional corresponding to + // each node. The elimination tree node is added to the same clique with its + // parent if it has exactly one more Bayes net conditional parent than + // does its elimination tree parent. + + // Traverse the elimination tree, doing symbolic elimination and merging nodes + // as we go. Gather the created junction tree roots in a dummy Node. + typedef HybridConstructorTraversalData Data; + Data rootData(0); + rootData.junctionTreeNode = + boost::make_shared(); // Make a dummy node to gather + // the junction tree roots + treeTraversal::DepthFirstForest(eliminationTree, rootData, + Data::ConstructorTraversalVisitorPre, + Data::ConstructorTraversalVisitorPostAlg2); + + // Assign roots from the dummy node + this->addChildrenAsRoots(rootData.junctionTreeNode); + + // Transfer remaining factors from elimination tree + Base::remainingFactors_ = eliminationTree.remainingFactors(); +} + +} // namespace gtsam diff --git a/gtsam/hybrid/HybridJunctionTree.h b/gtsam/hybrid/HybridJunctionTree.h new file mode 100644 index 0000000000..cad1e15a1e --- /dev/null +++ b/gtsam/hybrid/HybridJunctionTree.h @@ -0,0 +1,71 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file HybridJunctionTree.h + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#pragma once + +#include +#include +#include + +namespace gtsam { + +// Forward declarations +class HybridEliminationTree; + +/** + * An EliminatableClusterTree, i.e., a set of variable clusters with factors, + * arranged in a tree, with the additional property that it represents the + * clique tree associated with a Bayes net. + * + * In GTSAM a junction tree is an intermediate data structure in multifrontal + * variable elimination. Each node is a cluster of factors, along with a + * clique of variables that are eliminated all at once. In detail, every node k + * represents a clique (maximal fully connected subset) of an associated chordal + * graph, such as a chordal Bayes net resulting from elimination. + * + * The difference with the BayesTree is that a JunctionTree stores factors, + * whereas a BayesTree stores conditionals, that are the product of eliminating + * the factors in the corresponding JunctionTree cliques. + * + * The tree structure and elimination method are exactly analogous to the + * EliminationTree, except that in the JunctionTree, at each node multiple + * variables are eliminated at a time. + * + * \addtogroup Multifrontal + * \nosubgrouping + */ +class GTSAM_EXPORT HybridJunctionTree + : public JunctionTree { + public: + typedef JunctionTree + Base; ///< Base class + typedef HybridJunctionTree This; ///< This class + typedef boost::shared_ptr shared_ptr; ///< Shared pointer to this class + + /** + * Build the elimination tree of a factor graph using precomputed column + * structure. + * @param factorGraph The factor graph for which to build the elimination tree + * @param structure The set of factors involving each variable. If this is + * not precomputed, you can call the Create(const FactorGraph&) + * named constructor instead. + * @return The elimination tree + */ + HybridJunctionTree(const HybridEliminationTree& eliminationTree); +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/hybrid.i b/gtsam/hybrid/hybrid.i new file mode 100644 index 0000000000..bbe1e2400d --- /dev/null +++ b/gtsam/hybrid/hybrid.i @@ -0,0 +1,145 @@ +//************************************************************************* +// hybrid +//************************************************************************* + +namespace gtsam { + +#include +virtual class HybridFactor { + void print(string s = "HybridFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridFactor& other, double tol = 1e-9) const; + bool empty() const; + size_t size() const; + gtsam::KeyVector keys() const; +}; + +#include +virtual class HybridConditional { + void print(string s = "Hybrid Conditional\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridConditional& other, double tol = 1e-9) const; + size_t nrFrontals() const; + size_t nrParents() const; + Factor* inner(); +}; + +#include +class GaussianMixtureFactor : gtsam::HybridFactor { + static GaussianMixtureFactor FromFactors( + const gtsam::KeyVector& continuousKeys, + const gtsam::DiscreteKeys& discreteKeys, + const std::vector& factorsList); + + void print(string s = "GaussianMixtureFactor\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +class GaussianMixture : gtsam::HybridFactor { + static GaussianMixture FromConditionals( + const gtsam::KeyVector& continuousFrontals, + const gtsam::KeyVector& continuousParents, + const gtsam::DiscreteKeys& discreteParents, + const std::vector& + conditionalsList); + + void print(string s = "GaussianMixture\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +#include +class HybridBayesTreeClique { + HybridBayesTreeClique(); + HybridBayesTreeClique(const gtsam::HybridConditional* conditional); + const gtsam::HybridConditional* conditional() const; + bool isRoot() const; + // double evaluate(const gtsam::HybridValues& values) const; +}; + +#include +class HybridBayesTree { + HybridBayesTree(); + void print(string s = "HybridBayesTree\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridBayesTree& other, double tol = 1e-9) const; + + size_t size() const; + bool empty() const; + const HybridBayesTreeClique* operator[](size_t j) const; + + string dot(const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; +}; + +class HybridBayesNet { + HybridBayesNet(); + void add(const gtsam::HybridConditional& s); + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::HybridConditional* at(size_t i) const; + void print(string s = "HybridBayesNet\n", + const gtsam::KeyFormatter& keyFormatter = + gtsam::DefaultKeyFormatter) const; + bool equals(const gtsam::HybridBayesNet& other, double tol = 1e-9) const; + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; + void saveGraph( + string s, + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; +}; + +#include +class HybridGaussianFactorGraph { + HybridGaussianFactorGraph(); + HybridGaussianFactorGraph(const gtsam::HybridBayesNet& bayesNet); + + // Building the graph + void push_back(const gtsam::HybridFactor* factor); + void push_back(const gtsam::HybridConditional* conditional); + void push_back(const gtsam::HybridGaussianFactorGraph& graph); + void push_back(const gtsam::HybridBayesNet& bayesNet); + void push_back(const gtsam::HybridBayesTree& bayesTree); + void push_back(const gtsam::GaussianMixtureFactor* gmm); + + void add(gtsam::DecisionTreeFactor* factor); + void add(gtsam::JacobianFactor* factor); + + bool empty() const; + size_t size() const; + gtsam::KeySet keys() const; + const gtsam::HybridFactor* at(size_t i) const; + + void print(string s = "") const; + bool equals(const gtsam::HybridGaussianFactorGraph& fg, double tol = 1e-9) const; + + gtsam::HybridBayesNet* eliminateSequential(); + gtsam::HybridBayesNet* eliminateSequential( + gtsam::Ordering::OrderingType type); + gtsam::HybridBayesNet* eliminateSequential(const gtsam::Ordering& ordering); + pair + eliminatePartialSequential(const gtsam::Ordering& ordering); + + gtsam::HybridBayesTree* eliminateMultifrontal(); + gtsam::HybridBayesTree* eliminateMultifrontal( + gtsam::Ordering::OrderingType type); + gtsam::HybridBayesTree* eliminateMultifrontal( + const gtsam::Ordering& ordering); + pair + eliminatePartialMultifrontal(const gtsam::Ordering& ordering); + + string dot( + const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, + const gtsam::DotWriter& writer = gtsam::DotWriter()) const; +}; + +} // namespace gtsam diff --git a/gtsam/hybrid/tests/CMakeLists.txt b/gtsam/hybrid/tests/CMakeLists.txt new file mode 100644 index 0000000000..06ad2c5051 --- /dev/null +++ b/gtsam/hybrid/tests/CMakeLists.txt @@ -0,0 +1 @@ +gtsamAddTestsGlob(hybrid "test*.cpp" "" "gtsam") diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h new file mode 100644 index 0000000000..c081b8e87e --- /dev/null +++ b/gtsam/hybrid/tests/Switching.h @@ -0,0 +1,86 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file Switching.h + * @date Mar 11, 2022 + * @author Varun Agrawal + * @author Fan Jiang + */ + +#include +#include +#include +#include +#include +#include + +#pragma once + +using gtsam::symbol_shorthand::C; +using gtsam::symbol_shorthand::X; + +namespace gtsam { +inline HybridGaussianFactorGraph::shared_ptr makeSwitchingChain( + size_t n, std::function keyFunc = X, + std::function dKeyFunc = C) { + HybridGaussianFactorGraph hfg; + + hfg.add(JacobianFactor(keyFunc(1), I_3x3, Z_3x1)); + + // keyFunc(1) to keyFunc(n+1) + for (size_t t = 1; t < n; t++) { + hfg.add(GaussianMixtureFactor::FromFactors( + {keyFunc(t), keyFunc(t + 1)}, {{dKeyFunc(t), 2}}, + {boost::make_shared(keyFunc(t), I_3x3, keyFunc(t + 1), + I_3x3, Z_3x1), + boost::make_shared(keyFunc(t), I_3x3, keyFunc(t + 1), + I_3x3, Vector3::Ones())})); + + if (t > 1) { + hfg.add(DecisionTreeFactor({{dKeyFunc(t - 1), 2}, {dKeyFunc(t), 2}}, + "0 1 1 3")); + } + } + + return boost::make_shared(std::move(hfg)); +} + +inline std::pair> makeBinaryOrdering( + std::vector &input) { + KeyVector new_order; + std::vector levels(input.size()); + std::function::iterator, std::vector::iterator, + int)> + bsg = [&bsg, &new_order, &levels, &input]( + std::vector::iterator begin, + std::vector::iterator end, int lvl) { + if (std::distance(begin, end) > 1) { + std::vector::iterator pivot = + begin + std::distance(begin, end) / 2; + + new_order.push_back(*pivot); + levels[std::distance(input.begin(), pivot)] = lvl; + bsg(begin, pivot, lvl + 1); + bsg(pivot + 1, end, lvl + 1); + } else if (std::distance(begin, end) == 1) { + new_order.push_back(*begin); + levels[std::distance(input.begin(), begin)] = lvl; + } + }; + + bsg(input.begin(), input.end(), 0); + std::reverse(new_order.begin(), new_order.end()); + // std::reverse(levels.begin(), levels.end()); + return {new_order, levels}; +} + +} // namespace gtsam diff --git a/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp new file mode 100644 index 0000000000..552bb18f59 --- /dev/null +++ b/gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp @@ -0,0 +1,593 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/* + * @file testHybridGaussianFactorGraph.cpp + * @date Mar 11, 2022 + * @author Fan Jiang + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "Switching.h" + +using namespace boost::assign; + +using namespace std; +using namespace gtsam; + +using gtsam::symbol_shorthand::C; +using gtsam::symbol_shorthand::D; +using gtsam::symbol_shorthand::X; +using gtsam::symbol_shorthand::Y; + +/* ************************************************************************* */ +TEST(HybridGaussianFactorGraph, creation) { + HybridConditional test; + + HybridGaussianFactorGraph hfg; + + hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1))); + + GaussianMixture clgc( + {X(0)}, {X(1)}, DiscreteKeys(DiscreteKey{C(0), 2}), + GaussianMixture::Conditionals( + C(0), + boost::make_shared(X(0), Z_3x1, I_3x3, X(1), + I_3x3), + boost::make_shared(X(0), Vector3::Ones(), I_3x3, + X(1), I_3x3))); + GTSAM_PRINT(clgc); +} + +/* ************************************************************************* */ +TEST(HybridGaussianFactorGraph, eliminate) { + HybridGaussianFactorGraph hfg; + + hfg.add(HybridGaussianFactor(JacobianFactor(0, I_3x3, Z_3x1))); + + auto result = hfg.eliminatePartialSequential(KeyVector{0}); + + EXPECT_LONGS_EQUAL(result.first->size(), 1); +} + +/* ************************************************************************* */ +TEST(HybridGaussianFactorGraph, eliminateMultifrontal) { + HybridGaussianFactorGraph hfg; + + DiscreteKey c(C(1), 2); + + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8}))); + + Ordering ordering; + ordering.push_back(X(0)); + auto result = hfg.eliminatePartialMultifrontal(ordering); + + EXPECT_LONGS_EQUAL(result.first->size(), 1); + EXPECT_LONGS_EQUAL(result.second->size(), 1); +} + +/* ************************************************************************* */ +TEST(HybridGaussianFactorGraph, eliminateFullSequentialEqualChance) { + HybridGaussianFactorGraph hfg; + + DiscreteKey c1(C(1), 2); + + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + DecisionTree dt( + C(1), boost::make_shared(X(1), I_3x3, Z_3x1), + boost::make_shared(X(1), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); + + auto result = + hfg.eliminateSequential(Ordering::ColamdConstrainedLast(hfg, {C(1)})); + + auto dc = result->at(2)->asDiscreteConditional(); + DiscreteValues dv; + dv[C(1)] = 0; + EXPECT_DOUBLES_EQUAL(0.6225, dc->operator()(dv), 1e-3); +} + +/* ************************************************************************* */ +TEST(HybridGaussianFactorGraph, eliminateFullSequentialSimple) { + HybridGaussianFactorGraph hfg; + + DiscreteKey c1(C(1), 2); + + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + DecisionTree dt( + C(1), boost::make_shared(X(1), I_3x3, Z_3x1), + boost::make_shared(X(1), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); + // hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt)); + hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c1, {2, 8}))); + hfg.add(HybridDiscreteFactor( + DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); + // hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(2), 2}, {C(3), 2}}, "1 + // 2 3 4"))); hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(3), 2}, + // {C(1), 2}}, "1 2 2 1"))); + + auto result = hfg.eliminateSequential( + Ordering::ColamdConstrainedLast(hfg, {C(1), C(2)})); + + GTSAM_PRINT(*result); +} + +/* ************************************************************************* */ +TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) { + HybridGaussianFactorGraph hfg; + + DiscreteKey c1(C(1), 2); + + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + // DecisionTree dt( + // C(1), boost::make_shared(X(1), I_3x3, Z_3x1), + // boost::make_shared(X(1), I_3x3, Vector3::Ones())); + + // hfg.add(GaussianMixtureFactor({X(1)}, {c1}, dt)); + hfg.add(GaussianMixtureFactor::FromFactors( + {X(1)}, {{C(1), 2}}, + {boost::make_shared(X(1), I_3x3, Z_3x1), + boost::make_shared(X(1), I_3x3, Vector3::Ones())})); + + // hfg.add(GaussianMixtureFactor({X(0)}, {c1}, dt)); + hfg.add(DecisionTreeFactor(c1, {2, 8})); + hfg.add(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4")); + // hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(2), 2}, {C(3), 2}}, "1 + // 2 3 4"))); hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(3), 2}, + // {C(1), 2}}, "1 2 2 1"))); + + auto result = hfg.eliminateMultifrontal( + Ordering::ColamdConstrainedLast(hfg, {C(1), C(2)})); + + GTSAM_PRINT(*result); + GTSAM_PRINT(*result->marginalFactor(C(2))); +} + +/* ************************************************************************* */ +TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) { + HybridGaussianFactorGraph hfg; + + DiscreteKey c(C(1), 2); + + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + DecisionTree dt( + C(1), boost::make_shared(X(1), I_3x3, Z_3x1), + boost::make_shared(X(1), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor({X(1)}, {c}, dt)); + hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8}))); + // hfg.add(HybridDiscreteFactor(DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 + // 2 3 4"))); + + auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {C(1)}); + + HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full); + + GTSAM_PRINT(*hbt); + /* + Explanation: the Junction tree will need to reeliminate to get to the marginal + on X(1), which is not possible because it involves eliminating discrete before + continuous. The solution to this, however, is in Murphy02. TLDR is that this + is 1. expensive and 2. inexact. neverless it is doable. And I believe that we + should do this. + */ +} + +/* ************************************************************************* */ +/* + * This test is about how to assemble the Bayes Tree roots after we do partial + * elimination + */ +TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalTwoClique) { + HybridGaussianFactorGraph hfg; + + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + hfg.add(JacobianFactor(X(1), I_3x3, X(2), -I_3x3, Z_3x1)); + + { + // DecisionTree dt( + // C(0), boost::make_shared(X(0), I_3x3, Z_3x1), + // boost::make_shared(X(0), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor::FromFactors( + {X(0)}, {{C(0), 2}}, + {boost::make_shared(X(0), I_3x3, Z_3x1), + boost::make_shared(X(0), I_3x3, Vector3::Ones())})); + + DecisionTree dt1( + C(1), boost::make_shared(X(2), I_3x3, Z_3x1), + boost::make_shared(X(2), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor({X(2)}, {{C(1), 2}}, dt1)); + } + + // hfg.add(HybridDiscreteFactor(DecisionTreeFactor(c, {2, 8}))); + hfg.add(HybridDiscreteFactor( + DecisionTreeFactor({{C(1), 2}, {C(2), 2}}, "1 2 3 4"))); + + hfg.add(JacobianFactor(X(3), I_3x3, X(4), -I_3x3, Z_3x1)); + hfg.add(JacobianFactor(X(4), I_3x3, X(5), -I_3x3, Z_3x1)); + + { + DecisionTree dt( + C(3), boost::make_shared(X(3), I_3x3, Z_3x1), + boost::make_shared(X(3), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor({X(3)}, {{C(3), 2}}, dt)); + + DecisionTree dt1( + C(2), boost::make_shared(X(5), I_3x3, Z_3x1), + boost::make_shared(X(5), I_3x3, Vector3::Ones())); + + hfg.add(GaussianMixtureFactor({X(5)}, {{C(2), 2}}, dt1)); + } + + auto ordering_full = + Ordering::ColamdConstrainedLast(hfg, {C(0), C(1), C(2), C(3)}); + + GTSAM_PRINT(ordering_full); + + HybridBayesTree::shared_ptr hbt; + HybridGaussianFactorGraph::shared_ptr remaining; + std::tie(hbt, remaining) = hfg.eliminatePartialMultifrontal(ordering_full); + + GTSAM_PRINT(*hbt); + + GTSAM_PRINT(*remaining); + + hbt->dot(std::cout); + /* + Explanation: the Junction tree will need to reeliminate to get to the marginal + on X(1), which is not possible because it involves eliminating discrete before + continuous. The solution to this, however, is in Murphy02. TLDR is that this + is 1. expensive and 2. inexact. neverless it is doable. And I believe that we + should do this. + */ +} + +/* ************************************************************************* */ +// TODO(fan): make a graph like Varun's paper one +TEST(HybridGaussianFactorGraph, Switching) { + auto N = 12; + auto hfg = makeSwitchingChain(N); + + // X(5) will be the center, X(1-4), X(6-9) + // X(3), X(7) + // X(2), X(8) + // X(1), X(4), X(6), X(9) + // C(5) will be the center, C(1-4), C(6-8) + // C(3), C(7) + // C(1), C(4), C(2), C(6), C(8) + // auto ordering_full = + // Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7), + // X(5), + // C(1), C(4), C(2), C(6), C(8), C(3), C(7), C(5)}); + KeyVector ordering; + + { + std::vector naturalX(N); + std::iota(naturalX.begin(), naturalX.end(), 1); + std::vector ordX; + std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX), + [](int x) { return X(x); }); + + KeyVector ndX; + std::vector lvls; + std::tie(ndX, lvls) = makeBinaryOrdering(ordX); + std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering)); + for (auto &l : lvls) { + l = -l; + } + std::copy(lvls.begin(), lvls.end(), + std::ostream_iterator(std::cout, ",")); + std::cout << "\n"; + } + { + std::vector naturalC(N - 1); + std::iota(naturalC.begin(), naturalC.end(), 1); + std::vector ordC; + std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC), + [](int x) { return C(x); }); + KeyVector ndC; + std::vector lvls; + + // std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering)); + std::tie(ndC, lvls) = makeBinaryOrdering(ordC); + std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering)); + std::copy(lvls.begin(), lvls.end(), + std::ostream_iterator(std::cout, ",")); + } + auto ordering_full = Ordering(ordering); + + // auto ordering_full = + // Ordering(); + + // for (int i = 1; i <= 9; i++) { + // ordering_full.push_back(X(i)); + // } + + // for (int i = 1; i < 9; i++) { + // ordering_full.push_back(C(i)); + // } + + // auto ordering_full = + // Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7), + // X(5), + // C(1), C(2), C(3), C(4), C(5), C(6), C(7), C(8)}); + + // GTSAM_PRINT(*hfg); + GTSAM_PRINT(ordering_full); + + HybridBayesTree::shared_ptr hbt; + HybridGaussianFactorGraph::shared_ptr remaining; + std::tie(hbt, remaining) = hfg->eliminatePartialMultifrontal(ordering_full); + + // GTSAM_PRINT(*hbt); + + // GTSAM_PRINT(*remaining); + + { + DotWriter dw; + dw.positionHints['c'] = 2; + dw.positionHints['x'] = 1; + std::cout << hfg->dot(DefaultKeyFormatter, dw); + std::cout << "\n"; + hbt->dot(std::cout); + } + + { + DotWriter dw; + // dw.positionHints['c'] = 2; + // dw.positionHints['x'] = 1; + std::cout << "\n"; + std::cout << hfg->eliminateSequential(ordering_full) + ->dot(DefaultKeyFormatter, dw); + } + /* + Explanation: the Junction tree will need to reeliminate to get to the marginal + on X(1), which is not possible because it involves eliminating discrete before + continuous. The solution to this, however, is in Murphy02. TLDR is that this + is 1. expensive and 2. inexact. neverless it is doable. And I believe that we + should do this. + */ + hbt->marginalFactor(C(11))->print("HBT: "); +} + +/* ************************************************************************* */ +// TODO(fan): make a graph like Varun's paper one +TEST(HybridGaussianFactorGraph, SwitchingISAM) { + auto N = 11; + auto hfg = makeSwitchingChain(N); + + // X(5) will be the center, X(1-4), X(6-9) + // X(3), X(7) + // X(2), X(8) + // X(1), X(4), X(6), X(9) + // C(5) will be the center, C(1-4), C(6-8) + // C(3), C(7) + // C(1), C(4), C(2), C(6), C(8) + // auto ordering_full = + // Ordering(KeyVector{X(1), X(4), X(2), X(6), X(9), X(8), X(3), X(7), + // X(5), + // C(1), C(4), C(2), C(6), C(8), C(3), C(7), C(5)}); + KeyVector ordering; + + { + std::vector naturalX(N); + std::iota(naturalX.begin(), naturalX.end(), 1); + std::vector ordX; + std::transform(naturalX.begin(), naturalX.end(), std::back_inserter(ordX), + [](int x) { return X(x); }); + + KeyVector ndX; + std::vector lvls; + std::tie(ndX, lvls) = makeBinaryOrdering(ordX); + std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering)); + for (auto &l : lvls) { + l = -l; + } + std::copy(lvls.begin(), lvls.end(), + std::ostream_iterator(std::cout, ",")); + std::cout << "\n"; + } + { + std::vector naturalC(N - 1); + std::iota(naturalC.begin(), naturalC.end(), 1); + std::vector ordC; + std::transform(naturalC.begin(), naturalC.end(), std::back_inserter(ordC), + [](int x) { return C(x); }); + KeyVector ndC; + std::vector lvls; + + // std::copy(ordC.begin(), ordC.end(), std::back_inserter(ordering)); + std::tie(ndC, lvls) = makeBinaryOrdering(ordC); + std::copy(ndC.begin(), ndC.end(), std::back_inserter(ordering)); + std::copy(lvls.begin(), lvls.end(), + std::ostream_iterator(std::cout, ",")); + } + auto ordering_full = Ordering(ordering); + + // GTSAM_PRINT(*hfg); + GTSAM_PRINT(ordering_full); + + HybridBayesTree::shared_ptr hbt; + HybridGaussianFactorGraph::shared_ptr remaining; + std::tie(hbt, remaining) = hfg->eliminatePartialMultifrontal(ordering_full); + + // GTSAM_PRINT(*hbt); + + // GTSAM_PRINT(*remaining); + + { + DotWriter dw; + dw.positionHints['c'] = 2; + dw.positionHints['x'] = 1; + std::cout << hfg->dot(DefaultKeyFormatter, dw); + std::cout << "\n"; + hbt->dot(std::cout); + } + + { + DotWriter dw; + // dw.positionHints['c'] = 2; + // dw.positionHints['x'] = 1; + std::cout << "\n"; + std::cout << hfg->eliminateSequential(ordering_full) + ->dot(DefaultKeyFormatter, dw); + } + + auto new_fg = makeSwitchingChain(12); + auto isam = HybridGaussianISAM(*hbt); + + { + HybridGaussianFactorGraph factorGraph; + factorGraph.push_back(new_fg->at(new_fg->size() - 2)); + factorGraph.push_back(new_fg->at(new_fg->size() - 1)); + isam.update(factorGraph); + std::cout << isam.dot(); + isam.marginalFactor(C(11))->print(); + } +} + +/* ************************************************************************* */ +TEST(HybridGaussianFactorGraph, SwitchingTwoVar) { + const int N = 7; + auto hfg = makeSwitchingChain(N, X); + hfg->push_back(*makeSwitchingChain(N, Y, D)); + + for (int t = 1; t <= N; t++) { + hfg->add(JacobianFactor(X(t), I_3x3, Y(t), -I_3x3, Vector3(1.0, 0.0, 0.0))); + } + + KeyVector ordering; + + KeyVector naturalX(N); + std::iota(naturalX.begin(), naturalX.end(), 1); + KeyVector ordX; + for (size_t i = 1; i <= N; i++) { + ordX.emplace_back(X(i)); + ordX.emplace_back(Y(i)); + } + + // { + // KeyVector ndX; + // std::vector lvls; + // std::tie(ndX, lvls) = makeBinaryOrdering(naturalX); + // std::copy(ndX.begin(), ndX.end(), std::back_inserter(ordering)); + // std::copy(lvls.begin(), lvls.end(), + // std::ostream_iterator(std::cout, ",")); + // std::cout << "\n"; + + // for (size_t i = 0; i < N; i++) { + // ordX.emplace_back(X(ndX[i])); + // ordX.emplace_back(Y(ndX[i])); + // } + // } + + for (size_t i = 1; i <= N - 1; i++) { + ordX.emplace_back(C(i)); + } + for (size_t i = 1; i <= N - 1; i++) { + ordX.emplace_back(D(i)); + } + + { + DotWriter dw; + dw.positionHints['x'] = 1; + dw.positionHints['c'] = 0; + dw.positionHints['d'] = 3; + dw.positionHints['y'] = 2; + std::cout << hfg->dot(DefaultKeyFormatter, dw); + std::cout << "\n"; + } + + { + DotWriter dw; + dw.positionHints['y'] = 9; + // dw.positionHints['c'] = 0; + // dw.positionHints['d'] = 3; + dw.positionHints['x'] = 1; + std::cout << "\n"; + // std::cout << hfg->eliminateSequential(Ordering(ordX)) + // ->dot(DefaultKeyFormatter, dw); + hfg->eliminateMultifrontal(Ordering(ordX))->dot(std::cout); + } + + Ordering ordering_partial; + for (size_t i = 1; i <= N; i++) { + ordering_partial.emplace_back(X(i)); + ordering_partial.emplace_back(Y(i)); + } + { + HybridBayesNet::shared_ptr hbn; + HybridGaussianFactorGraph::shared_ptr remaining; + std::tie(hbn, remaining) = + hfg->eliminatePartialSequential(ordering_partial); + + // remaining->print(); + { + DotWriter dw; + dw.positionHints['x'] = 1; + dw.positionHints['c'] = 0; + dw.positionHints['d'] = 3; + dw.positionHints['y'] = 2; + std::cout << remaining->dot(DefaultKeyFormatter, dw); + std::cout << "\n"; + } + } +} + +/* ************************************************************************* */ +int main() { + TestResult tr; + return TestRegistry::runAllTests(tr); +} +/* ************************************************************************* */ diff --git a/gtsam/inference/BayesTree.h b/gtsam/inference/BayesTree.h index 5b053ebee5..924a505a2d 100644 --- a/gtsam/inference/BayesTree.h +++ b/gtsam/inference/BayesTree.h @@ -33,6 +33,7 @@ namespace gtsam { // Forward declarations template class FactorGraph; template class EliminatableClusterTree; + class HybridBayesTreeClique; /* ************************************************************************* */ /** clique statistics */ @@ -272,24 +273,33 @@ namespace gtsam { }; // BayesTree /* ************************************************************************* */ - template - class BayesTreeOrphanWrapper : public CLIQUE::ConditionalType - { - public: + template + class BayesTreeOrphanWrapper : public CLIQUE::ConditionalType { + public: typedef CLIQUE CliqueType; typedef typename CLIQUE::ConditionalType Base; boost::shared_ptr clique; - BayesTreeOrphanWrapper(const boost::shared_ptr& clique) : - clique(clique) - { - // Store parent keys in our base type factor so that eliminating those parent keys will pull - // this subtree into the elimination. - this->keys_.assign(clique->conditional()->beginParents(), clique->conditional()->endParents()); + /** + * @brief Construct a new Bayes Tree Orphan Wrapper object + * + * This object stores parent keys in our base type factor so that + * eliminating those parent keys will pull this subtree into the + * elimination. + * + * @param clique Orphan clique to add for further consideration in + * elimination. + */ + BayesTreeOrphanWrapper(const boost::shared_ptr& clique) + : clique(clique) { + this->keys_.assign(clique->conditional()->beginParents(), + clique->conditional()->endParents()); } - void print(const std::string& s="", const KeyFormatter& formatter = DefaultKeyFormatter) const override { + void print( + const std::string& s = "", + const KeyFormatter& formatter = DefaultKeyFormatter) const override { clique->print(s + "stored clique", formatter); } }; diff --git a/gtsam/inference/EliminateableFactorGraph.h b/gtsam/inference/EliminateableFactorGraph.h index c904d2f7ff..900346f7fb 100644 --- a/gtsam/inference/EliminateableFactorGraph.h +++ b/gtsam/inference/EliminateableFactorGraph.h @@ -204,7 +204,7 @@ namespace gtsam { OptionalVariableIndex variableIndex = boost::none) const; /** Do multifrontal elimination of the given \c variables in an ordering computed by COLAMD to - * produce a Bayes net and a remaining factor graph. This computes the factorization \f$ p(X) + * produce a Bayes tree and a remaining factor graph. This computes the factorization \f$ p(X) * = p(A|B) p(B) \f$, where \f$ A = \f$ \c variables, \f$ X \f$ is all the variables in the * factor graph, and \f$ B = X\backslash A \f$. */ std::pair, boost::shared_ptr > diff --git a/gtsam/inference/JunctionTree.h b/gtsam/inference/JunctionTree.h index e01f3721a4..e914c325ef 100644 --- a/gtsam/inference/JunctionTree.h +++ b/gtsam/inference/JunctionTree.h @@ -70,7 +70,7 @@ namespace gtsam { /// @} - private: + protected: // Private default constructor (used in static construction methods) JunctionTree() {} diff --git a/gtsam/inference/inference.i b/gtsam/inference/inference.i index e7b074ec49..fbdd70fdfb 100644 --- a/gtsam/inference/inference.i +++ b/gtsam/inference/inference.i @@ -9,6 +9,7 @@ namespace gtsam { #include #include #include +#include #include @@ -106,36 +107,36 @@ class Ordering { template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridGaussianFactorGraph}> static gtsam::Ordering Colamd(const FACTOR_GRAPH& graph); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridGaussianFactorGraph}> static gtsam::Ordering ColamdConstrainedLast( const FACTOR_GRAPH& graph, const gtsam::KeyVector& constrainLast, bool forceOrder = false); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridGaussianFactorGraph}> static gtsam::Ordering ColamdConstrainedFirst( const FACTOR_GRAPH& graph, const gtsam::KeyVector& constrainFirst, bool forceOrder = false); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridGaussianFactorGraph}> static gtsam::Ordering Natural(const FACTOR_GRAPH& graph); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridGaussianFactorGraph}> static gtsam::Ordering Metis(const FACTOR_GRAPH& graph); template < FACTOR_GRAPH = {gtsam::NonlinearFactorGraph, gtsam::DiscreteFactorGraph, - gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph}> + gtsam::SymbolicFactorGraph, gtsam::GaussianFactorGraph, gtsam::HybridGaussianFactorGraph}> static gtsam::Ordering Create(gtsam::Ordering::OrderingType orderingType, const FACTOR_GRAPH& graph); diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index c14f02ddab..cba206d111 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -67,6 +67,7 @@ set(interface_headers ${PROJECT_SOURCE_DIR}/gtsam/sfm/sfm.i ${PROJECT_SOURCE_DIR}/gtsam/navigation/navigation.i ${PROJECT_SOURCE_DIR}/gtsam/basis/basis.i + ${PROJECT_SOURCE_DIR}/gtsam/hybrid/hybrid.i ) set(GTSAM_PYTHON_TARGET gtsam_py) diff --git a/python/gtsam/preamble/base.h b/python/gtsam/preamble/base.h index 626b47ae4a..5cf633e653 100644 --- a/python/gtsam/preamble/base.h +++ b/python/gtsam/preamble/base.h @@ -11,6 +11,8 @@ * mutations on Python side will not be reflected on C++. */ -PYBIND11_MAKE_OPAQUE(std::vector); +PYBIND11_MAKE_OPAQUE(gtsam::IndexPairVector); + +PYBIND11_MAKE_OPAQUE(gtsam::IndexPairSet); PYBIND11_MAKE_OPAQUE(std::vector); // JacobianVector diff --git a/python/gtsam/preamble/discrete.h b/python/gtsam/preamble/discrete.h index 608508c32f..320e0ac718 100644 --- a/python/gtsam/preamble/discrete.h +++ b/python/gtsam/preamble/discrete.h @@ -13,4 +13,3 @@ #include -PYBIND11_MAKE_OPAQUE(gtsam::DiscreteKeys); diff --git a/python/gtsam/preamble/hybrid.h b/python/gtsam/preamble/hybrid.h new file mode 100644 index 0000000000..5e5a71e48d --- /dev/null +++ b/python/gtsam/preamble/hybrid.h @@ -0,0 +1,14 @@ +/* Please refer to: + * https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html + * These are required to save one copy operation on Python calls. + * + * NOTES + * ================= + * + * `PYBIND11_MAKE_OPAQUE` will mark the type as "opaque" for the pybind11 + * automatic STL binding, such that the raw objects can be accessed in Python. + * Without this they will be automatically converted to a Python object, and all + * mutations on Python side will not be reflected on C++. + */ + +PYBIND11_MAKE_OPAQUE(std::vector); diff --git a/python/gtsam/preamble/inference.h b/python/gtsam/preamble/inference.h index 4106c794ac..d07a75f6fb 100644 --- a/python/gtsam/preamble/inference.h +++ b/python/gtsam/preamble/inference.h @@ -10,5 +10,3 @@ * Without this they will be automatically converted to a Python object, and all * mutations on Python side will not be reflected on C++. */ - -#include \ No newline at end of file diff --git a/python/gtsam/specializations/hybrid.h b/python/gtsam/specializations/hybrid.h new file mode 100644 index 0000000000..bede6d86c4 --- /dev/null +++ b/python/gtsam/specializations/hybrid.h @@ -0,0 +1,4 @@ + +py::bind_vector >(m_, "GaussianFactorVector"); + +py::implicitly_convertible >(); diff --git a/python/gtsam/specializations/inference.h b/python/gtsam/specializations/inference.h index 22fe3beff6..9e23444ea2 100644 --- a/python/gtsam/specializations/inference.h +++ b/python/gtsam/specializations/inference.h @@ -11,3 +11,4 @@ * and saves one copy operation. */ +py::bind_map>(m_, "__MapCharDouble"); diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index 10c5db612a..ff2ba99d15 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -139,7 +139,7 @@ def test_dot(self): # Make sure we can *update* position hints writer = gtsam.DotWriter() ph: dict = writer.positionHints - ph.update({'a': 2}) # hint at symbol position + ph['a'] = 2 # hint at symbol position writer.positionHints = ph # Check the output of dot diff --git a/python/gtsam/tests/test_HybridFactorGraph.py b/python/gtsam/tests/test_HybridFactorGraph.py new file mode 100644 index 0000000000..781cfd9240 --- /dev/null +++ b/python/gtsam/tests/test_HybridFactorGraph.py @@ -0,0 +1,60 @@ +""" +GTSAM Copyright 2010-2019, Georgia Tech Research Corporation, +Atlanta, Georgia 30332-0415 +All Rights Reserved + +See LICENSE for the license information + +Unit tests for Hybrid Factor Graphs. +Author: Fan Jiang +""" +# pylint: disable=invalid-name, no-name-in-module, no-member + +from __future__ import print_function + +import unittest + +import gtsam +import numpy as np +from gtsam.symbol_shorthand import C, X +from gtsam.utils.test_case import GtsamTestCase + + +class TestHybridGaussianFactorGraph(GtsamTestCase): + """Unit tests for HybridGaussianFactorGraph.""" + + def test_create(self): + """Test contruction of hybrid factor graph.""" + noiseModel = gtsam.noiseModel.Unit.Create(3) + dk = gtsam.DiscreteKeys() + dk.push_back((C(0), 2)) + + jf1 = gtsam.JacobianFactor(X(0), np.eye(3), np.zeros((3, 1)), + noiseModel) + jf2 = gtsam.JacobianFactor(X(0), np.eye(3), np.ones((3, 1)), + noiseModel) + + gmf = gtsam.GaussianMixtureFactor.FromFactors([X(0)], dk, [jf1, jf2]) + + hfg = gtsam.HybridGaussianFactorGraph() + hfg.add(jf1) + hfg.add(jf2) + hfg.push_back(gmf) + + hbn = hfg.eliminateSequential( + gtsam.Ordering.ColamdConstrainedLastHybridGaussianFactorGraph( + hfg, [C(0)])) + + # print("hbn = ", hbn) + self.assertEqual(hbn.size(), 2) + + mixture = hbn.at(0).inner() + self.assertIsInstance(mixture, gtsam.GaussianMixture) + self.assertEqual(len(mixture.keys()), 2) + + discrete_conditional = hbn.at(hbn.size() - 1).inner() + self.assertIsInstance(discrete_conditional, gtsam.DiscreteConditional) + + +if __name__ == "__main__": + unittest.main()