-
Notifications
You must be signed in to change notification settings - Fork 394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sklearn.ensemble.BaggingRegressor() #972
Comments
Could you give more details? Is it because you get the following error message?
|
The code is as follows. from sklearn.datasets import make_regression
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from skorch import NeuralNetRegressor
from sklearn.ensemble import BaggingRegressor
# This is a toy dataset for regression, 1000 data points with 20 features each
X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0)
X_regr = X_regr.astype(np.float32)
y_regr = y_regr.astype(np.float32) / 100
y_regr = y_regr.reshape(-1, 1)
class RegressorModule(nn.Module):
def __init__(
self,
num_units=10,
nonlin=F.relu,
):
super(RegressorModule, self).__init__()
self.num_units = num_units
self.nonlin = nonlin
self.dense0 = nn.Linear(20, num_units)
self.nonlin = nonlin
self.dense1 = nn.Linear(num_units, 10)
self.output = nn.Linear(10, 1)
def forward(self, X, **kwargs):
X = self.nonlin(self.dense0(X))
X = F.relu(self.dense1(X))
X = self.output(X)
return X
net_regr = NeuralNetRegressor(
RegressorModule,
max_epochs=20,
lr=0.1,
device='cuda', # uncomment this to train with CUDA
)
bagging = BaggingRegressor(estimator=net_regr, n_estimators=10, random_state=42)
bagging.fit(X_regr, y_regr)
# Making prediction for first 5 data points of X
y_pred = bagging.predict(X_regr[:5]) The brief run error is as follows.
|
Yes, I'm sorry about this annoyance, we should fix it in skorch. Until then, here is a small hack that should resolve the issue: class MyNeuralNetRegressor(NeuralNetRegressor):
def fit(self, X, y, **fit_params):
if y.ndim == 1:
y = y.reshape(-1, 1)
return super().fit(X, y, **fit_params) Replace |
Thank you very much. It worked. |
This change makes it possible to pass a 1-dimensional y to `NeuralNetRegressor`. Problem description Right now, skorch requires the `y` passed to `NeuralNetRegressor.fit` to be 2-dimensional, even if there is only one target, as is the most common case. This problem has come up a few times in the past, but mostly it's just an annoyance - just do `y.reshape(-1, 1)` and you're good (the error message says as much). There are, however, also cases where it's not so easy to solve. For instance, in #972, a user reports that they cannot use skorch with sklearn's `BaggingRegressor`. The problem is that even if `y` is reshaped, once it is passed to the net from `BaggingRegressor`, it is 1d again. I assume that `BaggingRegressor` internally squeezes `y` at some point. This PR lifts the 2d restriction check. Initial motivation Why does skorch require `y` to be 2d? I couldn't remember the initial reasoning and did some archeology. I found this comment: (2f00e25#diff-66ed08bca4d171889565d0285a36b9b47e0e91e3b33d85c51352d8eb00faefac): > # The problem with 1-dim float y is that the pytorch DataLoader will > # somehow upcast it to DoubleTensor This strange behavior should not be an issue anymore, so if that was the only problem, we should be able to just remove the constraint, right? Problems with removing the constraint Unfortunately, it's not that easy. The issue comes down to the following: When we remove the constraint and allow the target `y` to be 1d, but the prediction `y_pred` is still 2d, the criterion `nn.MSELoss` will probably do the wrong thing. What exactly is wrong? Instead of calculating the squared error for each sample pair, the criterion will broadcast the vector and calculate _all squared errors_ between each sample, then return the mean of that. To demonstrate, let's remove the reduction step and look at the shape: ```python >>> import torch >>> criterion = torch.nn.MSELoss(reduction='none') >>> y = torch.rand(100) >>> y_pred = torch.rand((100, 1)) >>> y.shape, y_pred.shape (torch.Size([100]), torch.Size([100, 1])) >>> se = criterion(y_pred, y) /home/vinh/anaconda3/envs/skorch/lib/python3.10/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([100])) that is different to the input size (torch.Size([100, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) >>> se.shape torch.Size([100, 100]) ``` As can be seen, PyTorch broadcasts the two arrays, leading to 100x100 errors being calculated. Thankfully, PyTorch warns about potential issues with that. The current solution is to accept this behavior and hope that the users will indeed see the warning. If they don't see it or ignore it, it could be a huge issue, because they still get a loss scalar and might even see a small improvement in the loss during training. But the model will not converge and it's going to be a huge pain to debug the bug, if it's even identified as such. Just to be clear, existing code, which uses 2d targets, will not be affected by the change introduced in this PR and is still the preferred way (IMO) to use regression in skorch. Rejected solutions I did consider the following solutions but rejected them. Raising an error when shapes mismatch This would remove the risk of users missing the warning. The problem with this is that mismatching shapes can be okay in certain circumstances. Some criteria don't expect target and prediction to have the same shape, so we would need to check based on criterion. Moreover, theoretically, users may indeed want to broadcast. Raising an error would prevent that and users may have to resort to subclassing to circumvent the error. Automatic reshaping We could automatically add/remove dimensions if we see that they mismatch. This has the same problems as the previous solution regarding the dependence on the type of criterion. Furthermore, automatic adjustment of the user's output is prone to run into issues in some edge cases (e.g. when the broadcasting is actually desired).
* Allow regression with 1d targets This change makes it possible to pass a 1-dimensional y to `NeuralNetRegressor`. Problem description Right now, skorch requires the `y` passed to `NeuralNetRegressor.fit` to be 2-dimensional, even if there is only one target, as is the most common case. This problem has come up a few times in the past, but mostly it's just an annoyance - just do `y.reshape(-1, 1)` and you're good (the error message says as much). There are, however, also cases where it's not so easy to solve. For instance, in #972, a user reports that they cannot use skorch with sklearn's `BaggingRegressor`. The problem is that even if `y` is reshaped, once it is passed to the net from `BaggingRegressor`, it is 1d again. I assume that `BaggingRegressor` internally squeezes `y` at some point. This PR lifts the 2d restriction check. Initial motivation Why does skorch require `y` to be 2d? I couldn't remember the initial reasoning and did some archeology. I found this comment: (2f00e25#diff-66ed08bca4d171889565d0285a36b9b47e0e91e3b33d85c51352d8eb00faefac): > # The problem with 1-dim float y is that the pytorch DataLoader will > # somehow upcast it to DoubleTensor This strange behavior should not be an issue anymore, so if that was the only problem, we should be able to just remove the constraint, right? Problems with removing the constraint Unfortunately, it's not that easy. The issue comes down to the following: When we remove the constraint and allow the target `y` to be 1d, but the prediction `y_pred` is still 2d, the criterion `nn.MSELoss` will probably do the wrong thing. What exactly is wrong? Instead of calculating the squared error for each sample pair, the criterion will broadcast the vector and calculate _all squared errors_ between each sample, then return the mean of that. To demonstrate, let's remove the reduction step and look at the shape: ```python >>> import torch >>> criterion = torch.nn.MSELoss(reduction='none') >>> y = torch.rand(100) >>> y_pred = torch.rand((100, 1)) >>> y.shape, y_pred.shape (torch.Size([100]), torch.Size([100, 1])) >>> se = criterion(y_pred, y) /home/vinh/anaconda3/envs/skorch/lib/python3.10/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([100])) that is different to the input size (torch.Size([100, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. return F.mse_loss(input, target, reduction=self.reduction) >>> se.shape torch.Size([100, 100]) ``` As can be seen, PyTorch broadcasts the two arrays, leading to 100x100 errors being calculated. Thankfully, PyTorch warns about potential issues with that. The current solution is to accept this behavior and hope that the users will indeed see the warning. If they don't see it or ignore it, it could be a huge issue, because they still get a loss scalar and might even see a small improvement in the loss during training. But the model will not converge and it's going to be a huge pain to debug the bug, if it's even identified as such. Just to be clear, existing code, which uses 2d targets, will not be affected by the change introduced in this PR and is still the preferred way (IMO) to use regression in skorch. Rejected solutions I did consider the following solutions but rejected them. Raising an error when shapes mismatch This would remove the risk of users missing the warning. The problem with this is that mismatching shapes can be okay in certain circumstances. Some criteria don't expect target and prediction to have the same shape, so we would need to check based on criterion. Moreover, theoretically, users may indeed want to broadcast. Raising an error would prevent that and users may have to resort to subclassing to circumvent the error. Automatic reshaping We could automatically add/remove dimensions if we see that they mismatch. This has the same problems as the previous solution regarding the dependence on the type of criterion. Furthermore, automatic adjustment of the user's output is prone to run into issues in some edge cases (e.g. when the broadcasting is actually desired). * Fix error when initializing BaggingRegressor For Python 3.7, CI got: TypeError: __init__() got an unexpected keyword argument 'estimator' for BaggingRegressor. Probably it installs an older version of sklearn, which uses a different argument name. Passing as positional arg should fix it. * Reviewer comment: typo Co-authored-by: ottonemo <marian.tietz@ottogroup.com> * Reviewer comment: typo Co-authored-by: ottonemo <marian.tietz@ottogroup.com> --------- Co-authored-by: ottonemo <marian.tietz@ottogroup.com>
Release text: This release offers a new interface for scikit-learn to do zero-shot and few-shot classification using open source large language models (Jump right into the example notebook). skorch.llm.ZeroShotClassifier and skorch.llm.FewShotClassifier allow the user to do classification using open-source language models that are compatible with the huggingface generation interface. This allows you to do all sort of interesting things in your pipelines. From simply plugging a LLM into your classification pipeline to get preliminary results quickly, to using these classifiers to generate training data candidates for downstream models. This is a first draft of the interface, therefore it is not unlikely that the interface will change a bit in the future, so please, let us know about any potential issues you have. Other items of this release are - the drop of Python 3.7 support - this version of Python has reached EOL and will not be supported anymore - the NeptuneLogger now logs the skorch version thanks to AleksanderWWW - NeuralNetRegressor can now be fitted with 1-dimensional y, which is necessary in some specific circumstances (e.g. in conjunction with sklearn's BaggingRegressor, see sklearn.ensemble.BaggingRegressor() #972); for this to work correctly, the output of the of the PyTorch module should also be 1-dimensional; the existing default, i.e. having y and y_pred be 2-dimensional, remains the recommended way of using NeuralNetRegressor
With the latest skorch release, this workaround should no longer be necessary, i.e. |
NeuralNetRegressor () and sklearn. Ensemble. BaggingRegressor () can not run. But NeuralNetClassifier () and sklearn. Ensemble. BaggingClassifier combination can run.
The text was updated successfully, but these errors were encountered: