-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory leak when using Metric with list state #4098
Comments
It is expected behavior I would say, since the metric class is asked to keep track of hole computed history instead of a single state that can be updated (like all other metrics than
|
@SkafteNicki Thanks for the explanations. I think I might not be getting what And I understand that the memory footprint might be deceptively high, since whole graphs need to be kept in memory to be able to backproped through. However, I would have assumed that since I call it using Thanks again for the clarifications on the recommended use of |
@nathanpainchaud all loss functions are metrics, but not all metrics are loss functions. Two requirements to loss functions that does not necessarily apply to all metrics:
When you set As you can see the current state gets cached, get used to calculate for the current input, and then gets restore afterwards. The hole idea of implementing using the |
@SkafteNicki I do understand the difference between metrics and loss functions 😉 What I think I don't understand is to take into account As for my own use-case, def __init__(self, ...):
...
self._dice = DifferentiableDiceCoefficient(include_background=False)
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
x, y = batch
y_hat = self.module(x)
dice = self._dice(y_hat, y)
return loss |
@nathanpainchaud just wanted to make sure that we were on the same page 😉 If the metric update just add to a list, then yes, the memory footprint will just grow and grow. This is the very reason why most of our metric (all except for
|
@SkafteNicki Thanks a lot for the recommended implementation. I've just now gotten around to implementing it. However, I still get the CUDA OOM after the same number of steps. Like I mentioned in one my earlier message, I suspect that behind the scenes PyTorch keeps the computational graph of the previous steps because the results (requiring grad) are saved as part of the state of the metric. Reading through the Metric class implementation, I didn't find anything in the built-in metrics that seemed to ensure this doesn't happen. Would you have any recommendation on what inner Metric state to monitor to understand why my metric seems to behave differently from the built-in ones? Thanks in advance, as well as for all your help until now! |
@nathanpainchaud if all you want to do is to backpropergate the metric, why don't you just do that? def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
x, y = batch
y_hat = self.module(x)
dice = self.differentiable_dice_score(y_hat, y)
return dice |
@SkafteNicki That is what I'm doing right now. But I wanted to see if my loss could be implemented as a |
@nathanpainchaud we are reevaluating if our metric API should support backprop, as it seems to give troubles.
|
@SkafteNicki how about we call detach before caching for the epoch but let it be a part of the graph for the step? That way you can backprop for the step but detach from the graph before adding it to compute over epoch and GPUs? |
@ananyahjha93 that was also what I was thinking. I will send a PR soon with the update. |
🐛 Bug
I tried implementing a custom
Metric
to use as a loss when training. It seems to compute the desired values fine, however like the title states the metric quickly consumes all the memory on my single GPU. Models that previously required less than half of my GPU memory now run into OOMs after less than one epoch.I tried replicating the issue in this Colab notebook, however the dataset and training procedure are too lightweight to replicated the resource consumption issue.
Can anyone confirm whether it's an issue with how I handle the list state myself or a bug with
Metric
itself? The only metric I found that uses a list state that I could draw inspiration from is theexplained_variance
. However, I thought I should be able to do things differently and wouldn't need to store thetarget
s andpred
s in the state, only the computed results.To Reproduce
Colab notebook that tries to reproduce the issue (with no success) on a toy example.
If it can be of further help, my real use-case is the following, where I implement the dice loss as a metric with which to train my model:
Expected behavior
I would expect to be able to use metrics/losses implemented using the
Metric
API exactly like I would if they inherited fromnn.Module
.The text was updated successfully, but these errors were encountered: