Skip to content

Commit

Permalink
Refactor NearestCentroid to be stateless
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuYuHui committed Jun 3, 2020
1 parent a3f8d98 commit 2c5a5b4
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 53 deletions.
86 changes: 59 additions & 27 deletions src/shogun/classifier/NearestCentroid.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Philippe Tillet, Soeren Sonnenburg, Bjoern Esser, Sergey Lisitsyn
* Authors: Philippe Tillet, Soeren Sonnenburg, Bjoern Esser, Sergey Lisitsyn,
* Yuhui Liu
*/

#include <shogun/classifier/NearestCentroid.h>
Expand All @@ -19,13 +20,12 @@ namespace shogun{
init();
}

NearestCentroid::NearestCentroid(const std::shared_ptr<Distance>& d, const std::shared_ptr<Labels>& trainlab) : DistanceMachine()
NearestCentroid::NearestCentroid(const std::shared_ptr<Distance>& d)
: DistanceMachine()
{
init();
ASSERT(d)
ASSERT(trainlab)
set_distance(d);
set_labels(trainlab);
}

NearestCentroid::~NearestCentroid()
Expand All @@ -38,48 +38,59 @@ namespace shogun{
m_is_trained=false;
}


bool NearestCentroid::train_machine(std::shared_ptr<Features> data)
std::shared_ptr<Machine> NearestCentroid::fit(
std::shared_ptr<Features> feat, std::shared_ptr<Labels> lab)
{
ASSERT(m_labels)
ASSERT(distance)
if (data)
{
if (m_labels->get_num_labels() != data->get_num_vectors())
error("Number of training vectors does not match number of labels");
distance->init(data, data);
}
else
{
data = distance->get_lhs();
}
// ASSERT(distance)
require(distance, "distance not set");
require(
lab->get_num_labels() == feat->get_num_vectors(),
"Number of training vectors does not match number of labels");
distance->init(feat, feat);

auto multiclass_labels = m_labels->as<MulticlassLabels>();
auto dense_data = data->as<DenseFeatures<float64_t>>();
auto multiclass_labels = lab->as<MulticlassLabels>();
auto dense_data = feat->as<DenseFeatures<float64_t>>();

int32_t num_classes = multiclass_labels->get_num_classes();
int32_t num_feats = dense_data->get_num_features();

SGMatrix<float64_t> centroids(num_feats, num_classes);
SGVector<int64_t> num_per_class(num_classes);

linalg::zero(centroids);
linalg::zero(num_per_class);
auto iter_labels = multiclass_labels->get_labels().begin();
std::map<int64_t, int64_t> label_to_index;
int64_t index = 0;

for (const auto& current : DotIterator(dense_data))
{
const auto current_class = static_cast<int32_t>(*(iter_labels++));
auto target = centroids.get_column(current_class);
auto curr_index = index;
if (label_to_index.find(current_class) == label_to_index.end())
{
label_to_index[current_class] = index;
index_to_labels[index] = current_class;
curr_index = index;
index = index + 1;
}
else
{
curr_index = label_to_index[current_class];
}
auto target = centroids.get_column(curr_index);
current.add(1, target);
num_per_class[current_class]++;
num_per_class[curr_index]++;
}

SGVector<float64_t> scale(num_classes);

for (int32_t i=0 ; i<num_classes ; i++)
{
int32_t total = num_per_class[i];
if(total>1)
scale[i] = 1.0/((float64_t)(total-1));
if (total > 0)
scale[i] = 1.0 / (float64_t)(total);
else
scale[i] = 1.0/(float64_t)total;
scale[i] = 0.0;
}
linalg::scale(centroids, centroids, scale);

Expand All @@ -88,7 +99,28 @@ namespace shogun{
m_is_trained=true;
distance->init(centroids_feats, distance->get_rhs());

return true;
return shared_from_this()->as<NearestCentroid>();
}

bool NearestCentroid::train_machine(std::shared_ptr<Features> data)
{
error("the train_machine have been deprecated, please use fit instead");
return false;
}

std::shared_ptr<Labels>
NearestCentroid::predict(std::shared_ptr<Features> data)
{
require(data, "features not set");
auto lhs = distance->get_lhs();
distance->init(lhs, data);
auto result =
std::make_shared<MulticlassLabels>(data->get_num_vectors());
for (index_t i = 0; i < data->get_num_vectors(); ++i)
{
result->set_label(i, index_to_labels[apply_one(i)]);
}

return result;
}
}
29 changes: 23 additions & 6 deletions src/shogun/classifier/NearestCentroid.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class NearestCentroid : public DistanceMachine{
* @param distance distance
* @param trainlab labels for training
*/
NearestCentroid(const std::shared_ptr<Distance>& distance, const std::shared_ptr<Labels>& trainlab);
NearestCentroid(const std::shared_ptr<Distance>& distance);

/** Destructor
*/
Expand Down Expand Up @@ -81,14 +81,29 @@ class NearestCentroid : public DistanceMachine{
*/
virtual const char* get_name() const { return "NearestCentroid"; }

protected:
/** train Nearest Centroid classifier
*
* @param data training data (parameter can be avoided if distance or
* kernel-based classifiers are used and distance/kernels are
* initialized with train data)
* @param feat training data
* @param lab training labels
*
* @return pointer to current NearestCentroid object
*/

virtual std::shared_ptr<Machine>
fit(std::shared_ptr<Features> feat, std::shared_ptr<Labels> lab);

/** Perform classification on an array of test vectors X
*
* @param X test data
*
* @return whether training was successful
* @return The predicted Label for each sample in X .
*/
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X);

protected:
/** train Nearest Centroid classifier
*
* note: the train_machine have been deprecated, please use fit instead!
*/
virtual bool train_machine(std::shared_ptr<Features> data=NULL);

Expand All @@ -112,6 +127,8 @@ class NearestCentroid : public DistanceMachine{

/// Tells if the classifier has been trained or not
bool m_is_trained;

std::map<int64_t, int64_t> index_to_labels;
};

}
Expand Down
38 changes: 18 additions & 20 deletions src/shogun/machine/DistanceMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,25 @@ void DistanceMachine::distances_rhs(SGVector<float64_t>& result, index_t idx_b1,
}
}

std::shared_ptr<MulticlassLabels> DistanceMachine::apply_multiclass(std::shared_ptr<Features> data)
std::shared_ptr<Labels> DistanceMachine::predict(std::shared_ptr<Features> data)
{
if (data)
{
/* set distance features to given ones and apply to all */
auto lhs=distance->get_lhs();
distance->init(lhs, data);

/* build result labels and classify all elements of procedure */
auto result=std::make_shared<MulticlassLabels>(data->get_num_vectors());
for (index_t i=0; i<data->get_num_vectors(); ++i)
result->set_label(i, apply_one(i));
return result;
}
else
{
/* call apply on complete right hand side */
auto all=distance->get_rhs();
return apply_multiclass(all);
}
return NULL;
require(data, "features not set");
/* set distance features to given ones and apply to all */
auto lhs = distance->get_lhs();
distance->init(lhs, data);

/* build result labels and classify all elements of procedure */
auto result = std::make_shared<MulticlassLabels>(data->get_num_vectors());
for (index_t i = 0; i < data->get_num_vectors(); ++i)
result->set_label(i, apply_one(i));
return result;
}
std::shared_ptr<MulticlassLabels>
DistanceMachine::apply_multiclass(std::shared_ptr<Features> data)
{
error("the apply_multiclass have been deprecated, please use predict "
"instead");
return nullptr;
}

float64_t DistanceMachine::apply_one(int32_t num)
Expand Down
7 changes: 7 additions & 0 deletions src/shogun/machine/DistanceMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ class DistanceMachine : public Machine
*/
virtual std::shared_ptr<MulticlassLabels> apply_multiclass(std::shared_ptr<Features> data=NULL);

/** Perform classification on an array of test vectors X
*
* @param X test data
*
* @return The predicted Label for each sample in X .
*/
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X);
/** Apply machine to one example.
* Cluster index with smallest distance to to be classified element is
* returned
Expand Down
26 changes: 26 additions & 0 deletions src/shogun/machine/Machine.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,32 @@ class Machine : public StoppableSGObject
return true;
}

/** train classifier
*
* @param feat training data
* @param lab training labels
*
* @return pointer to current Machine object
*/
virtual std::shared_ptr<Machine>
fit(std::shared_ptr<Features> feat, std::shared_ptr<Labels> lab)
{
not_implemented(SOURCE_LOCATION);
return shared_from_this()->as<Machine>();
}

/** Perform classification on an array of test vectors X
*
* @param X test data
*
* @return The predicted Label for each sample in X .
*/
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X)
{
not_implemented(SOURCE_LOCATION);
return nullptr;
}

protected:
/** train machine
*
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/classifier/NearestCentroid_unittest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Yuhui Liu
*/
#include <gtest/gtest.h>
#include <shogun/classifier/NearestCentroid.h>
#include <shogun/distance/EuclideanDistance.h>
#include <shogun/labels/MulticlassLabels.h>

using namespace shogun;
TEST(NearestCentroid, fit_and_predict)
{
SGMatrix<float64_t> X{{-10, -1}, {-2, -1}, {-3, -2},
{1, 1}, {2, 1}, {3, 2}};
SGVector<float64_t> y{1, 1, 1, 2, 2, 2};

auto train_data = std::make_shared<DenseFeatures<float64_t>>(X);
auto train_labels = std::make_shared<MulticlassLabels>(y);
auto distance = std::make_shared<EuclideanDistance>();

SGMatrix<float64_t> t{{3, 2}, {-10, -1}};
auto test_data = std::make_shared<DenseFeatures<float64_t>>(t);
auto result_labels = std::make_shared<NearestCentroid>(distance)
->fit(train_data, train_labels)
->predict(test_data);
auto result = result_labels->as<MulticlassLabels>()->get_labels();
EXPECT_EQ(result[0], 2);
EXPECT_EQ(result[1], 1);
}

0 comments on commit 2c5a5b4

Please sign in to comment.