Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sklearn.ensemble.BaggingRegressor() #972

Closed
zhizhuaa opened this issue May 30, 2023 · 5 comments
Closed

sklearn.ensemble.BaggingRegressor() #972

zhizhuaa opened this issue May 30, 2023 · 5 comments

Comments

@zhizhuaa
Copy link

NeuralNetRegressor () and sklearn. Ensemble. BaggingRegressor () can not run. But NeuralNetClassifier () and sklearn. Ensemble. BaggingClassifier combination can run.

@BenjaminBossan
Copy link
Collaborator

Could you give more details? Is it because you get the following error message?

ValueError: The target data shouldn't be 1-dimensional but instead have 2 dimensions ...

@zhizhuaa
Copy link
Author

The code is as follows.

from sklearn.datasets import make_regression
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from skorch import NeuralNetRegressor
from sklearn.ensemble import BaggingRegressor

# This is a toy dataset for regression, 1000 data points with 20 features each
X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0)
X_regr = X_regr.astype(np.float32)
y_regr = y_regr.astype(np.float32) / 100
y_regr = y_regr.reshape(-1, 1)

class RegressorModule(nn.Module):
    def __init__(
            self,
            num_units=10,
            nonlin=F.relu,
    ):
        super(RegressorModule, self).__init__()
        self.num_units = num_units
        self.nonlin = nonlin

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 1)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = F.relu(self.dense1(X))
        X = self.output(X)
        return X

net_regr = NeuralNetRegressor(
    RegressorModule,
    max_epochs=20,
    lr=0.1,
    device='cuda',  # uncomment this to train with CUDA
)

bagging = BaggingRegressor(estimator=net_regr, n_estimators=10, random_state=42)

bagging.fit(X_regr, y_regr)

# Making prediction for first 5 data points of X
y_pred = bagging.predict(X_regr[:5])

The brief run error is as follows.

raise ValueError(msg)
ValueError: The target data shouldn't be 1-dimensional but instead have 2 dimensions, with the second dimension having the same size as the number of regression targets (usually 1). Please reshape your target data to be 2-dimensional (e.g. y = y.reshape(-1, 1).

@BenjaminBossan
Copy link
Collaborator

Yes, I'm sorry about this annoyance, we should fix it in skorch.

Until then, here is a small hack that should resolve the issue:

class MyNeuralNetRegressor(NeuralNetRegressor):
    def fit(self, X, y, **fit_params):
        if y.ndim == 1:
            y = y.reshape(-1, 1)
        return super().fit(X, y, **fit_params)

Replace NeuralNetRegressor in your code with MyNeuralNetRegressor and it should work. If it works, still leave this issue open as a reminder to fix the issue properly.

@zhizhuaa
Copy link
Author

Thank you very much. It worked.

BenjaminBossan added a commit that referenced this issue Jun 1, 2023
This change makes it possible to pass a 1-dimensional y to
`NeuralNetRegressor`.

Problem description

Right now, skorch requires the `y` passed to `NeuralNetRegressor.fit` to
be 2-dimensional, even if there is only one target, as is the most
common case. This problem has come up a few times in the past, but
mostly it's just an annoyance - just do `y.reshape(-1, 1)` and you're
good (the error message says as much).

There are, however, also cases where it's not so easy to solve. For
instance, in #972, a user reports that they cannot use skorch with
sklearn's `BaggingRegressor`. The problem is that even if `y` is
reshaped, once it is passed to the net from `BaggingRegressor`, it is 1d
again. I assume that `BaggingRegressor` internally squeezes `y` at some
point.

This PR lifts the 2d restriction check.

Initial motivation

Why does skorch require `y` to be 2d? I couldn't remember the initial
reasoning and did some archeology.

I found this comment:

(2f00e25#diff-66ed08bca4d171889565d0285a36b9b47e0e91e3b33d85c51352d8eb00faefac):

>         # The problem with 1-dim float y is that the pytorch DataLoader will
>         # somehow upcast it to DoubleTensor

This strange behavior should not be an issue anymore, so if that was the
only problem, we should be able to just remove the constraint, right?

Problems with removing the constraint

Unfortunately, it's not that easy. The issue comes down to the
following: When we remove the constraint and allow the target `y` to be
1d, but the prediction `y_pred` is still 2d, the criterion `nn.MSELoss`
will probably do the wrong thing. What exactly is wrong? Instead of
calculating the squared error for each sample pair, the criterion
will broadcast the vector and calculate _all squared errors_ between
each sample, then return the mean of that. To demonstrate, let's remove
the reduction step and look at the shape:

```python
>>> import torch
>>> criterion = torch.nn.MSELoss(reduction='none')
>>> y = torch.rand(100)
>>> y_pred = torch.rand((100, 1))
>>> y.shape, y_pred.shape
(torch.Size([100]), torch.Size([100, 1]))
>>> se = criterion(y_pred, y)
/home/vinh/anaconda3/envs/skorch/lib/python3.10/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([100])) that is different to the input size (torch.Size([100, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
>>> se.shape
torch.Size([100, 100])
```

As can be seen, PyTorch broadcasts the two arrays, leading to 100x100
errors being calculated. Thankfully, PyTorch warns about potential
issues with that.

The current solution is to accept this behavior and hope that the users
will indeed see the warning. If they don't see it or ignore it, it could
be a huge issue, because they still get a loss scalar and might even see
a small improvement in the loss during training. But the model will not
converge and it's going to be a huge pain to debug the bug, if it's even
identified as such.

Just to be clear, existing code, which uses 2d targets, will not be
affected by the change introduced in this PR and is still the preferred
way (IMO) to use regression in skorch.

Rejected solutions

I did consider the following solutions but rejected them.

Raising an error when shapes mismatch

This would remove the risk of users missing the warning. The problem
with this is that mismatching shapes can be okay in certain
circumstances. Some criteria don't expect target and prediction to have
the same shape, so we would need to check based on criterion. Moreover,
theoretically, users may indeed want to broadcast. Raising an error
would prevent that and users may have to resort to subclassing to
circumvent the error.

Automatic reshaping

We could automatically add/remove dimensions if we see that they
mismatch. This has the same problems as the previous solution regarding
the dependence on the type of criterion. Furthermore, automatic
adjustment of the user's output is prone to run into issues in some edge
cases (e.g. when the broadcasting is actually desired).
ottonemo added a commit that referenced this issue Jun 26, 2023
* Allow regression with 1d targets

This change makes it possible to pass a 1-dimensional y to
`NeuralNetRegressor`.

Problem description

Right now, skorch requires the `y` passed to `NeuralNetRegressor.fit` to
be 2-dimensional, even if there is only one target, as is the most
common case. This problem has come up a few times in the past, but
mostly it's just an annoyance - just do `y.reshape(-1, 1)` and you're
good (the error message says as much).

There are, however, also cases where it's not so easy to solve. For
instance, in #972, a user reports that they cannot use skorch with
sklearn's `BaggingRegressor`. The problem is that even if `y` is
reshaped, once it is passed to the net from `BaggingRegressor`, it is 1d
again. I assume that `BaggingRegressor` internally squeezes `y` at some
point.

This PR lifts the 2d restriction check.

Initial motivation

Why does skorch require `y` to be 2d? I couldn't remember the initial
reasoning and did some archeology.

I found this comment:

(2f00e25#diff-66ed08bca4d171889565d0285a36b9b47e0e91e3b33d85c51352d8eb00faefac):

>         # The problem with 1-dim float y is that the pytorch DataLoader will
>         # somehow upcast it to DoubleTensor

This strange behavior should not be an issue anymore, so if that was the
only problem, we should be able to just remove the constraint, right?

Problems with removing the constraint

Unfortunately, it's not that easy. The issue comes down to the
following: When we remove the constraint and allow the target `y` to be
1d, but the prediction `y_pred` is still 2d, the criterion `nn.MSELoss`
will probably do the wrong thing. What exactly is wrong? Instead of
calculating the squared error for each sample pair, the criterion
will broadcast the vector and calculate _all squared errors_ between
each sample, then return the mean of that. To demonstrate, let's remove
the reduction step and look at the shape:

```python
>>> import torch
>>> criterion = torch.nn.MSELoss(reduction='none')
>>> y = torch.rand(100)
>>> y_pred = torch.rand((100, 1))
>>> y.shape, y_pred.shape
(torch.Size([100]), torch.Size([100, 1]))
>>> se = criterion(y_pred, y)
/home/vinh/anaconda3/envs/skorch/lib/python3.10/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([100])) that is different to the input size (torch.Size([100, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
>>> se.shape
torch.Size([100, 100])
```

As can be seen, PyTorch broadcasts the two arrays, leading to 100x100
errors being calculated. Thankfully, PyTorch warns about potential
issues with that.

The current solution is to accept this behavior and hope that the users
will indeed see the warning. If they don't see it or ignore it, it could
be a huge issue, because they still get a loss scalar and might even see
a small improvement in the loss during training. But the model will not
converge and it's going to be a huge pain to debug the bug, if it's even
identified as such.

Just to be clear, existing code, which uses 2d targets, will not be
affected by the change introduced in this PR and is still the preferred
way (IMO) to use regression in skorch.

Rejected solutions

I did consider the following solutions but rejected them.

Raising an error when shapes mismatch

This would remove the risk of users missing the warning. The problem
with this is that mismatching shapes can be okay in certain
circumstances. Some criteria don't expect target and prediction to have
the same shape, so we would need to check based on criterion. Moreover,
theoretically, users may indeed want to broadcast. Raising an error
would prevent that and users may have to resort to subclassing to
circumvent the error.

Automatic reshaping

We could automatically add/remove dimensions if we see that they
mismatch. This has the same problems as the previous solution regarding
the dependence on the type of criterion. Furthermore, automatic
adjustment of the user's output is prone to run into issues in some edge
cases (e.g. when the broadcasting is actually desired).

* Fix error when initializing BaggingRegressor

For Python 3.7, CI got:

TypeError: __init__() got an unexpected keyword argument 'estimator'

for BaggingRegressor. Probably it installs an older version of sklearn,
which uses a different argument name. Passing as positional arg should
fix it.

* Reviewer comment: typo

Co-authored-by: ottonemo <marian.tietz@ottogroup.com>

* Reviewer comment: typo

Co-authored-by: ottonemo <marian.tietz@ottogroup.com>

---------

Co-authored-by: ottonemo <marian.tietz@ottogroup.com>
BenjaminBossan pushed a commit that referenced this issue Jun 26, 2023
Release text:

This release offers a new interface for scikit-learn to do zero-shot and
few-shot classification using open source large language models (Jump right into
the example notebook).

skorch.llm.ZeroShotClassifier and skorch.llm.FewShotClassifier allow the user to
do classification using open-source language models that are compatible with the
huggingface generation interface. This allows you to do all sort of interesting
things in your pipelines. From simply plugging a LLM into your classification
pipeline to get preliminary results quickly, to using these classifiers to
generate training data candidates for downstream models. This is a first draft
of the interface, therefore it is not unlikely that the interface will change a
bit in the future, so please, let us know about any potential issues you have.

Other items of this release are

- the drop of Python 3.7 support - this version of Python has reached EOL and
  will not be supported anymore
- the NeptuneLogger now logs the skorch version thanks to AleksanderWWW
- NeuralNetRegressor can now be fitted with 1-dimensional y, which is necessary
  in some specific circumstances (e.g. in conjunction with sklearn's
  BaggingRegressor, see sklearn.ensemble.BaggingRegressor() #972); for this to
  work correctly, the output of the of the PyTorch module should also be
  1-dimensional; the existing default, i.e. having y and y_pred be
  2-dimensional, remains the recommended way of using NeuralNetRegressor
@BenjaminBossan
Copy link
Collaborator

With the latest skorch release, this workaround should no longer be necessary, i.e. NeuralNetRegressor works with 1-dim y. Just make sure that the y_pred that is returned from the module is also 1-dim.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants