-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gpu memory leak with self.log on_epoch=True #4556
Comments
Same problem! However, I use |
I'm logging just a python |
For anyone having the same problem, I monkeypatched like this to avoid setting loss
|
Validation loop has the same issue cuda tensors are stored in a list, but they are detached compared to non-detached train loop so overhead isn't big, but it's still there. This can be fixed by loss.cpu() before returning it in validation_step or not returning anything at all |
I also had this issue |
pl 1.0.5
Using new logging api I want to log a metric in LightningModule
This is a dummy example but it is sufficient to add to
LightningModule
'straining_step
to cause a memory leak on gpu.What could go wrong? We want to log a metric which is not even a cuda tensor. How could it lead to a gpu memory leak?
Well thanks to the magic of metric epoch aggregation stuff
Let's dig in and take a look at here
https://github.com/PyTorchLightning/pytorch-lightning/blob/b3db197b43667ccf0f67a4d0d8093fc866080637/pytorch_lightning/trainer/training_loop.py#L550-L569
Here we run batch, convert
batch_output
toepoch_end_outputs
ifon_epoch
was set and appendepoch_end_outputs
toepoch_output
insideon_train_batch_end
epoch_output
is defined herehttps://github.com/PyTorchLightning/pytorch-lightning/blob/b3db197b43667ccf0f67a4d0d8093fc866080637/pytorch_lightning/trainer/training_loop.py#L540
Everything seems normal, but there is a problem inside
batch_output
there is a surprise - loss value stored on gpu.I think you can guess by now what could go wrong if we store a lot of separate cuda tensors in a long long
epoch_output
Yeah the gpu memory is going to end and you'll get a famous
Where is the loss appended to output? Here
https://github.com/PyTorchLightning/pytorch-lightning/blob/b3db197b43667ccf0f67a4d0d8093fc866080637/pytorch_lightning/trainer/training_loop.py#L396-L427
In the first line we get a pretty
result
without the loss in it, and in line 414 the loss get appended and we start our memory leak chain of eventsHow is it affecting the training? It can lead to error only on the first epoch of training. If you've got enough memory to hold a list of gpu losses during the 1st epoch there won't be any exceptions as subsequent epochs will have the same list of losses, if not you'll get it somewhere in the middle of 1st epoch. And of course the more steps you have in an epoch the more memory this list of gpu losses will require as one loss is stored per step
Here is the comparison for my task. My gpu could hold 2k steps before memory error
With
self.log
Without
self.log
You can see how there is a rapid growth in the first minute in both as the model is loaded and feeded the 1st batch.
The difference is in subsequent minutes where in the former case the list of losses eats 7gb of gpu memory and leads to crash, and in the latter nothing happens and training goes on
Pretty cool how one
self.log
could eat 2 times more gpu memory more than actual training processThe text was updated successfully, but these errors were encountered: