Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trainer.test() gives an error when using deepspeed and bf16. #16298

Closed
leao1995 opened this issue Jan 9, 2023 · 2 comments · Fixed by #16973
Closed

trainer.test() gives an error when using deepspeed and bf16. #16298

leao1995 opened this issue Jan 9, 2023 · 2 comments · Fixed by #16973
Assignees
Labels
bug Something isn't working precision: amp Automatic Mixed Precision strategy: deepspeed
Milestone

Comments

@leao1995
Copy link

leao1995 commented Jan 9, 2023

calling trainer.test() after trainer.fit() gives an error when using deepspeed and bf16.

ValueError: torch.float32 is enabled but the following parameters have dtype that is not torch.float32: [('module.model.weight', torch.bfloat16), ('module.model.bias', torch.bfloat16)]

deepspeed== 0.6.5
pytorch-lightning==1.6.4

The following is a minimul example to reproduce the error.

import torch
from torch.utils.data import Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        output = self.model(batch)
        loss = torch.sum(output)
        return loss

    def test_step(self, batch, batch_idx):
        return self.model(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

train_data = torch.utils.data.DataLoader(RandomDataset(32, 64), num_workers=8, batch_size=8)
model = BoringModel()
strategy = DeepSpeedStrategy(
    config = {
        "bf16": { "enabled": True },
        "prescale_gradients": False,
        "zero_optimization": {
            "stage": 2,
            "contiguous_gradients": False,
            "allgather_bucket_size": 5e8,
            "reduce_bucket_size": 5e8,
            "overlap_comm": True
        },
        "zero_allow_untested_optimizer": True
    }
)
trainer = Trainer(
    limit_train_batches=1,
    limit_test_batches=1,
    max_epochs=1,
    devices=1,
    accelerator="gpu",
    precision="bf16",
    strategy=strategy
)
trainer.fit(model, train_data)
trainer.test(model, train_data)

Originally posted by @leao1995 in #16297

cc @awaelchli @carmocca @justusschock

@Borda Borda added the bug Something isn't working label Jan 9, 2023
@Borda
Copy link
Member

Borda commented Jan 9, 2023

pytorch-lightning==1.6.4

could you pls validate if it is still the case with the latest PL?

@colehawkins
Copy link
Contributor

The source of this error appears to be a missed check. The current check is only for fp16 and should be extended to at least bf16, and possibly also bfloat16 which was an option in prior deepspeed versions. https://github.com/Lightning-AI/lightning/blob/c2b28a0e8cb3073d95b44bb94664a7ba22fc5d51/src/lightning/pytorch/strategies/deepspeed.py#L545

Deepspeed current version appears to support only bf16. https://www.deepspeed.ai/docs/config-json/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working precision: amp Automatic Mixed Precision strategy: deepspeed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants