Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"element 0 of tensors does not require grad and does not have a grad_fn" when using AdamW from Hugging Face #18254

Closed
0x404 opened this issue Aug 8, 2023 · 14 comments · Fixed by #18268
Assignees
Labels
3rd party Related to a 3rd-party bug Something isn't working optimization ver: 2.0.x
Milestone

Comments

@0x404
Copy link
Contributor

0x404 commented Aug 8, 2023

Bug description

When attempting to migrate my current model to Lightning, I encountered an error while using the AdamW optimizer provided by Hugging Face's Transformers library during training: "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn."

What version are you seeing the problem on?

v2.0

How to reproduce the bug

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW
from lightning.pytorch import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, length):
        self.len = length
        self.data = torch.randn(length, 31)
        self.label = torch.zeros(length, dtype=torch.long)

    def __getitem__(self, index):
        return self.data[index], self.label[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(31, 2)
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        input = batch[0]
        label = batch[1]
        out = self.layer(input)
        loss = self.loss_fn(out, label)
        return loss

    def configure_optimizers(self):
        return AdamW(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32), batch_size=2)
    model = BoringModel()
    trainer = Trainer(max_epochs=1)
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Error messages and logs

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Environment

lightning==2.0.6
transformers==4.31.0

More info

Upon further analysis of Lightning's source code, I found that the issue stemmed from the use of the @torch.no_grad() decorator in the step function of AdamW provided by Hugging Face's Transformers library.

Source code of AdamW provided by Hugging Face's Transformers library:

@torch.no_grad()
    def step(self, closure: Callable = None):
        """
        Performs a single optimization step.

        Arguments:
            closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                state = self.state[p]

...

Lightning wraps the training_step within a closure, and the actual execution of training_step occurs during optimizer.step. This leads to the fact that the step function in the Transformers library's AdamW runs model.training_step and calculates the step loss, causing loss.requires_grad to be False.

To address this problem, I made the following temporary modifications to precision_plugin.py:

def _temp_fix(self, closure, model, optimizer):
    closure_result = closure()
    self._after_closure(model, optimizer)
    def wrap_closure():
        return closure_result
    return wrap_closure

def optimizer_step(  # type: ignore[override]
    self,
    optimizer: Steppable,
    model: "pl.LightningModule",
    closure: Callable[[], Any],
    **kwargs: Any,
) -> Any:
    """Hook to run the optimizer step."""
    # closure = partial(self._wrap_closure, model, optimizer, closure)
    closure = self._temp_fix(closure, model, optimizer)
    return optimizer.step(closure=closure, **kwargs)

This modification ensures that the internal training_step and related functions are executed before passing the actual closure to the optimizer. The closure result is then wrapped in a simple callable for the optimizer, allowing the optimizer to access closure_result.

I'm not certain if this is a correct fix. I am new to the Lightning community and find Lightning very convenient for me. I would like to contribute to Lightning. Would it be possible for me to receive guidance and open a pull request to fix this bug?

thanks in advance!

cc @Borda

@0x404 0x404 added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Aug 8, 2023
@0x404
Copy link
Contributor Author

0x404 commented Aug 8, 2023

maybe related to #18222

@awaelchli
Copy link
Contributor

Thank you so much for the analysis @0x404, this is very helpful and saves us a lot of time.

Your workaround is also good and works in your case, but doesn't support optimizers that require closures. And since it is a feature in Lightning, we can't just remove it.

Here is another proposal that might work. We could explicitly set grad enabled when entering our closure, bypassing the no_grad() context set in the optimizer (only when closure runs training_step of course), and the rest of the HF optimizer code can still benefit from the no_grad context. This makes sense to me now, but needs to be validated.

@awaelchli awaelchli added optimization 3rd party Related to a 3rd-party and removed needs triage Waiting to be triaged by maintainers labels Aug 8, 2023
@awaelchli awaelchli self-assigned this Aug 8, 2023
@awaelchli awaelchli added this to the 2.0.x milestone Aug 8, 2023
@0x404
Copy link
Contributor Author

0x404 commented Aug 8, 2023

Thank you @awaelchli , I validated this method and make this modification in precision_plugin.py:

def _wrap_closure(
    self,
    model: "pl.LightningModule",
    optimizer: Optimizer,
    closure: Callable[[], Any],
) -> Any:
    """This double-closure allows makes sure the ``closure`` is executed before the
    ``on_before_optimizer_step`` hook is called.

    The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
    consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
    """
    def _torch_require_grad():
        _x = torch.tensor([0.0], requires_grad=True)
        _y = _x ** 2
        return _y.requires_grad
    
    _require_grad = _torch_require_grad()
    torch.set_grad_enabled(True)

    closure_result = closure()
    self._after_closure(model, optimizer)

    if not _require_grad:
        torch.set_grad_enabled(False)
    return closure_result

def optimizer_step(  # type: ignore[override]
    self,
    optimizer: Steppable,
    model: "pl.LightningModule",
    closure: Callable[[], Any],
    **kwargs: Any,
) -> Any:
    """Hook to run the optimizer step."""
    closure = partial(self._wrap_closure, model, optimizer, closure)
    return optimizer.step(closure=closure, **kwargs)

This seems work fine.

@0x404
Copy link
Contributor Author

0x404 commented Aug 8, 2023

explicitly setting grad enabled for training_step will lead to the same error "element 0 of tensors does not require grad and does not have a grad_fn" since backward needs grad. Therefore, I explicitly setting grad enabled for closure, which should consists of training_step, backward and optimizer_zero_grad.

@awaelchli
Copy link
Contributor

@awaelchli
Copy link
Contributor

Basically, at the beginning of this closure, set grad enabled. What do you think?

@0x404
Copy link
Contributor Author

0x404 commented Aug 8, 2023

@awaelchli Exactly! Would it be possible for me to submit a PR to address this? I'm relatively new to Lightning, so it might take me a day or two to become acquainted with the Lightning workflow. I am so interested in Lighting, so I think this is a good first issue for me.

@awaelchli
Copy link
Contributor

Definitely please give it a try. This is much appreciated, and I'm happy to help or answer questions.
We still need to validate that such a change doesn't have any unintended side effects, so I suggest to submit a PR with the change and then we can let the test suite run.

@0x404
Copy link
Contributor Author

0x404 commented Aug 9, 2023

Hi @awaelchli, I encountered a few issues while attempting this:

  1. I made modifications to the corresponding code in src/pytorch_lightning/, but it seems this target is being ignored by git. How can I submit my code changes to the git repository?
  2. After completing my code changes, how should I build Lightning for testing purposes? (Currently, I'm using make test, but the generated package doesn't seem to include my modifications.)

Could you please point me to any relevant documentation that could assist me? Thank you.

@awaelchli
Copy link
Contributor

You need to make the modifications under src/lightning/pytorch. Ignore the src/pytorch_lightning/, it is only there to generate the pytorch_lightning package.

After completing my code changes, how should I build Lightning for testing purposes? (Currently, I'm using make test, but the generated package doesn't seem to include my modifications.)

You can, but there are many tests and you won't be able to run all of them (which is what make test will attempt to do). For a simple change like yours, I suggest to just pip install -r requirements/pytorch/test.txt, then run individual test files like so:

py.test -v tests/tests_pytorch/.../path/to/test_file.py

But before doing that, if I were you I would just submit the PR first, then the CI can run once through the test suite and we see the output. If some tests fail, then that's the time to go and investigate. Let me know if that works.

@0x404
Copy link
Contributor Author

0x404 commented Aug 10, 2023

Thanks @awaelchli, I have already submit a PR #18268.

@carmocca
Copy link
Contributor

Shouldnt this be fixed in the optimizer definition? As it forgot to take the closure into consideration

@0x404
Copy link
Contributor Author

0x404 commented Aug 10, 2023

HI, @carmocca. I feel the same way. But I think checking for gradients for closure in lighting is also a viable approach. This makes Lightning more compatible with various third-party libraries (including those that forgot to take the closure into consideration). What do you think?

@carmocca
Copy link
Contributor

I agree. It's okay to add the explicit fix here too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3rd party Related to a 3rd-party bug Something isn't working optimization ver: 2.0.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants