Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Improvements #1343

Merged
merged 13 commits into from
Dec 30, 2022
55 changes: 42 additions & 13 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,37 +47,66 @@ DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
/**
* @brief Helper function to get the pruner functional.
*
* @param decisionTree The probability decision tree of only discrete keys.
* @return std::function<GaussianConditional::shared_ptr(
* const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
* @param prunedDecisionTree The prob. decision tree of only discrete keys.
* @param conditional Conditional to prune. Used to get full assignment.
* @return std::function<double(const Assignment<Key> &, double)>
*/
std::function<double(const Assignment<Key> &, double)> prunerFunc(
const DecisionTreeFactor &decisionTree,
const DecisionTreeFactor &prunedDecisionTree,
const HybridConditional &conditional) {
// Get the discrete keys as sets for the decision tree
// and the Gaussian mixture.
auto decisionTreeKeySet = DiscreteKeysAsSet(decisionTree.discreteKeys());
auto conditionalKeySet = DiscreteKeysAsSet(conditional.discreteKeys());
std::set<DiscreteKey> decisionTreeKeySet =
DiscreteKeysAsSet(prunedDecisionTree.discreteKeys());
std::set<DiscreteKey> conditionalKeySet =
DiscreteKeysAsSet(conditional.discreteKeys());

auto pruner = [decisionTree, decisionTreeKeySet, conditionalKeySet](
auto pruner = [prunedDecisionTree, decisionTreeKeySet, conditionalKeySet](
const Assignment<Key> &choices,
double probability) -> double {
// typecast so we can use this to get probability value
DiscreteValues values(choices);
// Case where the Gaussian mixture has the same
// discrete keys as the decision tree.
if (conditionalKeySet == decisionTreeKeySet) {
if (decisionTree(values) == 0) {
if (prunedDecisionTree(values) == 0) {
return 0.0;
} else {
return probability;
}
} else {
// Due to branch merging (aka pruning) in DecisionTree, it is possible we
// get a `values` which doesn't have the full set of keys.
std::set<Key> valuesKeys;
for (auto kvp : values) {
valuesKeys.insert(kvp.first);
}
std::set<Key> conditionalKeys;
for (auto kvp : conditionalKeySet) {
conditionalKeys.insert(kvp.first);
}
// If true, then values is missing some keys
if (conditionalKeys != valuesKeys) {
// Get the keys present in conditionalKeys but not in valuesKeys
std::vector<Key> missing_keys;
std::set_difference(conditionalKeys.begin(), conditionalKeys.end(),
valuesKeys.begin(), valuesKeys.end(),
std::back_inserter(missing_keys));
// Insert missing keys with a default assignment.
for (auto missing_key : missing_keys) {
values[missing_key] = 0;
}
}

// Now we generate the full assignment by enumerating
// over all keys in the prunedDecisionTree.
// First we find the differing keys
std::vector<DiscreteKey> set_diff;
std::set_difference(decisionTreeKeySet.begin(), decisionTreeKeySet.end(),
conditionalKeySet.begin(), conditionalKeySet.end(),
std::back_inserter(set_diff));

// Now enumerate over all assignments of the differing keys
const std::vector<DiscreteValues> assignments =
DiscreteValues::CartesianProduct(set_diff);
for (const DiscreteValues &assignment : assignments) {
Expand All @@ -86,7 +115,7 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(

// If any one of the sub-branches are non-zero,
// we need this probability.
if (decisionTree(augmented_values) > 0.0) {
if (prunedDecisionTree(augmented_values) > 0.0) {
return probability;
}
}
Expand All @@ -99,7 +128,6 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
}

/* ************************************************************************* */
// TODO(dellaert): what is this non-const method used for? Abolish it?
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor::shared_ptr &prunedDecisionTree) {
KeyVector prunedTreeKeys = prunedDecisionTree->keys();
Expand All @@ -109,8 +137,6 @@ void HybridBayesNet::updateDiscreteConditionals(
HybridConditional::shared_ptr conditional = this->at(i);
if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete();
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());

// Apply prunerFunc to the underlying AlgebraicDecisionTree
auto discreteTree =
Expand All @@ -119,6 +145,8 @@ void HybridBayesNet::updateDiscreteConditionals(
discreteTree->apply(prunerFunc(*prunedDecisionTree, *conditional));

// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = boost::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = boost::make_shared<HybridConditional>(prunedDiscrete);
Expand Down Expand Up @@ -206,14 +234,15 @@ GaussianBayesNet HybridBayesNet::choose(

/* ************************************************************************* */
HybridValues HybridBayesNet::optimize() const {
// Solve for the MPE
// Collect all the discrete factors to compute MPE
DiscreteBayesNet discrete_bn;
for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
discrete_bn.push_back(conditional->asDiscrete());
}
}

// Solve for the MPE
DiscreteValues mpe = DiscreteFactorGraph(discrete_bn).optimize();

// Given the MPE, compute the optimal continuous values.
Expand Down
14 changes: 13 additions & 1 deletion gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ struct HybridAssignmentData {

/* *************************************************************************
*/
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesTree HybridBayesTree::choose(
const DiscreteValues& assignment) const {
GaussianBayesTree gbt;
HybridAssignmentData rootData(assignment, 0, &gbt);
{
Expand All @@ -151,6 +152,17 @@ VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
}

if (!rootData.isValid()) {
return GaussianBayesTree();
}
return gbt;
}

/* *************************************************************************
*/
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesTree gbt = this->choose(assignment);
// If empty GaussianBayesTree, means a clique is pruned hence invalid
if (gbt.size() == 0) {
return VectorValues();
}
VectorValues result = gbt.optimize();
Expand Down
10 changes: 10 additions & 0 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <gtsam/inference/BayesTree.h>
#include <gtsam/inference/BayesTreeCliqueBase.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianBayesTree.h>

#include <string>

Expand Down Expand Up @@ -76,6 +77,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;

/**
* @brief Get the Gaussian Bayes Tree which corresponds to a specific discrete
* value assignment.
*
* @param assignment The discrete value assignment for the discrete keys.
* @return GaussianBayesTree
*/
GaussianBayesTree choose(const DiscreteValues& assignment) const;

/**
* @brief Optimize the hybrid Bayes tree by computing the MPE for the current
* set of discrete variables and using it to compute the best continuous
Expand Down
17 changes: 8 additions & 9 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
if (!factor) {
return 0.0; // If nullptr, return 0.0 probability
} else {
// This is the probability q(μ) at the MLE point.
double error =
0.5 * std::abs(factor->augmentedInformation().determinant());
return std::exp(-error);
Expand Down Expand Up @@ -396,18 +397,16 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
if (discrete_only) {
// Case 1: we are only dealing with discrete
return discreteElimination(factors, frontalKeys);
} else {
} else if (mapFromKeyToDiscreteKey.empty()) {
// Case 2: we are only dealing with continuous
if (mapFromKeyToDiscreteKey.empty()) {
return continuousElimination(factors, frontalKeys);
} else {
// Case 3: We are now in the hybrid land!
return continuousElimination(factors, frontalKeys);
} else {
// Case 3: We are now in the hybrid land!
#ifdef HYBRID_TIMING
tictoc_reset_();
tictoc_reset_();
#endif
return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet);
}
return hybridElimination(factors, frontalKeys, continuousSeparator,
discreteSeparatorSet);
}
}

Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
/**
* @file HybridGaussianFactorGraph.h
* @brief Linearized Hybrid factor graph that uses type erasure
* @author Fan Jiang
* @author Fan Jiang, Varun Agrawal
* @date Mar 11, 2022
*/

Expand Down
3 changes: 1 addition & 2 deletions gtsam/hybrid/HybridSmoother.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
/* ************************************************************************* */
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture(
size_t index) const {
return boost::dynamic_pointer_cast<GaussianMixture>(
hybridBayesNet_.at(index));
return hybridBayesNet_.atMixture(index);
}

/* ************************************************************************* */
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class GTSAM_EXPORT HybridValues {
* @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { continuous_.insert(j, value); }

// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h
// TODO(Shangjie)- insert_or_assign() , similar to Values.h

/**
* Read/write access to the discrete value with key \c j, throws
Expand Down
2 changes: 2 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,14 @@ TEST(HybridBayesNet, Optimize) {

HybridValues delta = hybridBayesNet->optimize();

//TODO(Varun) The expectedAssignment should be 111, not 101
DiscreteValues expectedAssignment;
expectedAssignment[M(0)] = 1;
expectedAssignment[M(1)] = 0;
expectedAssignment[M(2)] = 1;
EXPECT(assert_equal(expectedAssignment, delta.discrete()));

//TODO(Varun) This should be all -Vector1::Ones()
VectorValues expectedValues;
expectedValues.insert(X(0), -0.999904 * Vector1::Ones());
expectedValues.insert(X(1), -0.99029 * Vector1::Ones());
Expand Down
51 changes: 51 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,57 @@ TEST(HybridBayesTree, Optimize) {
EXPECT(assert_equal(expectedValues, delta.continuous()));
}

/* ****************************************************************************/
// Test for choosing a GaussianBayesTree from a HybridBayesTree.
TEST(HybridBayesTree, Choose) {
Switching s(4);

HybridGaussianISAM isam;
HybridGaussianFactorGraph graph1;

// Add the 3 hybrid factors, x1-x2, x2-x3, x3-x4
for (size_t i = 1; i < 4; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}

// Add the Gaussian factors, 1 prior on X(0),
// 3 measurements on X(2), X(3), X(4)
graph1.push_back(s.linearizedFactorGraph.at(0));
for (size_t i = 4; i <= 6; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}

// Add the discrete factors
for (size_t i = 7; i <= 9; i++) {
graph1.push_back(s.linearizedFactorGraph.at(i));
}

isam.update(graph1);

DiscreteValues assignment;
assignment[M(0)] = 1;
assignment[M(1)] = 1;
assignment[M(2)] = 1;

GaussianBayesTree gbt = isam.choose(assignment);

Ordering ordering;
ordering += X(0);
ordering += X(1);
ordering += X(2);
ordering += X(3);
ordering += M(0);
ordering += M(1);
ordering += M(2);

//TODO(Varun) get segfault if ordering not provided
auto bayesTree = s.linearizedFactorGraph.eliminateMultifrontal(ordering);

auto expected_gbt = bayesTree->choose(assignment);

EXPECT(assert_equal(expected_gbt, gbt));
}

/* ****************************************************************************/
// Test HybridBayesTree serialization.
TEST(HybridBayesTree, Serialization) {
Expand Down
Loading