Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

performance loss from 1.0.8 to 1.1.* when using 16 bit precision #5159

Closed
immanuelweber opened this issue Dec 16, 2020 · 8 comments · Fixed by #5191
Closed

performance loss from 1.0.8 to 1.1.* when using 16 bit precision #5159

immanuelweber opened this issue Dec 16, 2020 · 8 comments · Fixed by #5191
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@immanuelweber
Copy link

🐛 Bug

After updating pytorch-lightning from 1.0.8 to 1.1.0/1.1.1 the use of 16 bit precision destroys the performances.
In my actual code of object detection losses are by a factor of 4 larger at the beginning than compared to 32 bit or 16 bit with pl 1.08.
They converge to a much higher value and the resulting model lost its detection capabilities completely.
To replicate I tested the pl notebooks and the 06-cifar10-baseline.ipynb also shows this and the classification accuracy corresponds to guessing the class when switching from 32 to 16 bit.
I integrated it into the BoringModel notebook and the problem is also happening in google colab.

Please reproduce using the BoringModel and post here

https://colab.research.google.com/drive/1FqXG9Xw9gVZxnwiGnjsHpAtb-vUqFaob?usp=sharing

To Reproduce

Expected behavior

Same performance for 32 and 16 bit.

Environment

  • CUDA:
    • GPU:
      • Tesla P100-PCIE-16GB
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.18.5
    • pyTorch_debug: True
    • pyTorch_version: 1.7.0+cu101
    • pytorch-lightning: 1.1.1
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.6.9
    • version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020

Additional context

@immanuelweber immanuelweber added bug Something isn't working help wanted Open to be worked on labels Dec 16, 2020
@Borda Borda added the priority: 0 High priority task label Dec 16, 2020
@Borda Borda added this to the 1.1.x milestone Dec 16, 2020
@SeanNaren
Copy link
Contributor

Can verify I see this issue, setting enable_pl_optimizer=False in the trainer seems to fix convergence, investigating now!

you can use enable_pl_optimizer=False as a temporary hotfix in this case

@SeanNaren
Copy link
Contributor

Here is a script to reproduce the underlying issue, seems like behaviour is different when using the pl optimizer:

import os
import torch
from torch.utils.data import Dataset
from pytorch_lightning import Trainer, LightningModule, seed_everything


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        """
        Testing PL Module

        Use as follows:
        - subclass
        - modify the behavior for what you want

        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing

        or:

        model = BaseTestModel()
        model.training_epoch_end = None

        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def step(self, x):
        x = self.layer(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        print(loss)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


#  NOTE: If you are using a cmd line to run your script,
#  provide the cmd line as below.
#  opt = "--max_epochs 1 --limit_train_batches 1".split(" ")
#  parser = ArgumentParser()
#  args = parser.parse_args(opt)

def run_test():
    seed_everything(42)

    # fake data
    train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

    # model
    before = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=1,
        limit_train_batches=4,
        limit_val_batches=0,
        weights_summary=None,
        gpus=1,
        precision=16,
    )
    trainer.fit(before, train_data, val_data)

    seed_everything(42)
    # fake data
    train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

    # model
    after = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        max_epochs=1,
        limit_train_batches=4,
        limit_val_batches=0,
        weights_summary=None,
        gpus=1,
        precision=16,
        enable_pl_optimizer=False
    )
    trainer.fit(after, train_data, val_data)

    # Assert model parameters are identical after fit
    for before, after in zip(before.parameters(), after.parameters()):
        assert torch.equal(before, after), 'Model parameters are different'


if __name__ == '__main__':
    run_test()

We should expect the same model trained on both (you can confirm if you disable pl optimizer)

@mees
Copy link
Contributor

mees commented Dec 16, 2020

I can confirm this issue with my network too. My model does not converge with fp16 and PL 1.1.* and using enable_pl_optimizer=False fixes the issue.

@Borda
Copy link
Member

Borda commented Dec 17, 2020

shall we add some vanilla AMP loop for parity testing?

@tchaton
Copy link
Contributor

tchaton commented Dec 17, 2020

Hey @mees @egonuel,

We found the bug, a fix should be merged soon !

We apologise for the inconvenience.

Best regards,
T.C

@mees
Copy link
Contributor

mees commented Dec 17, 2020

thanks @tchaton @SeanNaren! what was the bug?

@tchaton
Copy link
Contributor

tchaton commented Dec 17, 2020

Hey @mees,

Sneaky bug :) The gradients were unscaled twice by the scaler. It uses an attribute on the optimizer to track if it needs to upscale. However, it was calling _unscale on LightningOptimizer, and then performing step on the unwrapped optimizer. As the attributes was unscaled was True on LightningOptimizer, but still False on unwrapped optimizer, it was unsealing a second time.

Simple fix, unscaled on unwrapped optimizer.

Best,
T.C

@siddk
Copy link

siddk commented Dec 19, 2020

Has this been resolved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
6 participants