-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move the module to the precision dtype #12591
Comments
Thanks @yuvalkirstain! I investigated this as it definitely seemed strange however I think there may be a misunderstanding of behaviour. using AMP, the model remains in FP32 but operations that can happen in FP16 are auto-cast. This is the same with BF16 (see this example): import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
assert self.layer.weight.dtype == torch.float32 # the weights are in 32!
x = self.layer(x)
assert x.dtype == torch.bfloat16 # output was bfloat16 as the operation is bf16 compatible
return x
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
max_epochs=1,
enable_model_summary=False,
precision="bf16",
gpus=1,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run() I then checked FairSeq, which converts the entire model into BF16 even in AMP mode: https://github.com/pytorch/fairseq/blob/main/fairseq/trainer.py#L105-L106 This might be why you see the memory improvements. If this is something you'd like to do, you can simply do the same in your class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2).bfloat16() Let me know if this helps! We might need to improve our documentation to explain this case. |
Thank you so much for the detailed answer! I will try it out and update :) |
Nope thank you @yuvalkirstain let me know how it goes! You do also bring up a great point which is if we should support converting the pl_module internally for users. I think it is a potentially good idea, however, I wonder what the API for this would be. |
This would be just a flag on the precision plugin (name up for debate) from pytorch_lightning import Trainer
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin
Trainer(plugins=NativeMixedPrecisionPlugin(convert_modules=True)) Another question is whether this should be enabled by default. As per the test added by @rohitgr7 in https://github.com/PyTorchLightning/pytorch-lightning/pull/12508/files#diff-3e387bfea0892d3f7033341769075f9e8b17dbe2fbe4a44a68e3d450bdba2e0eR1223, the layers are also moved using DeepSpeed |
@SeanNaren Yes, doing so results in less memory on the GPU with identical results, thank you! T5-XL (3B parameters) - inference on SQuAD dataset
8058MiB / 24268MiB (with model = model.bfloat16())
13196MiB / 24268MiB (without) Regarding converting the pl_module internally for users, definitely, it makes more sense IMO that the trainer will take care of that rather than the model. |
@yuvalkirstain I'm glad it worked! Hopefully we'll get the feature in soon :) |
I think yes because:
issues:
but if there are no other side effects, I think enabling it by default should be good. |
Just to be clear, we're talking about just BF16 precision? DeepSpeed's handling of AMP is very different and should be treated as such. |
For DeepSpeed in particular, we let it make the choice, at the moment that's moving the module too But I was thinking this should apply to all precision values. |
🐛 Bug
@SeanNaren
When I use
bf-16
and check thedtype
of the model, it seems like the model's precision isfp32
(and I do not see the memory gains I expect). On other frameworks that supportbf-16
(like fairseq) the model's dtype istorch.bfloat16
. Is there a simple example that "proves" that this feature reduces the memory consumption as it should? I suspect that there might be something wrong (but of course, I might be wrong).Thank you!
To Reproduce
launch any job with
precision=bf16
and compare withprecision=32
.Expected behavior
This feature should save 30-50% memory but I do not see such gains in lightning.
Environment
Additional context
BF-16 is a very important feature. It is usually more stable than fp16 and lightning should support it effectively (models that are pretrained with bf-16 should not be used with fp-16) :)
cc @Borda @tchaton @rohitgr7 @carmocca @justusschock @awaelchli @akihironitta
The text was updated successfully, but these errors were encountered: