Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor NearestCentroid class #5053

Conversation

LiuYuHui
Copy link
Contributor

@LiuYuHui LiuYuHui commented Jun 3, 2020

  • Add fit/predict to Machine class
  • Refactor NearestCentroid to be stateless
  • Add NearestCentroid unittest

@LiuYuHui LiuYuHui force-pushed the refactor-class-stateless branch from 2c5a5b4 to 2e6f793 Compare June 3, 2020 02:15
Copy link
Member

@vigsterkr vigsterkr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some food for thought

}

bool NearestCentroid::train_machine(std::shared_ptr<Features> data)
{
error("the train_machine have been deprecated, please use fit instead");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhoh :) the deprecation should mean that it's still functional :)
so either use a warning and still be functional (but i guess that's not possible coz of not having labels) so in this case i would just drop the whole thing :)
if Machine::train_machine is a pure virtual function, then maybe we should just create there an implementation with this error....

src/shogun/machine/Machine.h Outdated Show resolved Hide resolved
*
* @return The predicted Label for each sample in X .
*/
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X)
virtual std::shared_ptr<Labels> predict(const std::shared_ptr<Features>& X)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict should be const right?

src/shogun/classifier/NearestCentroid.cpp Outdated Show resolved Hide resolved
if(total>1)
scale[i] = 1.0/((float64_t)(total-1));
if (total > 0)
scale[i] = 1.0 / (float64_t)(total);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
scale[i] = 1.0 / (float64_t)(total);
scale[i] = 1.0 / static_cast<float64_t>(total);

Doesn't change anything but it is more explicit

auto iter_labels = multiclass_labels->get_labels().begin();
std::map<int64_t, int64_t> label_to_index;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not unordered_map and then reserve num_classes?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this labels to index, this is a problem that re-occurs in other places, so I'd not solve it here.
Actually, we have some mechanisms in place that convert labels internally (code is e.g. in multiclass labels).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karlnapf didn't you start a PR that does this mapping?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, there is something merged even.
But we need to re-think/redesign parts of it as it had issues ...

for (const auto& current : DotIterator(dense_data))
{
const auto current_class = static_cast<int32_t>(*(iter_labels++));
auto target = centroids.get_column(current_class);
auto curr_index = index;
if (label_to_index.find(current_class) == label_to_index.end())
Copy link
Member

@gf712 gf712 Jun 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could rewrite this in order to reuse the iterator in the else branch. I am not sure if the compiler will do this for you and you might get some performance considering this is a inner loop.


int32_t num_classes = multiclass_labels->get_num_classes();
int32_t num_feats = dense_data->get_num_features();

SGMatrix<float64_t> centroids(num_feats, num_classes);
SGVector<int64_t> num_per_class(num_classes);

linalg::zero(centroids);
linalg::zero(num_per_class);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't the constructor memset to 0?

@gf712
Copy link
Member

gf712 commented Jun 3, 2020

This is great! :)

auto multiclass_labels = m_labels->as<MulticlassLabels>();
auto dense_data = data->as<DenseFeatures<float64_t>>();
auto multiclass_labels = lab->as<MulticlassLabels>();
auto dense_data = feat->as<DenseFeatures<float64_t>>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these could be const right?

auto multiclass_labels = m_labels->as<MulticlassLabels>();
auto dense_data = data->as<DenseFeatures<float64_t>>();
auto multiclass_labels = lab->as<MulticlassLabels>();
auto dense_data = feat->as<DenseFeatures<float64_t>>();

int32_t num_classes = multiclass_labels->get_num_classes();
int32_t num_feats = dense_data->get_num_features();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const I think?

* @return whether training was successful
* @return The predicted Label for each sample in X .
*/
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X);
std::shared_ptr<Labels> predict(std::shared_ptr<Features> X) override;

*/

virtual std::shared_ptr<Machine>
fit(std::shared_ptr<Features> feat, std::shared_ptr<Labels> lab);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fit(std::shared_ptr<Features> feat, std::shared_ptr<Labels> lab);
fit(std::shared_ptr<Features> feat, std::shared_ptr<Labels> lab) override;

data = distance->get_lhs();
}
// ASSERT(distance)
require(distance, "distance not set");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick minor: Capital D

}
// ASSERT(distance)
require(distance, "distance not set");
require(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vigsterkr @gf712 could we pls move this check into the base class? It is repeated 100 times in all sorts of different styles ...

Copy link
Member

@gf712 gf712 Jun 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, the labels and features base classes implement get_num_labels and get_num_vectors, so it makes sense to have the check in Machine

std::shared_ptr<Labels>
NearestCentroid::predict(std::shared_ptr<Features> data)
{
require(data, "features not set");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vigsterkr @gf712 while we are touching these things. Wouldnt it be better to do all these sort of input checks in the base class and then call a (virtual) method that can assume that

  • the pointers are not null
  • the number of examples are the same

so we don't have to have all this duplicate checks?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup agreed, like train calls train_machine, predict does the checks and predict_labels is pure virtual

{
SGMatrix<float64_t> X{{-10, -1}, {-2, -1}, {-3, -2},
{1, 1}, {2, 1}, {3, 2}};
SGVector<float64_t> y{1, 1, 1, 2, 2, 2};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGVector<float64_t> y{0,0,0,1,1,1};

@LiuYuHui LiuYuHui changed the base branch from develop to feature/machine_refactor June 16, 2020 00:47
@LiuYuHui LiuYuHui force-pushed the refactor-class-stateless branch from 2e6f793 to 7cfb209 Compare June 16, 2020 03:07
@karlnapf
Copy link
Member

might to rebase/force push

@LiuYuHui LiuYuHui force-pushed the refactor-class-stateless branch from 7cfb209 to ea3da01 Compare June 17, 2020 02:51
@LiuYuHui LiuYuHui changed the title Refactor NearestCentroid to be stateless Refactor NearestCentroid class Jun 17, 2020
@@ -34,25 +32,16 @@ namespace shogun{

void NearestCentroid::init()
{
m_shrinking=0;
m_is_trained=false;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

Copy link
Member

@karlnapf karlnapf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@gf712 gf712 force-pushed the feature/machine_refactor branch 5 times, most recently from 9fae5e2 to b9d5bfa Compare June 18, 2020 10:09
@LiuYuHui LiuYuHui force-pushed the refactor-class-stateless branch from 5c70702 to cec247b Compare June 18, 2020 11:37
@@ -16,7 +16,7 @@ Distance d = create_distance("EuclideanDistance")

#![create_instance]
int k = 3
Machine knn = create_machine("KNN", k=k, distance=d, labels=labels_train)
KNN knn(k, d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not Machine knn = create_machine("KNN", k=k, distance=d)?

@@ -552,7 +552,7 @@
"plt.figure(figsize=(15,5))\n",
"plt.subplot(121)\n",
"plt.title(\"Nearest Neighbors - Linear Features\")\n",
"plot_model(plt,knn_linear,feats_linear,labels_linear,fading=False)\n",
"plot_model(plt,knn_linear,feats_linear,labels_linear,fading=True)\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why True?

@LiuYuHui LiuYuHui force-pushed the refactor-class-stateless branch from ba12148 to af2a0fc Compare June 22, 2020 01:42
@gf712 gf712 force-pushed the feature/machine_refactor branch from 5a866f9 to 6d77d19 Compare July 6, 2020 08:23
@gf712 gf712 force-pushed the feature/machine_refactor branch from e2673fc to 97f1d5d Compare July 28, 2020 16:45
@LiuYuHui LiuYuHui force-pushed the refactor-class-stateless branch from af2a0fc to fb4e506 Compare July 29, 2020 09:02
@gf712 gf712 force-pushed the feature/machine_refactor branch 2 times, most recently from 72046cb to aa741be Compare July 29, 2020 13:43
@LiuYuHui LiuYuHui force-pushed the refactor-class-stateless branch from fb4e506 to 3bb0ffe Compare July 30, 2020 02:19
@gf712 gf712 force-pushed the feature/machine_refactor branch 3 times, most recently from aa741be to 7a37af4 Compare July 30, 2020 08:56
@gf712
Copy link
Member

gf712 commented Jul 30, 2020

I think you changed the data commit?

@LiuYuHui LiuYuHui force-pushed the refactor-class-stateless branch from c61b683 to aa414d5 Compare July 30, 2020 11:20
@gf712
Copy link
Member

gf712 commented Jul 31, 2020

Can merge this as the CI error seems to be unrelated

@gf712 gf712 merged commit afbdeac into shogun-toolbox:feature/machine_refactor Jul 31, 2020
gf712 pushed a commit that referenced this pull request Dec 8, 2020
* Add NonParametricMachine class (#5055)
* add nonparametric machine
* fix notebooks
* Refactor NearestCentroid class
gf712 pushed a commit that referenced this pull request Dec 8, 2020
* Add NonParametricMachine class (#5055)
* add nonparametric machine
* fix notebooks
* Refactor NearestCentroid class
gf712 pushed a commit that referenced this pull request Dec 8, 2020
* Add NonParametricMachine class (#5055)
* add nonparametric machine
* fix notebooks
* Refactor NearestCentroid class
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants