Skip to content

Commit

Permalink
Merge pull request #1282 from borglab/hybrid/optimize-2
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Aug 31, 2022
2 parents a6b9554 + ca4293b commit f7e1d2a
Show file tree
Hide file tree
Showing 12 changed files with 225 additions and 529 deletions.
39 changes: 31 additions & 8 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
* @date January 2022
*/

#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>

namespace gtsam {
Expand Down Expand Up @@ -111,10 +112,15 @@ HybridBayesNet HybridBayesNet::prune(
}

/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
GaussianMixture::shared_ptr HybridBayesNet::atMixture(size_t i) const {
return factors_.at(i)->asMixture();
}

/* ************************************************************************* */
GaussianConditional::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return factors_.at(i)->asGaussian();
}

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return factors_.at(i)->asDiscreteConditional();
Expand All @@ -125,22 +131,39 @@ GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
try {
GaussianMixture gm = *this->atGaussian(idx);
if (factors_.at(idx)->isHybrid()) {
// If factor is hybrid, select based on assignment.
GaussianMixture gm = *this->atMixture(idx);
gbn.push_back(gm(assignment));

} catch (std::exception &exc) {
// if factor at `idx` is discrete-only, just continue.
} else if (factors_.at(idx)->isContinuous()) {
// If continuous only, add gaussian conditional.
gbn.push_back((this->atGaussian(idx)));

} else if (factors_.at(idx)->isDiscrete()) {
// If factor at `idx` is discrete-only, we simply continue.
continue;
}
}

return gbn;
}

/* *******************************************************************************/
HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this);
return dag.argmax();
// Solve for the MPE
DiscreteBayesNet discrete_bn;
for (auto &conditional : factors_) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscreteConditional());
}
}

DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();

// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = this->choose(mpe);
return HybridValues(mpe, gbn.optimize());
}

/* *******************************************************************************/
Expand Down
5 changes: 4 additions & 1 deletion gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
}

/// Get a specific Gaussian mixture by index `i`.
GaussianMixture::shared_ptr atGaussian(size_t i) const;
GaussianMixture::shared_ptr atMixture(size_t i) const;

/// Get a specific Gaussian conditional by index `i`.
GaussianConditional::shared_ptr atGaussian(size_t i) const;

/// Get a specific discrete conditional by index `i`.
DiscreteConditional::shared_ptr atDiscrete(size_t i) const;
Expand Down
59 changes: 55 additions & 4 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/

#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h>
#include <gtsam/inference/BayesTree-inst.h>
Expand All @@ -35,6 +37,52 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}

/* ************************************************************************* */
HybridValues HybridBayesTree::optimize() const {
HybridBayesNet hbn;
DiscreteBayesNet dbn;

KeyVector added_keys;

// Iterate over all the nodes in the BayesTree
for (auto&& node : nodes()) {
// Check if conditional being added is already in the Bayes net.
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
added_keys.end()) {
// Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();

// Record the key being added
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());

if (conditional->isDiscrete()) {
// If discrete, we use it to compute the MPE
dbn.push_back(conditional->asDiscreteConditional());

} else {
// Else conditional is hybrid or continuous-only,
// so we directly add it to the Hybrid Bayes net.
hbn.push_back(conditional);
}
}
}
// Get the MPE
DiscreteValues mpe = DiscreteFactorGraph(dbn).optimize();
// Given the MPE, compute the optimal continuous values.
GaussianBayesNet gbn = hbn.choose(mpe);

// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
auto reversed = boost::adaptors::reverse(gbn);
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
#endif

return HybridValues(mpe, gbn.optimize());
}

/* ************************************************************************* */
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesNet gbn;
Expand All @@ -50,11 +98,9 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();

KeyVector frontals(conditional->frontals().begin(),
conditional->frontals().end());

// Record the key being added
added_keys.insert(added_keys.end(), frontals.begin(), frontals.end());
added_keys.insert(added_keys.end(), conditional->frontals().begin(),
conditional->frontals().end());

// If conditional is hybrid (and not discrete-only), we get the Gaussian
// Conditional corresponding to the assignment and add it to the Gaussian
Expand All @@ -65,9 +111,14 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
(*gm)(assignment);

gbn.push_back(gaussian_conditional);

} else if (conditional->isContinuous()) {
// If conditional is Gaussian, we simply add it to the Bayes net.
gbn.push_back(conditional->asGaussian());
}
}
}

// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
Expand Down
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;

/**
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
* set of discrete variables and using it to compute the best continuous
* update delta.
*
* @return HybridValues
*/
HybridValues optimize() const;

/**
* @brief Recursively optimize the BayesTree to produce a vector solution.
*
Expand Down
11 changes: 11 additions & 0 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ class GTSAM_EXPORT HybridConditional
return boost::static_pointer_cast<GaussianMixture>(inner_);
}

/**
* @brief Return HybridConditional as a GaussianConditional
*
* @return GaussianConditional::shared_ptr
*/
GaussianConditional::shared_ptr asGaussian() {
if (!isContinuous())
throw std::invalid_argument("Not a continuous conditional");
return boost::static_pointer_cast<GaussianConditional>(inner_);
}

/**
* @brief Return conditional as a DiscreteConditional
*
Expand Down
76 changes: 0 additions & 76 deletions gtsam/hybrid/HybridLookupDAG.cpp

This file was deleted.

Loading

0 comments on commit f7e1d2a

Please sign in to comment.