Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leak when using Metric with list state #4098

Closed
nathanpainchaud opened this issue Oct 12, 2020 · 12 comments · Fixed by #4313
Closed

Memory leak when using Metric with list state #4098

nathanpainchaud opened this issue Oct 12, 2020 · 12 comments · Fixed by #4313
Labels
help wanted Open to be worked on

Comments

@nathanpainchaud
Copy link
Contributor

🐛 Bug

I tried implementing a custom Metric to use as a loss when training. It seems to compute the desired values fine, however like the title states the metric quickly consumes all the memory on my single GPU. Models that previously required less than half of my GPU memory now run into OOMs after less than one epoch.

I tried replicating the issue in this Colab notebook, however the dataset and training procedure are too lightweight to replicated the resource consumption issue.

Can anyone confirm whether it's an issue with how I handle the list state myself or a bug with Metric itself? The only metric I found that uses a list state that I could draw inspiration from is the explained_variance. However, I thought I should be able to do things differently and wouldn't need to store the targets and preds in the state, only the computed results.

To Reproduce

Colab notebook that tries to reproduce the issue (with no success) on a toy example.

If it can be of further help, my real use-case is the following, where I implement the dice loss as a metric with which to train my model:

class DifferentiableDiceCoefficient(Metric):
    """Computes a differentiable version of the dice coefficient."""

    def __init__(
        self,
        include_background: bool = False,
        nan_score: float = 0.0,
        no_fg_score: float = 0.0,
        reduction: str = "elementwise_mean",
        compute_on_step: bool = True,
        dist_sync_on_step: bool = False,
        process_group: Optional[Any] = None,
    ):
        super().__init__(
            compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group
        )
        self.include_background = include_background
        self.nan_score = nan_score
        self.no_fg_score = no_fg_score
        assert reduction in ("elementwise_mean", "none")
        self.reduction = reduction

        self.add_state("dice_by_steps", [])

    def update(self, input: torch.Tensor, target: torch.Tensor) -> None:
        self.dice_by_steps += [
            differentiable_dice_score(
                input=input,
                target=target,
                bg=self.include_background,
                nan_score=self.nan_score,
                no_fg_score=self.no_fg_score,
                reduction=self.reduction,
            )
        ]

    def compute(self) -> torch.Tensor:
        return torch.mean(torch.stack(self.dice_by_steps), 0)

Expected behavior

I would expect to be able to use metrics/losses implemented using the Metric API exactly like I would if they inherited from nn.Module.

@nathanpainchaud nathanpainchaud added bug Something isn't working help wanted Open to be worked on labels Oct 12, 2020
@SkafteNicki
Copy link
Member

It is expected behavior I would say, since the metric class is asked to keep track of hole computed history instead of a single state that can be updated (like all other metrics than ExplainedVariance).
I have added a warning in ExplainedVariance just because of this:
https://github.com/PyTorchLightning/pytorch-lightning/blob/09c2020a9325850bc159d2053b30c0bb627e5bbb/pytorch_lightning/metrics/regression/explained_variance.py#L89-L91
That said, could it be a option for you to:

  • Call detach() on the output of differentiable_dice_score` (i.e. are you going to backprop the computed value)?
  • If all you are doing in the end is call mean on the stacked list of self.dice_by_steps you could instead initialize this as a tensor, which you just add to during update and then also log how many times update has been called, such that you can compute the mean at the end in compute? Should be much more memory efficient.

@nathanpainchaud
Copy link
Contributor Author

@SkafteNicki Thanks for the explanations. I think I might not be getting what Metric is designed to be used for. I can't use detach() because I indeed want to backprop through the metric, i.e. I implement my loss using the Metric API. Is this not recommended?

And I understand that the memory footprint might be deceptively high, since whole graphs need to be kept in memory to be able to backproped through. However, I would have assumed that since I call it using compute_on_step=True, backprop should be performed for each step, followed by a reset of the state that clears the graph in memory. Or is there no guarantee about when memory will be freed once I use a Metric?

Thanks again for the clarifications on the recommended use of Metrics!

@edenlightning
Copy link
Contributor

@ananyahjha93

@SkafteNicki
Copy link
Member

@nathanpainchaud all loss functions are metrics, but not all metrics are loss functions. Two requirements to loss functions that does not necessarily apply to all metrics:

  1. Be differentiable. Accuracy is a good example of a metric that is non-differentiable and therefore is not a loss function
  2. Can be decomposed into individual terms per sample (at least necessary if we want to do minibatching). ExplainedVariance is a good example of a metric, where the value is not a simple sum/mean over each individual sample

When you set compute_on_step=True, it is correct that the value returned from forward is the metric calculated on the current input, however it is also stored in state. Please see this part of the code:

https://github.com/PyTorchLightning/pytorch-lightning/blob/0474464c454a000e4cacfe188c86a0b8317288d5/pytorch_lightning/metrics/metric.py#L148-L165

As you can see the current state gets cached, get used to calculate for the current input, and then gets restore afterwards.

The hole idea of implementing using the Metrics class is that you want to compute a value over multiple batches, not just a single one. I would have to see more of your code to judge if using the Metric class is correct for your usecase (what does differentiable_dice_score look like and how to you call and use the metric).

@nathanpainchaud
Copy link
Contributor Author

@SkafteNicki I do understand the difference between metrics and loss functions 😉

What I think I don't understand is to take into account self._cache. I now understand the caching mechanism (thanks for the explanation!) but I'm not sure how to work with it to implement a loss with a reasonable memory footprint. I don't see where self._cache gets flushed, so should I assume that the cache only grows during training, and even across epochs? Would that mean the whole computational graph MUST be kept in memory in the case of a loss that is not detached from the graph?

As for my own use-case, differentiable_dice_score closely resembles the functional dice_score (it has the same API) provided in the Metrics package, only it is the loss version of the dice coefficient that is differentiable and can be backproped through to train models. And is use it has a regular metric, i.e. create the object in the model's init, and call it in the training loop, kind of like below:

def __init__(self, ...):
    ...
    self._dice = DifferentiableDiceCoefficient(include_background=False)
    
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
    x, y = batch
    y_hat = self.module(x)
    dice = self._dice(y_hat, y)
    return loss

@SkafteNicki
Copy link
Member

@nathanpainchaud just wanted to make sure that we were on the same page 😉

If the metric update just add to a list, then yes, the memory footprint will just grow and grow. This is the very reason why most of our metric (all except for ExplainedVarariance) have a two tensors which they just add. It would seem to me that in your usecase you could do something similar:

class DifferentiableDiceCoefficient(Metric):
    """Computes a differentiable version of the dice coefficient."""

    def __init__(
        self,
        include_background: bool = False,
        nan_score: float = 0.0,
        no_fg_score: float = 0.0,
        reduction: str = "elementwise_mean",
        compute_on_step: bool = True,
        dist_sync_on_step: bool = False,
        process_group: Optional[Any] = None,
    ):
        super().__init__(
            compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group
        )
        self.include_background = include_background
        self.nan_score = nan_score
        self.no_fg_score = no_fg_score
        assert reduction in ("elementwise_mean", "none")
        self.reduction = reduction

        self.add_state("sum_dice_score", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("counter", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, input: torch.Tensor, target: torch.Tensor) -> None:
        self.sum_dice_score += differentiable_dice_score(
                input=input,
                target=target,
                bg=self.include_background,
                nan_score=self.nan_score,
                no_fg_score=self.no_fg_score,
                reduction=self.reduction,
            )
        self.counter+=1

    def compute(self) -> torch.Tensor:
        return self.sum_dice_score / self.counter


@nathanpainchaud
Copy link
Contributor Author

@SkafteNicki Thanks a lot for the recommended implementation. I've just now gotten around to implementing it. However, I still get the CUDA OOM after the same number of steps.

Like I mentioned in one my earlier message, I suspect that behind the scenes PyTorch keeps the computational graph of the previous steps because the results (requiring grad) are saved as part of the state of the metric. Reading through the Metric class implementation, I didn't find anything in the built-in metrics that seemed to ensure this doesn't happen. Would you have any recommendation on what inner Metric state to monitor to understand why my metric seems to behave differently from the built-in ones?

Thanks in advance, as well as for all your help until now!

@edenlightning edenlightning added this to the 1.0.3 milestone Oct 19, 2020
@edenlightning edenlightning added the priority: 0 High priority task label Oct 20, 2020
@SkafteNicki
Copy link
Member

@nathanpainchaud if all you want to do is to backpropergate the metric, why don't you just do that?

def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
    x, y = batch
    y_hat = self.module(x)
    dice = self.differentiable_dice_score(y_hat, y)
    return dice

@edenlightning edenlightning removed priority: 0 High priority task bug Something isn't working labels Oct 22, 2020
@edenlightning edenlightning removed this from the 1.0.x milestone Oct 22, 2020
@nathanpainchaud
Copy link
Contributor Author

@SkafteNicki That is what I'm doing right now. But I wanted to see if my loss could be implemented as a Metric instead, to benefit from all the built-in functionalities the Metric API offers. But it seems that it might be more trouble than it's worth at this point...

@SkafteNicki
Copy link
Member

@nathanpainchaud we are reevaluating if our metric API should support backprop, as it seems to give troubles.
You are completely right that there is nothing in the build in Metric class that prevents the computational graph from being stored.
Maybe the compromise would be to internally call detach before values are stored in the cache. With this:

val = metric(pred, target) # this can be backprop
val = metric.compute() # this cannot be backprop

@ananyahjha93
Copy link
Contributor

@SkafteNicki how about we call detach before caching for the epoch but let it be a part of the graph for the step? That way you can backprop for the step but detach from the graph before adding it to compute over epoch and GPUs?

@SkafteNicki
Copy link
Member

@ananyahjha93 that was also what I was thinking. I will send a PR soon with the update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants