Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpu memory leak with self.log on_epoch=True #4556

Closed
Vozf opened this issue Nov 6, 2020 · 7 comments · Fixed by #4592
Closed

Gpu memory leak with self.log on_epoch=True #4556

Vozf opened this issue Nov 6, 2020 · 7 comments · Fixed by #4592
Assignees
Labels
bug Something isn't working help wanted Open to be worked on logger Related to the Loggers priority: 0 High priority task

Comments

@Vozf
Copy link
Contributor

Vozf commented Nov 6, 2020

pl 1.0.5
Using new logging api I want to log a metric in LightningModule

self.log(";;;;;;;;;;;;;;;;;;;", 1, on_step=False, on_epoch=True)

This is a dummy example but it is sufficient to add to LightningModule's training_step to cause a memory leak on gpu.
What could go wrong? We want to log a metric which is not even a cuda tensor. How could it lead to a gpu memory leak?
Well thanks to the magic of metric epoch aggregation stuff
Let's dig in and take a look at here
https://github.com/PyTorchLightning/pytorch-lightning/blob/b3db197b43667ccf0f67a4d0d8093fc866080637/pytorch_lightning/trainer/training_loop.py#L550-L569

Here we run batch, convert batch_output to epoch_end_outputs if on_epoch was set and append epoch_end_outputs to epoch_output inside on_train_batch_end
epoch_output is defined here
https://github.com/PyTorchLightning/pytorch-lightning/blob/b3db197b43667ccf0f67a4d0d8093fc866080637/pytorch_lightning/trainer/training_loop.py#L540

Everything seems normal, but there is a problem inside batch_output there is a surprise - loss value stored on gpu.
image
I think you can guess by now what could go wrong if we store a lot of separate cuda tensors in a long long epoch_output
Yeah the gpu memory is going to end and you'll get a famous

RuntimeError: CUDA out of memory. Tried to allocate 114.00 MiB (GPU 1; 10.92 GiB total capacity; 9.39 GiB already allocated; 27.38 MiB free; 10.24 GiB reserved in total by PyTorch)

Where is the loss appended to output? Here
https://github.com/PyTorchLightning/pytorch-lightning/blob/b3db197b43667ccf0f67a4d0d8093fc866080637/pytorch_lightning/trainer/training_loop.py#L396-L427

In the first line we get a pretty result without the loss in it, and in line 414 the loss get appended and we start our memory leak chain of events

How is it affecting the training? It can lead to error only on the first epoch of training. If you've got enough memory to hold a list of gpu losses during the 1st epoch there won't be any exceptions as subsequent epochs will have the same list of losses, if not you'll get it somewhere in the middle of 1st epoch. And of course the more steps you have in an epoch the more memory this list of gpu losses will require as one loss is stored per step
Here is the comparison for my task. My gpu could hold 2k steps before memory error
With self.log
image
Without self.log
image
You can see how there is a rapid growth in the first minute in both as the model is loaded and feeded the 1st batch.
The difference is in subsequent minutes where in the former case the list of losses eats 7gb of gpu memory and leads to crash, and in the latter nothing happens and training goes on
Pretty cool how one self.log could eat 2 times more gpu memory more than actual training process

@Vozf Vozf added bug Something isn't working help wanted Open to be worked on labels Nov 6, 2020
@edenlightning edenlightning added the logger Related to the Loggers label Nov 6, 2020
@AristoYU
Copy link

AristoYU commented Nov 9, 2020

Same problem! However, I use self.log("log name", (scalar tensor).item()) to avoid that OOM problem. Maybe you can log the data in the tensor instead of the tensor itself.

@Vozf
Copy link
Contributor Author

Vozf commented Nov 9, 2020

I'm logging just a python1 not a tensor as you can see from the example

@Vozf
Copy link
Contributor Author

Vozf commented Nov 9, 2020

@tchaton, @Borda Any thoughts on this?

@Vozf
Copy link
Contributor Author

Vozf commented Nov 9, 2020

For anyone having the same problem, I monkeypatched like this to avoid setting loss

    from pytorch_lightning.trainer.training_loop import TrainLoop

    old_process_training_step_outputs = TrainLoop.process_train_step_outputs

    def process_train_step_outputs_delete_loss(*args, **kwargs):
        results = old_process_training_step_outputs(*args, **kwargs)
        for result in results:
            for res in result:
                res.minimize = None
        return results

    TrainLoop.process_train_step_outputs = process_train_step_outputs_delete_loss

@Vozf
Copy link
Contributor Author

Vozf commented Nov 9, 2020

Validation loop has the same issue cuda tensors are stored in a list, but they are detached compared to non-detached train loop so overhead isn't big, but it's still there. This can be fixed by loss.cpu() before returning it in validation_step or not returning anything at all

@tchaton
Copy link
Contributor

tchaton commented Nov 9, 2020

Hey @Vozf and @AristoYU,

I deeply apologise for this bug. Let me work on it in priority !

Best regards,
Thomas Chaton.

@Borda Borda added the priority: 0 High priority task label Nov 9, 2020
@Svito-zar
Copy link

I also had this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on logger Related to the Loggers priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants