Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrete Elimination Refactor #1919

Draft
wants to merge 48 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d1d440a
add nrValues method
varunagrawal Dec 7, 2024
a68da21
operator* version which accepts DiscreteFactor
varunagrawal Dec 7, 2024
a09b77e
return DiscreteFactor shared_ptr as leftover from elimination
varunagrawal Dec 7, 2024
27bbce1
generalize DiscreteFactorGraph::product to DiscreteFactor
varunagrawal Dec 7, 2024
84e4194
make normalization code common
varunagrawal Dec 7, 2024
4dac37c
make sum and max DiscreteFactor methods
varunagrawal Dec 7, 2024
6c45467
add timing info
varunagrawal Dec 7, 2024
b0ad350
add note about toDecisionTreeFactor
varunagrawal Dec 7, 2024
306a3ba
kill toDecisionTreeFactor to force rethink
varunagrawal Dec 7, 2024
2cd2ab0
DiscreteDistribution from TableFactor
varunagrawal Dec 7, 2024
9f88a36
make evaluate use the Assignment<Key> base class
varunagrawal Dec 7, 2024
2a3b5e6
use Assignment<Key> for evaluate since it is the base class
varunagrawal Dec 7, 2024
fff8458
remove TableFactor constructor in DiscreteDistribution
varunagrawal Dec 8, 2024
295b965
use Assignment<Key> since it is a base class
varunagrawal Dec 8, 2024
261038f
fix DiscreteConditional constructor
varunagrawal Dec 8, 2024
20d6d09
use DiscreteFactor everywhere in DiscreteFactorGraph.cpp
varunagrawal Dec 8, 2024
32b6bc0
update DiscreteConditional
varunagrawal Dec 8, 2024
38563da
Revert "kill toDecisionTreeFactor to force rethink"
varunagrawal Dec 8, 2024
9633ad1
make DiscreteConditional::likelihood match the declaration
varunagrawal Dec 8, 2024
0b3477f
get different classes to play nicely
varunagrawal Dec 8, 2024
1d79188
compiles
varunagrawal Dec 8, 2024
7757851
timing
varunagrawal Dec 8, 2024
9844a55
move evaluate and operator() next to each other
varunagrawal Dec 8, 2024
aa25ccf
implement evaluate in DiscreteFactor
varunagrawal Dec 8, 2024
90d7e21
change from DiscreteValues to Assignment<Key>
varunagrawal Dec 8, 2024
6665659
use BaseFactor instead of DecisionTreeFactor
varunagrawal Dec 8, 2024
f9a9801
Merge branch 'ring' into discrete-elimination-refactor
varunagrawal Dec 8, 2024
e6b6528
common definitions of Unary, UnaryAssignment and Binary
varunagrawal Dec 8, 2024
f85284a
some cleanup based on previous commit
varunagrawal Dec 8, 2024
5e86f7e
remove previously added code
varunagrawal Dec 8, 2024
1c14a56
revert changes to make code generic
varunagrawal Dec 8, 2024
b325150
revert DiscreteFactorGraph::product
varunagrawal Dec 8, 2024
0afc198
revert some DiscreteFactorGraph changes
varunagrawal Dec 8, 2024
975fe62
add methods in gtsam_unstable
varunagrawal Dec 8, 2024
fc2d33f
add division with DiscreteFactor::shared_ptr for convenience
varunagrawal Dec 8, 2024
2c02efc
fix tests
varunagrawal Dec 8, 2024
360598d
undo uncomment
varunagrawal Dec 8, 2024
853241c
add evaluate to DiscreteConditional
varunagrawal Dec 8, 2024
199c0a0
keep using DecisionTreeFactor for DiscreteConditional
varunagrawal Dec 8, 2024
214bf4e
more fixes
varunagrawal Dec 8, 2024
0de114f
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 9, 2024
e46cd54
TableFactor cleanup
varunagrawal Dec 9, 2024
52c8034
add division by DiscreteFactor in TableFactor
varunagrawal Dec 9, 2024
e0e833c
cleanup
varunagrawal Dec 9, 2024
84627c0
fix error
varunagrawal Dec 9, 2024
cc4e9cb
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 10, 2024
0b3f058
Merge branch 'develop' into discrete-elimination-refactor
varunagrawal Dec 11, 2024
22d11d7
don't print timing info by default
varunagrawal Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,26 +159,31 @@ namespace gtsam {
return apply(f, safe_div);
}

/// divide by DiscreteFactor::shared_ptr f (safely)
DecisionTreeFactor operator/(const DiscreteFactor::shared_ptr& f) const {
return apply(*std::dynamic_pointer_cast<DecisionTreeFactor>(f), safe_div);
}

/// Convert into a decision tree
DecisionTreeFactor toDecisionTreeFactor() const override { return *this; }

/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}

/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}

Expand Down Expand Up @@ -259,6 +264,12 @@ namespace gtsam {
*/
DecisionTreeFactor prune(size_t maxNrAssignments) const;

/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return nrLeaves(); }

/// @}
/// @name Wrapper support
/// @{
Expand Down
16 changes: 9 additions & 7 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ template class GTSAM_EXPORT
/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
const DecisionTreeFactor& f)
: BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {}
: BaseFactor(f / (*std::dynamic_pointer_cast<DecisionTreeFactor>(
f.sum(nrFrontals)))),
BaseConditional(nrFrontals) {}

/* ************************************************************************** */
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
Expand Down Expand Up @@ -149,11 +151,11 @@ void DiscreteConditional::print(const string& s,
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
if (!dynamic_cast<const BaseFactor*>(&other)) {
return false;
} else {
const DecisionTreeFactor& f(static_cast<const DecisionTreeFactor&>(other));
return DecisionTreeFactor::equals(f, tol);
const BaseFactor& f(static_cast<const BaseFactor&>(other));
return BaseFactor::equals(f, tol);
}
}

Expand Down Expand Up @@ -374,7 +376,7 @@ std::string DiscreteConditional::markdown(const KeyFormatter& keyFormatter,
ss << "*\n" << std::endl;
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::markdown(keyFormatter, names);
ss << BaseFactor::markdown(keyFormatter, names);
return ss.str();
}

Expand Down Expand Up @@ -426,7 +428,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,
ss << "</i></p>\n";
if (nrParents() == 0) {
// We have no parents, call factor method.
ss << DecisionTreeFactor::html(keyFormatter, names);
ss << BaseFactor::html(keyFormatter, names);
return ss.str();
}

Expand Down Expand Up @@ -474,7 +476,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter,

/* ************************************************************************* */
double DiscreteConditional::evaluate(const HybridValues& x) const {
return this->evaluate(x.discrete());
return this->operator()(x.discrete());
}

/* ************************************************************************* */
Expand Down
19 changes: 19 additions & 0 deletions gtsam/discrete/DiscreteFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/discrete/AlgebraicDecisionTree.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/inference/Ordering.h>

#include <string>
namespace gtsam {
Expand Down Expand Up @@ -131,6 +132,24 @@ class GTSAM_EXPORT DiscreteFactor : public Factor {

virtual DecisionTreeFactor toDecisionTreeFactor() const = 0;

/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(size_t nrFrontals) const = 0;

/// Create new factor by summing all values with the same separator values
virtual DiscreteFactor::shared_ptr sum(const Ordering& keys) const = 0;

/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(size_t nrFrontals) const = 0;

/// Create new factor by maximizing over all values with the same separator.
virtual DiscreteFactor::shared_ptr max(const Ordering& keys) const = 0;

/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
virtual uint64_t nrValues() const = 0;

/// @}
/// @name Wrapper support
/// @{
Expand Down
12 changes: 8 additions & 4 deletions gtsam/discrete/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,11 @@ namespace gtsam {
auto normalization = product.max(product.size());

// Normalize the product factor to prevent underflow.
product = product / (*normalization);
auto normalized_product =
product /
(*std::dynamic_pointer_cast<DecisionTreeFactor>(normalization));

return product;
return normalized_product;
}

/* ************************************************************************ */
Expand All @@ -143,7 +145,8 @@ namespace gtsam {

// max out frontals, this is the factor on the separator
gttic(max);
DecisionTreeFactor::shared_ptr max = product.max(frontalKeys);
DecisionTreeFactor::shared_ptr max =
std::dynamic_pointer_cast<DecisionTreeFactor>(product.max(frontalKeys));
gttoc(max);

// Ordering keys for the conditional so that frontalKeys are really in front
Expand Down Expand Up @@ -222,7 +225,8 @@ namespace gtsam {

// sum out frontals, this is the factor on the separator
gttic(sum);
DecisionTreeFactor::shared_ptr sum = product.sum(frontalKeys);
DecisionTreeFactor::shared_ptr sum = std::dynamic_pointer_cast<DecisionTreeFactor>(
product.sum(frontalKeys));
gttoc(sum);

// Ordering keys for the conditional so that frontalKeys are really in front
Expand Down
1 change: 1 addition & 0 deletions gtsam/discrete/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class GTSAM_EXPORT DiscreteFactorGraph

/// @}

//TODO(Varun): Make compatible with TableFactor
/** Add a decision-tree factor */
template <typename... Args>
void add(Args&&... args) {
Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteLookupDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void DiscreteLookupTable::print(const std::string& s,
}
}
cout << "):\n";
ADT::print("", formatter);
BaseFactor::print("", formatter);
cout << endl;
}

Expand Down
13 changes: 13 additions & 0 deletions gtsam/discrete/DiscreteLookupDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>

Expand Down Expand Up @@ -54,6 +55,18 @@ class GTSAM_EXPORT DiscreteLookupTable : public DiscreteConditional {
const ADT& potentials)
: DiscreteConditional(nFrontals, keys, potentials) {}

/**
* @brief Construct a new Discrete Lookup Table object
*
* @param nFrontals number of frontal variables
* @param keys a sorted list of gtsam::Keys
* @param potentials Discrete potentials as a TableFactor.
*/
DiscreteLookupTable(size_t nFrontals, const DiscreteKeys& keys,
const TableFactor& potentials)
: DiscreteConditional(nFrontals, keys,
potentials.toDecisionTreeFactor()) {}

/// GTSAM-style print
void print(
const std::string& s = "Discrete Lookup Table: ",
Expand Down
27 changes: 22 additions & 5 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/Ring.h>
Expand Down Expand Up @@ -95,7 +96,6 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;

public:
/// @name Standard Constructors
/// @{

Expand Down Expand Up @@ -172,6 +172,17 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
return apply(f, safe_div);
}

/// divide by DiscreteFactor::shared_ptr f (safely)
TableFactor operator/(const DiscreteFactor::shared_ptr& f) const {
if (auto tf = std::dynamic_pointer_cast<TableFactor>(f)) {
return apply(*tf, safe_div);
} else if (auto dtf = std::dynamic_pointer_cast<DecisionTreeFactor>(f)) {
return apply(TableFactor(f->discreteKeys(), *dtf), safe_div);
} else {
throw std::runtime_error("Unknown derived type for DiscreteFactor");
}
}

/// Convert into a decisiontree
DecisionTreeFactor toDecisionTreeFactor() const override;

Expand All @@ -180,22 +191,22 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
DiscreteKeys parent_keys) const;

/// Create new factor by summing all values with the same separator values
shared_ptr sum(size_t nrFrontals) const {
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::add);
}

/// Create new factor by summing all values with the same separator values
shared_ptr sum(const Ordering& keys) const {
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
return combine(keys, Ring::add);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
return combine(nrFrontals, Ring::max);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
return combine(keys, Ring::max);
}

Expand Down Expand Up @@ -300,6 +311,12 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
*/
TableFactor prune(size_t maxNrAssignments) const;

/**
* Get the number of non-zero values contained in this factor.
* It could be much smaller than `prod_{key}(cardinality(key))`.
*/
uint64_t nrValues() const override { return sparse_table_.nonZeros(); }

/// @}
/// @name Wrapper support
/// @{
Expand Down
6 changes: 3 additions & 3 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ TEST(DecisionTreeFactor, sum_max) {
DecisionTreeFactor f1(v0 & v1, "1 2 3 4 5 6");

DecisionTreeFactor expected(v1, "9 12");
DecisionTreeFactor::shared_ptr actual = f1.sum(1);
auto actual = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.sum(1));
CHECK(assert_equal(expected, *actual, 1e-5));

DecisionTreeFactor expected2(v1, "5 6");
DecisionTreeFactor::shared_ptr actual2 = f1.max(1);
auto actual2 = std::dynamic_pointer_cast<DecisionTreeFactor>(f1.max(1));
CHECK(assert_equal(expected2, *actual2));

DecisionTreeFactor f2(v1 & v0, "1 2 3 4 5 6");
DecisionTreeFactor::shared_ptr actual22 = f2.sum(1);
auto actual22 = std::dynamic_pointer_cast<DecisionTreeFactor>(f2.sum(1));
}

/* ************************************************************************* */
Expand Down
4 changes: 2 additions & 2 deletions gtsam/discrete/tests/testDiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ TEST(DiscreteConditional, constructors) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
DecisionTreeFactor expected2 = f2 / f2.sum(1);
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));

std::vector<double> probs{0.2, 0.5, 0.3, 0.6, 0.4, 0.7, 0.25, 0.55, 0.35, 0.65, 0.45, 0.75};
Expand All @@ -70,7 +70,7 @@ TEST(DiscreteConditional, constructors_alt_interface) {
DecisionTreeFactor f2(
X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75");
DiscreteConditional actual2(1, f2);
DecisionTreeFactor expected2 = f2 / *f2.sum(1);
DecisionTreeFactor expected2 = f2 / f2.sum(1);
EXPECT(assert_equal(expected2, static_cast<DecisionTreeFactor>(actual2)));
}

Expand Down
4 changes: 2 additions & 2 deletions gtsam/discrete/tests/testDiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ TEST(DiscreteFactorGraph, test) {
// Normalize newFactor by max for comparison with expected
auto normalization = newFactor.max(newFactor.size());

newFactor = newFactor / *normalization;
newFactor = newFactor / normalization;

// Check Conditional
CHECK(conditional);
Expand All @@ -133,7 +133,7 @@ TEST(DiscreteFactorGraph, test) {
// Normalize by max.
normalization = expectedFactor.max(expectedFactor.size());
// Ensure normalization is correct.
expectedFactor = expectedFactor / *normalization;
expectedFactor = expectedFactor / normalization;
EXPECT(assert_equal(expectedFactor, newFactor));

// Test using elimination tree
Expand Down
6 changes: 3 additions & 3 deletions gtsam/discrete/tests/testTableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,15 @@ TEST(TableFactor, sum_max) {
TableFactor f1(v0 & v1, "1 2 3 4 5 6");

TableFactor expected(v1, "9 12");
TableFactor::shared_ptr actual = f1.sum(1);
auto actual = std::dynamic_pointer_cast<TableFactor>(f1.sum(1));
CHECK(assert_equal(expected, *actual, 1e-5));

TableFactor expected2(v1, "5 6");
TableFactor::shared_ptr actual2 = f1.max(1);
auto actual2 = std::dynamic_pointer_cast<TableFactor>(f1.max(1));
CHECK(assert_equal(expected2, *actual2));

TableFactor f2(v1 & v0, "1 2 3 4 5 6");
TableFactor::shared_ptr actual22 = f2.sum(1);
auto actual22 = std::dynamic_pointer_cast<TableFactor>(f2.sum(1));
}

/* ************************************************************************* */
Expand Down
19 changes: 19 additions & 0 deletions gtsam_unstable/discrete/AllDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,25 @@ class GTSAM_UNSTABLE_EXPORT AllDiff : public Constraint {
/// Partially apply known values, domain version
Constraint::shared_ptr partiallyApply(
const Domains&) const override;

/// Get the number of non-zero values contained in this factor.
uint64_t nrValues() const override { return 1; };

DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override {
throw std::runtime_error("Not implemented");
}

DiscreteFactor::shared_ptr sum(const Ordering& keys) const override {
throw std::runtime_error("Not implemented");
}

DiscreteFactor::shared_ptr max(size_t nrFrontals) const override {
throw std::runtime_error("Not implemented");
}

DiscreteFactor::shared_ptr max(const Ordering& keys) const override {
throw std::runtime_error("Not implemented");
}
};

} // namespace gtsam
Loading
Loading