Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make Random Forest observable. #4655

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/shogun/base/SGObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ void SGObject::unsubscribe(const std::shared_ptr<ParameterObserver>& obs)
obs->put("subscription_id", static_cast<int64_t>(-1));
}

void SGObject::observe(std::shared_ptr<ObservedValue> value) const
void SGObject::observe(const std::shared_ptr<ObservedValue>& value) const
{
m_subscriber_params->on_next(value);
}
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -1103,7 +1103,7 @@ class SGObject: public std::enable_shared_from_this<SGObject>
* Observe a parameter value, given a pointer.
* @param value Observed parameter's value
*/
void observe(std::shared_ptr<ObservedValue> value) const;
void observe(const std::shared_ptr<ObservedValue>& value) const;

/**
* Observe a parameter value given custom properties for the Any.
Expand Down
71 changes: 25 additions & 46 deletions src/shogun/machine/BaggingMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
#include <shogun/base/progress.h>
#include <shogun/ensemble/CombinationRule.h>
#include <shogun/ensemble/MeanRule.h>
#include <shogun/evaluation/BinaryClassEvaluation.h>
#include <shogun/evaluation/ContingencyTableEvaluation.h>
#include <shogun/evaluation/Evaluation.h>
#include <shogun/evaluation/MeanSquaredError.h>
#include <shogun/evaluation/MulticlassAccuracy.h>
#include <shogun/lib/observers/ObservedValueTemplated.h>
#include <shogun/machine/BaggingMachine.h>
#include <shogun/mathematics/UniformIntDistribution.h>
#include <shogun/mathematics/linalg/LinalgNamespace.h>
#include <shogun/evaluation/Evaluation.h>

#include <utility>

Expand Down Expand Up @@ -203,12 +208,26 @@ bool BaggingMachine::train_machine(std::shared_ptr<Features> data)

#pragma omp critical
{
// get out of bag indexes
auto oob = get_oob_indices(idx);
m_oob_indices.push_back(oob);

// add trained machine to bag array
m_bags.push_back(c);
// get out of bag indexes
auto oob = get_oob_indices(idx);
m_oob_indices.push_back(oob);

// add trained machine to bag array
m_bags.push_back(c);

// observe some variables. The oob error is computed only when when
// we have all the variables set and there are observers attached.
this->observe<std::shared_ptr<Machine>>(
i, "trained_machine", "Trained machine for this bag", c);

if (this->m_combination_rule && this->m_oob_evaluation_metric &&
this->get_num_subscriptions() != 0)
{
auto oob_error = this->get_oob_error();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: you now always compute the oob error whenever there is an observer ...Is that good? I forgot how the specifying to be observed variables works, but can't the user somehow pick what she wants to observe? In that case this should only be computed if the user requested it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the idea is that as long as the model has some observer attached, then it will output all the available observations. This means that the OOB error will be computed regardless of the user preferences.

It will be the observer job to filter the observations and keep just the ones that the user wants.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, there is no way to tell an observable model to emit only certain observations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think that’s good and we should probably think about a way to avoid this. Passing (and ignoring) pointers to existing data is one thing, but expensive operations should only be executed if desired.
I could imagine that this could be done with a lambda function that is only executed if the observer actually wants (as in user instructed) to store the data. Maybe @gf712 has an idea he is the lambda master?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the problem. One small issue here is that the model does not have any idea about which things the observers want to observe. I guess the solution could be emitting functions (lambdas) instead which can then be executed by the observers if they want to store that kind of data. Is this what you meant @karlnapf?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the model doesnt know. In general, it should offer the observer to observe something (which the observer then can or cannot do, depending on its settings). Passing a pointer is a good example for that. The observer can clone the underlying structure for example.
But for costly things, passing a lambda that would compute something expensive is a better choice imo, so yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that makes sense! I will do some experimenting to try adding this lambda feature. I'm leaving this one as it is (or maybe we could merge it in a feature branch) and then I will open a new PR with the lambda observers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I say let's just keep this PR open, and then you could rebase it once the the lambda thing works. Alternatively, just remove the oob error stuff and we can merge everything else?

this->observe<float64_t>(
i, "oob_error", "Out-of-bag Error", oob_error);
}
}

pb.print_progress();
Expand Down Expand Up @@ -242,36 +261,6 @@ void BaggingMachine::register_parameters()
watch_method(kOobError, &BaggingMachine::get_oob_error);
}

void BaggingMachine::set_num_bags(int32_t num_bags)
{
m_num_bags = num_bags;
}

int32_t BaggingMachine::get_num_bags() const
{
return m_num_bags;
}

void BaggingMachine::set_bag_size(int32_t bag_size)
{
m_bag_size = bag_size;
}

int32_t BaggingMachine::get_bag_size() const
{
return m_bag_size;
}

std::shared_ptr<Machine> BaggingMachine::get_machine() const
{
return m_machine;
}

void BaggingMachine::set_machine(std::shared_ptr<Machine> machine)
{
m_machine = std::move(machine);
}

void BaggingMachine::init()
{
m_machine = nullptr;
Expand All @@ -284,15 +273,6 @@ void BaggingMachine::init()
m_oob_evaluation_metric = nullptr;
}

void BaggingMachine::set_combination_rule(std::shared_ptr<CombinationRule> rule)
{
m_combination_rule = std::move(rule);
}

std::shared_ptr<CombinationRule> BaggingMachine::get_combination_rule() const
{
return m_combination_rule;
}

float64_t BaggingMachine::get_oob_error() const
{
Expand Down Expand Up @@ -368,7 +348,6 @@ float64_t BaggingMachine::get_oob_error() const
error("Unsupported label type");
}


m_labels->add_subset(SGVector<index_t>(idx.data(), idx.size(), false));
float64_t res = m_oob_evaluation_metric->evaluate(predicted, m_labels);
m_labels->remove_subset();
Expand Down
61 changes: 1 addition & 60 deletions src/shogun/machine/BaggingMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,65 +44,6 @@ namespace shogun
virtual std::shared_ptr<MulticlassLabels> apply_multiclass(std::shared_ptr<Features> data=NULL);
virtual std::shared_ptr<RegressionLabels> apply_regression(std::shared_ptr<Features> data=NULL);

/**
* Set number of bags/machine to create
*
* @param num_bags number of bags
*/
void set_num_bags(int32_t num_bags);

/**
* Get number of bags/machines
*
* @return number of bags
*/
int32_t get_num_bags() const;

/**
* Set number of feature vectors to use
* for each bag/machine
*
* @param bag_size number of vectors to use for a bag
*/
virtual void set_bag_size(int32_t bag_size);

/**
* Get number of feature vectors that are use
* for training each bag/machine
*
* @return number of vectors used for training for each bag.
*/
virtual int32_t get_bag_size() const;

/**
* Get machine for bagging
*
* @return machine that is being used in bagging
*/
std::shared_ptr<Machine> get_machine() const;

/**
* Set machine to use in bagging
*
* @param machine the machine to use for bagging
*/
virtual void set_machine(std::shared_ptr<Machine> machine);

/**
* Set the combination rule to use for aggregating the classification
* results
*
* @param rule combination rule
*/
void set_combination_rule(std::shared_ptr<CombinationRule> rule);

/**
* Get the combination rule that is used for aggregating the results
*
* @return CombinationRule
*/
std::shared_ptr<CombinationRule> get_combination_rule() const;

/** get classifier type
*
* @return classifier type CT_BAGGING
Expand Down Expand Up @@ -210,7 +151,7 @@ namespace shogun
static constexpr std::string_view kMachine = "machine";
static constexpr std::string_view kOobError = "oob_error";
static constexpr std::string_view kOobEvaluationMetric = "oob_evaluation_metric";
#endif
#endif
};
} // namespace shogun

Expand Down
6 changes: 3 additions & 3 deletions src/shogun/machine/RandomForest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ RandomForest::RandomForest(int32_t rand_numfeats, int32_t num_bags)
{
init();

set_num_bags(num_bags);
m_num_bags = num_bags;

if (rand_numfeats>0)
m_machine->as<RandomCARTree>()->set_feature_subset_size(rand_numfeats);
Expand All @@ -60,7 +60,7 @@ RandomForest::RandomForest(std::shared_ptr<Features> features, std::shared_ptr<L
m_features=std::move(features);
set_labels(std::move(labels));

set_num_bags(num_bags);
m_num_bags = num_bags;

if (rand_numfeats>0)
m_machine->as<RandomCARTree>()->set_feature_subset_size(rand_numfeats);
Expand All @@ -75,7 +75,7 @@ RandomForest::RandomForest(std::shared_ptr<Features> features, std::shared_ptr<L
set_labels(std::move(labels));
m_weights=weights;

set_num_bags(num_bags);
m_num_bags = num_bags;

if (rand_numfeats>0)
m_machine->as<RandomCARTree>()->set_feature_subset_size(rand_numfeats);
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/labels/MockLabels.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ namespace shogun {
MOCK_CONST_METHOD0(get_num_labels, int32_t());
MOCK_CONST_METHOD0(get_label_type, ELabelType());
MOCK_METHOD0(get_values, SGVector<float64_t>());
MOCK_CONST_METHOD0(get_labels, SGVector<float64_t>());

virtual const char* get_name() const { return "MockLabels"; }
virtual const char* get_name() const
{
return "MockLabels";
}
};

} // namespace shogun
Expand Down
63 changes: 47 additions & 16 deletions tests/unit/multiclass/BaggingMachine_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <shogun/features/DenseFeatures.h>
#include <shogun/lib/SGMatrix.h>
#include <shogun/lib/config.h>
#include <shogun/lib/observers/ParameterObserverLogger.h>
#include <shogun/machine/BaggingMachine.h>
#include <shogun/mathematics/linalg/LinalgNamespace.h>
#include <shogun/multiclass/tree/CARTree.h>
Expand Down Expand Up @@ -84,10 +85,10 @@ TEST_F(BaggingMachineTest, mock_train)
auto mv = std::make_shared<MajorityVote>();

env()->set_num_threads(1);
bm->set_machine(mm);
bm->set_bag_size(bag_size);
bm->set_num_bags(num_bags);
bm->set_combination_rule(mv);
bm->put<Machine>("machine", mm);
bm->put("bag_size", bag_size);
bm->put("num_bags", num_bags);
bm->put<CombinationRule>("combination_rule", mv);
bm->put("seed", seed);

ON_CALL(*mm, train_machine(_))
Expand Down Expand Up @@ -123,11 +124,12 @@ TEST_F(BaggingMachineTest, classify_CART)
auto c = std::make_shared<BaggingMachine>(features_train, labels_train);

env()->set_num_threads(1);
c->set_machine(cart);
c->set_bag_size(14);
c->set_num_bags(10);
c->set_combination_rule(cv);
c->put<Machine>("machine", cart);
c->put("bag_size", 14);
c->put("num_bags", 10);
c->put<CombinationRule>("combination_rule", cv);
c->put("seed", seed);

c->train(features_train);

auto result = c->apply_multiclass(features_test);
Expand All @@ -153,10 +155,10 @@ TEST_F(BaggingMachineTest, output_binary)
cart->set_feature_types(ft);
auto c = std::make_shared<BaggingMachine>(features_train, labels_train);
env()->set_num_threads(1);
c->set_machine(cart);
c->set_bag_size(14);
c->set_num_bags(10);
c->set_combination_rule(cv);
c->put<Machine>("machine", cart);
c->put("bag_size", 14);
c->put("num_bags", 10);
c->put<CombinationRule>("combination_rule", cv);
c->put("seed", seed);
c->train(features_train);

Expand Down Expand Up @@ -186,10 +188,11 @@ TEST_F(BaggingMachineTest, output_multiclass_probs_sum_to_one)

cart->set_feature_types(ft);
auto c = std::make_shared<BaggingMachine>(features_train, labels_train);
c->set_machine(cart);
c->set_bag_size(14);
c->set_num_bags(10);
c->set_combination_rule(cv);

c->put<Machine>("machine", cart);
c->put("bag_size", 14);
c->put("num_bags", 10);
c->put<CombinationRule>("combination_rule", cv);
c->put("seed", seed);
c->train(features_train);

Expand All @@ -213,3 +216,31 @@ TEST_F(BaggingMachineTest, output_multiclass_probs_sum_to_one)


}

TEST_F(BaggingMachineTest, observable_bagging_machine)
{
int32_t seed = 555;
auto cart = std::make_shared<CARTree>();
auto cv = std::make_shared<MajorityVote>();
cart->set_feature_types(ft);

auto c = std::make_shared<BaggingMachine>(features_train, labels_train);

auto oob_eval = std::make_shared<MulticlassAccuracy>();
c->put<Evaluation>("oob_evaluation_metric", oob_eval);

auto obs = std::make_shared<ParameterObserverLogger>();
c->subscribe(obs);

env()->set_num_threads(1);
c->put<Machine>("machine", cart);
c->put("bag_size", 14);
c->put("num_bags", 10);
c->put<CombinationRule>("combination_rule", cv);
c->put("seed", seed);

c->train(features_train);

EXPECT_EQ(obs->get<int32_t>("num_observations"), 20);
c->unsubscribe(obs);
}
8 changes: 4 additions & 4 deletions tests/unit/multiclass/tree/RandomForest_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ TEST_F(RandomForestTest, classify_nominal_test)
std::make_shared<RandomForest>(weather_features_train, weather_labels_train, 100, 2);
c->set_feature_types(weather_ft);
auto mv = std::make_shared<MajorityVote>();
c->set_combination_rule(mv);
c->put<CombinationRule>("combination_rule", mv);
env()->set_num_threads(1);
c->put("seed", seed);
c->train(weather_features_train);
Expand Down Expand Up @@ -129,7 +129,7 @@ TEST_F(RandomForestTest, classify_non_nominal_test)
std::make_shared<RandomForest>(weather_features_train, weather_labels_train, 100, 2);
c->set_feature_types(weather_ft);
auto mv = std::make_shared<MajorityVote>();
c->set_combination_rule(mv);
c->put<CombinationRule>("combination_rule", mv);
env()->set_num_threads(1);
c->put("seed", seed);
c->train(weather_features_train);
Expand Down Expand Up @@ -173,7 +173,7 @@ TEST_F(RandomForestTest, score_compare_sklearn_toydata)
c->set_feature_types(ft);

auto mr = std::make_shared<MeanRule>();
c->set_combination_rule(mr);
c->put<CombinationRule>("combination_rule", mr);
c->put("seed", seed);
c->train(features_train);

Expand Down Expand Up @@ -232,7 +232,7 @@ TEST_F(RandomForestTest, score_consistent_with_binary_trivial_data)
c->set_feature_types(ft);

auto mr = std::make_shared<MeanRule>();
c->set_combination_rule(mr);
c->put<CombinationRule>("combination_rule", mr);
c->put("seed", seed);
c->train(features_train);

Expand Down