-
-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor NearestCentroid class #5053
Refactor NearestCentroid class #5053
Conversation
LiuYuHui
commented
Jun 3, 2020
- Add fit/predict to Machine class
- Refactor NearestCentroid to be stateless
- Add NearestCentroid unittest
2c5a5b4
to
2e6f793
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some food for thought
} | ||
|
||
bool NearestCentroid::train_machine(std::shared_ptr<Features> data) | ||
{ | ||
error("the train_machine have been deprecated, please use fit instead"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhoh :) the deprecation should mean that it's still functional :)
so either use a warning and still be functional (but i guess that's not possible coz of not having labels) so in this case i would just drop the whole thing :)
if Machine::train_machine is a pure virtual function, then maybe we should just create there an implementation with this error....
src/shogun/machine/Machine.h
Outdated
* | ||
* @return The predicted Label for each sample in X . | ||
*/ | ||
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X) | |
virtual std::shared_ptr<Labels> predict(const std::shared_ptr<Features>& X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
predict
should be const
right?
if(total>1) | ||
scale[i] = 1.0/((float64_t)(total-1)); | ||
if (total > 0) | ||
scale[i] = 1.0 / (float64_t)(total); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scale[i] = 1.0 / (float64_t)(total); | |
scale[i] = 1.0 / static_cast<float64_t>(total); |
Doesn't change anything but it is more explicit
auto iter_labels = multiclass_labels->get_labels().begin(); | ||
std::map<int64_t, int64_t> label_to_index; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not unordered_map and then reserve num_classes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this labels to index, this is a problem that re-occurs in other places, so I'd not solve it here.
Actually, we have some mechanisms in place that convert labels internally (code is e.g. in multiclass labels).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@karlnapf didn't you start a PR that does this mapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, there is something merged even.
But we need to re-think/redesign parts of it as it had issues ...
for (const auto& current : DotIterator(dense_data)) | ||
{ | ||
const auto current_class = static_cast<int32_t>(*(iter_labels++)); | ||
auto target = centroids.get_column(current_class); | ||
auto curr_index = index; | ||
if (label_to_index.find(current_class) == label_to_index.end()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could rewrite this in order to reuse the iterator in the else
branch. I am not sure if the compiler will do this for you and you might get some performance considering this is a inner loop.
|
||
int32_t num_classes = multiclass_labels->get_num_classes(); | ||
int32_t num_feats = dense_data->get_num_features(); | ||
|
||
SGMatrix<float64_t> centroids(num_feats, num_classes); | ||
SGVector<int64_t> num_per_class(num_classes); | ||
|
||
linalg::zero(centroids); | ||
linalg::zero(num_per_class); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't the constructor memset to 0?
This is great! :) |
auto multiclass_labels = m_labels->as<MulticlassLabels>(); | ||
auto dense_data = data->as<DenseFeatures<float64_t>>(); | ||
auto multiclass_labels = lab->as<MulticlassLabels>(); | ||
auto dense_data = feat->as<DenseFeatures<float64_t>>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these could be const
right?
auto multiclass_labels = m_labels->as<MulticlassLabels>(); | ||
auto dense_data = data->as<DenseFeatures<float64_t>>(); | ||
auto multiclass_labels = lab->as<MulticlassLabels>(); | ||
auto dense_data = feat->as<DenseFeatures<float64_t>>(); | ||
|
||
int32_t num_classes = multiclass_labels->get_num_classes(); | ||
int32_t num_feats = dense_data->get_num_features(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const
I think?
* @return whether training was successful | ||
* @return The predicted Label for each sample in X . | ||
*/ | ||
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
virtual std::shared_ptr<Labels> predict(std::shared_ptr<Features> X); | |
std::shared_ptr<Labels> predict(std::shared_ptr<Features> X) override; |
*/ | ||
|
||
virtual std::shared_ptr<Machine> | ||
fit(std::shared_ptr<Features> feat, std::shared_ptr<Labels> lab); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fit(std::shared_ptr<Features> feat, std::shared_ptr<Labels> lab); | |
fit(std::shared_ptr<Features> feat, std::shared_ptr<Labels> lab) override; |
data = distance->get_lhs(); | ||
} | ||
// ASSERT(distance) | ||
require(distance, "distance not set"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick minor: Capital D
} | ||
// ASSERT(distance) | ||
require(distance, "distance not set"); | ||
require( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vigsterkr @gf712 could we pls move this check into the base class? It is repeated 100 times in all sorts of different styles ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, the labels and features base classes implement get_num_labels
and get_num_vectors
, so it makes sense to have the check in Machine
std::shared_ptr<Labels> | ||
NearestCentroid::predict(std::shared_ptr<Features> data) | ||
{ | ||
require(data, "features not set"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vigsterkr @gf712 while we are touching these things. Wouldnt it be better to do all these sort of input checks in the base class and then call a (virtual) method that can assume that
- the pointers are not null
- the number of examples are the same
so we don't have to have all this duplicate checks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup agreed, like train
calls train_machine
, predict
does the checks and predict_labels
is pure virtual
{ | ||
SGMatrix<float64_t> X{{-10, -1}, {-2, -1}, {-3, -2}, | ||
{1, 1}, {2, 1}, {3, 2}}; | ||
SGVector<float64_t> y{1, 1, 1, 2, 2, 2}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGVector<float64_t> y{0,0,0,1,1,1};
2e6f793
to
7cfb209
Compare
might to rebase/force push |
7cfb209
to
ea3da01
Compare
@@ -34,25 +32,16 @@ namespace shogun{ | |||
|
|||
void NearestCentroid::init() | |||
{ | |||
m_shrinking=0; | |||
m_is_trained=false; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
9fae5e2
to
b9d5bfa
Compare
5c70702
to
cec247b
Compare
@@ -16,7 +16,7 @@ Distance d = create_distance("EuclideanDistance") | |||
|
|||
#![create_instance] | |||
int k = 3 | |||
Machine knn = create_machine("KNN", k=k, distance=d, labels=labels_train) | |||
KNN knn(k, d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not Machine knn = create_machine("KNN", k=k, distance=d)
?
@@ -552,7 +552,7 @@ | |||
"plt.figure(figsize=(15,5))\n", | |||
"plt.subplot(121)\n", | |||
"plt.title(\"Nearest Neighbors - Linear Features\")\n", | |||
"plot_model(plt,knn_linear,feats_linear,labels_linear,fading=False)\n", | |||
"plot_model(plt,knn_linear,feats_linear,labels_linear,fading=True)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why True?
ba12148
to
af2a0fc
Compare
5a866f9
to
6d77d19
Compare
e2673fc
to
97f1d5d
Compare
af2a0fc
to
fb4e506
Compare
72046cb
to
aa741be
Compare
fb4e506
to
3bb0ffe
Compare
aa741be
to
7a37af4
Compare
I think you changed the data commit? |
* add nonparametric machine * fix notebooks
c61b683
to
aa414d5
Compare
Can merge this as the CI error seems to be unrelated |
* Add NonParametricMachine class (#5055) * add nonparametric machine * fix notebooks * Refactor NearestCentroid class
* Add NonParametricMachine class (#5055) * add nonparametric machine * fix notebooks * Refactor NearestCentroid class
* Add NonParametricMachine class (#5055) * add nonparametric machine * fix notebooks * Refactor NearestCentroid class