Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subtle memory leak in _ReversibleModuleFunction #1

Closed
spezold opened this issue Aug 28, 2019 · 4 comments
Closed

Subtle memory leak in _ReversibleModuleFunction #1

spezold opened this issue Aug 28, 2019 · 4 comments

Comments

@spezold
Copy link

spezold commented Aug 28, 2019

Hi, first of all: very nice work and congrats to your MICCAI paper!

I would like to point out to you a subtle memory leak in _ReversibleModuleFunction, which is due to not using ctx.save_for_backward() for storing x. The memory leak occurs under rare conditions, namely if a network output is not consumed by the loss term, thus it is not backpropagated through, and thus del ctx.y in _ReversibleModuleFunction.backward() never happens, as _ReversibleModuleFunction.backward() for this network output is never called in the first place (at least, this is my uneducated guess on the source of the leak).

Consider the following minimal example:

class Net(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(2, 2, kernel_size=1)
        self.rev1 = ReversibleSequence(nn.ModuleList([
                ReversibleBlock(nn.Conv2d(1, 1, kernel_size=1),
                                nn.Conv2d(1, 1, kernel_size=1))]))
        self.rev2 = ReversibleSequence(nn.ModuleList(
                [ReversibleBlock(nn.Conv2d(1, 1, kernel_size=1),
                                 nn.Conv2d(1, 1, kernel_size=1))]))
    
    def forward(self, x):
        x = nn.functional.relu(self.conv(x))
        y1 = nn.functional.relu(self.rev1(x))
        y2 = nn.functional.relu(self.rev2(x))
        return y1, y2
    

def my_loss(output, unused):
    
    result = output.sum()  # + 0 * unused.sum()
    return result
    

if __name__ == "__main__":
    
    model = Net().to(torch.device("cpu"))
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
    
    model.train()
    for i in range(10000):
        data = torch.randn(64, 2, 32, 32, dtype=torch.float32)
        y1, unused = model(data)
        loss = my_loss(y1, unused)
        loss.backward()
        optimizer.step()

As you can see, the second network output, y2, is nowhere used in the loss calculation, and memory consumption is building up. I found two ways to fix the leak:

  1. Make sure that all network outputs are consumed by the loss term (even if it is by multiplying them with zero and adding them, see code comment above).
  2. Use ctx.save_for_backward() and ctx.saved_tensors for storing and retrieving x, respectively, in _ReversibleModuleFunction.

Maybe you want to try to reproduce the memory leak, as I am not sure if it depends on the PyTorch version and/or operating system (my setup is PyTorch 1.2.0 on Windows 10). You may then want to decide whether you change the implementation of _ReversibleModuleFunction or whether you point out to the users the need to "consume" all network outputs, as described above.

@spezold
Copy link
Author

spezold commented Aug 28, 2019

Just an update: using ctx.y = x.detach() in _ReversibleModuleFunction.forward() also seems to do the trick of fixing the leak. Seems like my uneducated guess from above was wrong.

@RobinBruegger
Copy link
Owner

Thank you very much for reporting this issue and also providing a fix! I am currently on vacation and will have a look at this as soon as I get back home.

@spezold
Copy link
Author

spezold commented Aug 30, 2019

Hi Robin, thanks for your feedback! In the meantime I found myself working with ctx.y = x.detach() in _ReversibleModuleFunction.forward(), as written above. It doesn't seem to have any side effects, and not using ctx.save_for_backward(), as you currently do, really avoids another memory bottleneck. So maybe, if you can't find any side effects either, this should be the way to go. Anyway, enjoy your vacation!

@RobinBruegger
Copy link
Owner

Hi @spezold , I have created Version 0.2.0 with the fix that you suggested (ctx.y = x.detach()) in _ReversibleModuleFunction.forward(). I did not observe any side effects with this change. Thank you very much for your contribution :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants