Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various updates for Hybrid #1284

Merged
merged 4 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions gtsam/discrete/DiscreteKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ namespace gtsam {
push_back(key);
return *this;
}

/// Print the keys and cardinalities.
void print(const std::string& s = "",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
for (auto&& dkey : *this) {
std::cout << DefaultKeyFormatter(dkey.first) << " " << dkey.second
<< std::endl;
}
}

}; // DiscreteKeys

/// Create a list from two keys
Expand Down
25 changes: 23 additions & 2 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,35 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering(
OptionalOrderingType orderingType) const {
const KeySet HybridGaussianFactorGraph::getDiscreteKeys() const {
KeySet discrete_keys;
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}
return discrete_keys;
}

/* ************************************************************************ */
const KeySet HybridGaussianFactorGraph::getContinuousKeys() const {
KeySet keys;
for (auto &factor : factors_) {
for (const Key &key : factor->continuousKeys()) {
keys.insert(key);
}
}
return keys;
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering() const {
KeySet discrete_keys = getDiscreteKeys();
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}

const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
Expand Down
17 changes: 11 additions & 6 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,19 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
}
}

/// Get all the discrete keys in the factor graph.
const KeySet getDiscreteKeys() const;

/// Get all the continuous keys in the factor graph.
const KeySet getContinuousKeys() const;

/**
* @brief
*
* @param orderingType
* @return const Ordering
* @brief Return a Colamd constrained ordering where the discrete keys are
* eliminated after the continuous keys.
*
* @return const Ordering
*/
const Ordering getHybridOrdering(
OptionalOrderingType orderingType = boost::none) const;
const Ordering getHybridOrdering() const;
};

} // namespace gtsam
18 changes: 9 additions & 9 deletions gtsam/hybrid/HybridValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
namespace gtsam {

/**
* HybridValues represents a collection of DiscreteValues and VectorValues. It
* is typically used to store the variables of a HybridGaussianFactorGraph.
* HybridValues represents a collection of DiscreteValues and VectorValues.
* It is typically used to store the variables of a HybridGaussianFactorGraph.
* Optimizing a HybridGaussianBayesNet returns this class.
*/
class GTSAM_EXPORT HybridValues {
Expand All @@ -47,18 +47,18 @@ class GTSAM_EXPORT HybridValues {
/// @name Standard Constructors
/// @{

// Default constructor creates an empty HybridValues.
/// Default constructor creates an empty HybridValues.
HybridValues() = default;

// Construct from DiscreteValues and VectorValues.
/// Construct from DiscreteValues and VectorValues.
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
: discrete_(dv), continuous_(cv){};

/// @}
/// @name Testable
/// @{

// print required by Testable for unit testing
/// print required by Testable for unit testing
void print(const std::string& s = "HybridValues",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": \n";
Expand All @@ -67,7 +67,7 @@ class GTSAM_EXPORT HybridValues {
keyFormatter); // print continuous components
};

// equals required by Testable for unit testing
/// equals required by Testable for unit testing
bool equals(const HybridValues& other, double tol = 1e-9) const {
return discrete_.equals(other.discrete_, tol) &&
continuous_.equals(other.continuous_, tol);
Expand All @@ -83,13 +83,13 @@ class GTSAM_EXPORT HybridValues {
/// Return the delta update for the continuous vectors
VectorValues continuous() const { return continuous_; }

// Check whether a variable with key \c j exists in DiscreteValue.
/// Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete_.find(j) != discrete_.end()); };

// Check whether a variable with key \c j exists in VectorValue.
/// Check whether a variable with key \c j exists in VectorValue.
bool existsVector(Key j) { return continuous_.exists(j); };

// Check whether a variable with key \c j exists.
/// Check whether a variable with key \c j exists.
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };

/** Insert a discrete \c value with key \c j. Replaces the existing value if
Expand Down
2 changes: 2 additions & 0 deletions gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class HybridBayesTree {
bool empty() const;
const HybridBayesTreeClique* operator[](size_t j) const;

gtsam::HybridValues optimize() const;

string dot(const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const;
};
Expand Down
9 changes: 4 additions & 5 deletions gtsam/hybrid/tests/testGaussianHybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalSimple) {
hfg.add(DecisionTreeFactor(m1, {2, 8}));
hfg.add(DecisionTreeFactor({{M(1), 2}, {M(2), 2}}, "1 2 3 4"));

HybridBayesTree::shared_ptr result = hfg.eliminateMultifrontal(
Ordering::ColamdConstrainedLast(hfg, {M(1), M(2)}));
HybridBayesTree::shared_ptr result =
hfg.eliminateMultifrontal(hfg.getHybridOrdering());

// The bayes tree should have 3 cliques
EXPECT_LONGS_EQUAL(3, result->size());
Expand Down Expand Up @@ -215,7 +215,7 @@ TEST(HybridGaussianFactorGraph, eliminateFullMultifrontalCLG) {
hfg.add(HybridDiscreteFactor(DecisionTreeFactor(m, {2, 8})));

// Get a constrained ordering keeping c1 last
auto ordering_full = Ordering::ColamdConstrainedLast(hfg, {M(1)});
auto ordering_full = hfg.getHybridOrdering();

// Returns a Hybrid Bayes Tree with distribution P(x0|x1)P(x1|c1)P(c1)
HybridBayesTree::shared_ptr hbt = hfg.eliminateMultifrontal(ordering_full);
Expand Down Expand Up @@ -484,8 +484,7 @@ TEST(HybridGaussianFactorGraph, SwitchingTwoVar) {
}
HybridBayesNet::shared_ptr hbn;
HybridGaussianFactorGraph::shared_ptr remaining;
std::tie(hbn, remaining) =
hfg->eliminatePartialSequential(ordering_partial);
std::tie(hbn, remaining) = hfg->eliminatePartialSequential(ordering_partial);

EXPECT_LONGS_EQUAL(14, hbn->size());
EXPECT_LONGS_EQUAL(11, remaining->size());
Expand Down