-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU memory issues (leak?) #439
Comments
Thanks @ptrblck. I changed my code to only run |
Thanks for the update! import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from apex import amp
import copy
use_amp = False
clean_opt = False
device='cuda'
model = models.resnet18()
model.to(device)
state_dict = copy.deepcopy(model.state_dict())
optimizers = [optim.Adam(model.parameters(), lr=1e-3) for _ in range(3)]
if use_amp:
model, optimizers = amp.initialize(model, optimizers, opt_level='O1')
dataset = datasets.FakeData(transform=transforms.ToTensor())
loader = DataLoader(
dataset,
batch_size=64,
num_workers=4,
pin_memory=True,
shuffle=True
)
criterion = nn.CrossEntropyLoss()
print('Memory allocated {:.3f}MB'.format(
torch.cuda.memory_allocated() / 1014**2))
for opt_idx, optimizer in enumerate(optimizers):
# reset model
model.load_state_dict(state_dict)
# Train
for epoch in range(5):
for data, target in loader:
data = data.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
print('OptIdx {}, epoch {}, loss {}, mem allocated {:.3f}MB'.format(
opt_idx, epoch, loss.item(), torch.cuda.memory_allocated()/1024**2))
if clean_opt:
optimizers[opt_idx] = None If you leave the default settings as Thanks for reporting! We'll have a look. |
Thanks again @ptrblck. I am actually doing what you are suggesting, namely deleting the optimizers in the end, but as you correctly state it does not seem to work with amp (similarly it might not work for other variables). Hope you can take a look. |
This is an important bug for me as well. I am testing multiple hyperparameters sequentially. That means that in each iteration I set up everything from scratch. So also the model and optimizer are recreated in each iteration, and thus amp.initialize is also called on each iteration. After a model has been trained, I try to clear caches by running torch.cuda.empty_cache()
gc.collect() Unfortunately to no avail. After some iterations, I run into a CUDA OOM error when using AMP. Any form of parameter tuning doesn't seem possible, then. (For reference: I could test 5 parameter settings when fine-tuning RoBERTa before I got an OOM on an RTX 2080 Ti.) I was running the same process (but fine-tuning another model) on 2x V100's. That process is still running (they have 16GB of RAM, RTX 2080 Ti only has 11GB), but looking at its memory usage via
|
@BramVanroy would it be possible to reuse the model and optimizer and just reinitialize them? |
@ptrblck Thank you for your time. Re-using is not possible in my case, since depending on the parameters the architecture can change (e.g. no intermediate linear layer). I distilled the scripts that I use to a working example that only uses DistilBERT. I included testing data of 2000 sentences train/dev. You can find the repo here. Documentation is bad, but the only thing you need to use is the Training shouldn't take too long (dataset is quite small). The script will train on four different parameters. You can see the difference between just running Initialising amp happens here Below you can find the Without AMP
With AMP
The important part is not the absolute numbers, but the relative increase of memory usage of GPU:0 at each new parameter setting. In the case of using AMP, it increases with 10% every time, whereas without AMP it only increases by 1%. |
@ptrblck According to my insights and your example above, re-using the model is no help as deleting e.g., the optimizer does not work. |
I've been casually reading through the source code, and I was wondering how hard it would be to implement a 'destroy' method that frees all AMP objects, ready for garbage collection. It seems that most things go through the AmpHandle, _amp_state, and master_params. Is there a way to clear those out? |
I ran into this issue as well, sadly this results in pretty hard to debug memory leaks. I'd suggest showing a warning when calling In my case I can actually reuse the model but it's not convenient. I run hypermeter optimization and the whole model initialization is capsuled off from the main optimization loop. So passing around the model and optimizer states is cumbersome. |
this issue seriously affects me |
Are you using O1 or O2 when you see these memory leak issues? O1 does not create any additional parameters, so it should not leak memory. |
I am using O1, the memory usage constantly going up when I recreate the model and optimizer |
Also, if i turn off fp16 mode, then there is no memory issue. So it must be apex's problem |
Is the memory consumption going up on the CPU, GPU, or both? |
I am not monitoring cpu, for GPU, it is like 10% more memory each time i recreate model and optimizer. |
This bug actually hinders me from using Apex most of the time. I am constantly running multiple models in single notebooks / scripts. |
Just in case you have trouble reproducing this this is what I used for debugging: from collections import Counter
import gc
import torch
from apex import amp
import torch.nn as nn
cached = []
allocated = []
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.linear = nn.Linear(1024, 1024)
def forward(self, x):
x = self.linear(x)
return x
def get_tensor_sizes():
ctr = Counter()
gc.collect()
for obj in gc.get_objects():
try:
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
ctr.update([tuple(obj.size())])
except OSError: pass
except RuntimeError: pass
return ctr
for _ in range(1000):
gc.collect()
print(get_tensor_sizes())
model = Net().to("cuda")
optimizer = torch.optim.SGD(model.parameters(), 0.1)
model(torch.rand(1024).to("cuda"))
amp.initialize(
model,
optimizer,
opt_level="O1",
)
allocated.append(torch.cuda.memory_allocated())
cached.append(torch.cuda.memory_cached())
print(f"Cached: {cached}")
print(f"Allocated: {allocated}") As you will see cuda memory usage rises over time. Even though the model clearly should be deleted (which should be enforced by the I am not 100% sure if this example could be further minimized but the |
Please see my previous comment #439 (comment) where I provide a repo to test. The numbers given is the memory usage increase with O1. |
As documented here (and in the official documentation) NVIDIA/apex#439 we shouldn't call apex.initialize twice. To avoid this we retain the original model, loading state dicts of new optmizers and models for each run.
apex creates a memory leak (see NVIDIA/apex#439) where models cannot be garbage collected. This makes apex very difficult to use at the moment during hyper-parameter tuning.
Hi! Bumping this post, I am also experiencing the exact same issue with O1 mode. Using hyperopt library and making sequential calls to initialize new models with different hyperparameters that are then wrapped in amp.initialize. GPU memory growth is roughly 10% per iteration. Thank you for your help! |
I also have this problem with 'O1' and 'O2' modes. |
I think the scenario of @MrHuff is common: using hyperparameter optimization requires multiple initializations. I expect that most time on Apex is dedicated to integrating it in upstream, so it would be great that by default the amp initialization is NOT indestructible. |
This is still a major bug. Tons of people complaining about the same. |
It's not a bug per se, I think. AMP was simply never intended to be initialised more than once per session. That being said, it would be nice to see an improvement on this aspect when it gets upstream support. |
Just for the record. We hit the same problem with memory leaks with apex. Without apex, the same code works without leaks as all people said above. |
For recording as well. I hit the same problem when I use AMP in training steps, and it seems to have memory leaks problem when I do the evaluation. |
Same issue using Ray Tune for hyperparameters optimization. |
I have the same issue in Kaggle for testing a few models. Can someone tell us the best practice? |
AMP is now available in PyTorch's nightlies. The brave experimentalists can try it out now. Documentation looks good already! https://pytorch.org/docs/master/notes/amp_examples.html#amp-examples |
If someone testing it could check if it fixes this issue I would be grateful. |
It;s probably a bit too early for that to make exhaustive comments on the functionality: the functionality has only just been introduced as part of a nightly version. |
I have been testing it a bit, and it looks way better. Could not replicate the memory issues yet. As you say, too early to say, but looks promising. |
Glad you're trying the upstream API!! I anticipate it'll solve any memory leak issues associated with One thing that may not be obvious from the upstream docs: scaler = torch.cuda.amp.GradScaler() # replaces the old scaler instance I could add a It's also permissible to have multiple
(in other words, there is a 1-1 correspondence between convergence runs and |
Hey @mcarilli thanks for the reply and your work on integrating AMP with upstream! A reset method seems useful but (to me) it seems that this would only be useful if it is faster than creating a new instance from scratch. Otherwise recreating the object is just as easy. Finally, the docs might benefit from a pros/cons comparison. Discussing the cons (if any) might be useful. (Because people might wonder what the downside is, if this AMP is not enabled by default.) Thanks again |
it's so lightweight (a few ints/floats/dicts and a couple of one-element tensors) that it shouldn't make a difference either way.
which docs? native docs? do you mean pros and cons of Apex Amp vs native amp? I'd rather not mention Apex at all in the native docs, I want native docs to be self-contained. |
After thinking about it a bit more, I think a reset method would be useful after all if you initialised the scaler outside the loop. Yes, the native docs. But I meant a comparison of using the native AMP vs. not using AMP at all. So a bit more general. This can be one simple line, but it would help beginners to understand the benefit of AMP and any downsides (if any). The latter would be important because one may wonder: if there are no downsides, then why is it not enabled by default? |
@mcarilli Thanks a lot for your contribution! This is a life safer for many! I think reeinitializing the |
Why would overwriting the local reference named scaler = torch.cuda.amp.GradScaler() # replaces scaler in current scope would only overwrite it in that scope, while
You're right, the second, you should only call update() on iterations where you actually step()ed. I'll add gradient accumulation to the examples. |
I think that it is cleaner. For instance, assume that you have a Trainer class with a self.scaler property that you set in the constructor, then using a reset in a method is a lot cleaner than reassigning imo. class Trainer:
def __init__(self):
self.scaler = torch.cuda.amp.GradScaler()
def my_loop(self):
for config in configs:
# train models with different configurations
self.scaler.reset()
# seems a lot cleaner than here self.scaler = torch.cuda.amp.GradScaler() |
Summary: Several people have asked me about proper Amp usage with gradient accumulation. In particular, it's [unclear to people](NVIDIA/apex#439 (comment)) that you should only call `scaler.unscale_()` (if desired) and `scaler.update()` in iterations where you actually plan to step. This PR adds a minimal accumulation example. I built the docs locally and it looks free from sphinx errors, at least. Pull Request resolved: #36601 Differential Revision: D21082295 Pulled By: ngimel fbshipit-source-id: b2faa6c02b9f7e1972618a0f1d5360a03f0450ac
Same here. Run successfully with fp32 in the original code, but failed with apex mixed precision because of gpu memory insufficient. |
Hello, I got the same problem. But I find that there is no reset() fro GradScaler() according to
|
so what is the memory leak solution for the original amp?? put the amp.initialize() in the training loop to initialize the model and optimizer in each training iteration ? Currently I only use amp.initialize() in the beginning after the model is constructed, I got the out of memory problem as well. (btw, I want to train on multiple gpus with average memory, because the model is bigger than the gpu memory of individual Card, but I could not make it because of the out of memory problem) |
use amp instead (torch 1.6 or higher is
needed)
…------------------ Original ------------------
From: zzz123xyz ***@***.***>
Date: Tue,May 18,2021 0:57 AM
To: NVIDIA/apex ***@***.***>
Cc: hugq ***@***.***>, Comment ***@***.***>
Subject: Re: [NVIDIA/apex] GPU memory issues (leak?) (#439)
|
I am running a loop where I initialize a new model in each loop and train it. I am using NVIDIA Apex for mixed precision training. My current issue is that there seems to be some unwanted memory allocations across different steps in the loop. The GPU memory accumulates and after a few steps in the loop CUDA memory runs out.
I have debugged everything, have monitored memory, and have deleted every single thing possible. Only after removing apex the memory allocation seems to be consistent. I am doing nothing else than adding the three lines of code from the tutorial for initializing and backward passing.
Any ideas?
The text was updated successfully, but these errors were encountered: