Skip to content

Commit

Permalink
Merge pull request #1151 from borglab/feature/decision-tree-factor-prune
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal authored Mar 27, 2022
2 parents c2c54bc + 365473f commit 0850e89
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
34 changes: 34 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,5 +286,39 @@ namespace gtsam {
AlgebraicDecisionTree<Key>(keys, table),
cardinalities_(keys.cardinalities()) {}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const {
const size_t N = maxNrLeaves;

// Get the probabilities in the decision tree so we can threshold.
std::vector<double> probabilities;
this->visit([&](const double& prob) { probabilities.emplace_back(prob); });

// The number of probabilities can be lower than max_leaves
if (probabilities.size() <= N) {
return *this;
}

std::sort(probabilities.begin(), probabilities.end(),
std::greater<double>{});

double threshold = probabilities[N - 1];

// Now threshold the decision tree
size_t total = 0;
auto thresholdFunc = [threshold, &total, N](const double& value) {
if (value < threshold || total >= N) {
return 0.0;
} else {
total += 1;
return value;
}
};
DecisionTree<Key, double> thresholded(*this, thresholdFunc);

// Create pruned decision tree factor and return.
return DecisionTreeFactor(this->discreteKeys(), thresholded);
}

/* ************************************************************************ */
} // namespace gtsam
12 changes: 12 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,18 @@ namespace gtsam {
/// Return all the discrete keys associated with this factor.
DiscreteKeys discreteKeys() const;

/**
* @brief Prune the decision tree of discrete variables.
*
* Pruning will set the leaves to be "pruned" to 0 indicating a 0
* probability.
* A leaf is pruned if it is not in the top `maxNrLeaves` values.
*
* @param maxNrLeaves The maximum number of leaves to keep.
* @return DecisionTreeFactor
*/
DecisionTreeFactor prune(size_t maxNrLeaves) const;

/// @}
/// @name Wrapper support
/// @{
Expand Down
21 changes: 21 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ TEST(DecisionTreeFactor, enumerate) {
EXPECT(actual == expected);
}

/* ************************************************************************* */
// Check pruning of the decision tree works as expected.
TEST(DecisionTreeFactor, Prune) {
DiscreteKey A(1, 2), B(2, 2), C(3, 2);
DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8");

// Only keep the leaves with the top 5 values.
size_t maxNrLeaves = 5;
auto pruned5 = f.prune(maxNrLeaves);

// Pruned leaves should be 0
DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8");
EXPECT(assert_equal(expected, pruned5));

// Check for more extreme pruning where we only keep the top 2 leaves
maxNrLeaves = 2;
auto pruned2 = f.prune(maxNrLeaves);
DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8");
EXPECT(assert_equal(expected2, pruned2));
}

/* ************************************************************************* */
TEST(DecisionTreeFactor, DotWithNames) {
DiscreteKey A(12, 3), B(5, 2);
Expand Down

0 comments on commit 0850e89

Please sign in to comment.