Skip to content

Commit

Permalink
Ensure optimizer state references are cleared (pytorch#100282)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlazos authored and valentinandrei committed May 2, 2023
1 parent 062ea9a commit b84bcc7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
31 changes: 30 additions & 1 deletion test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -2987,10 +2987,39 @@ def f(x):

self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4)))

def test_optim_state_references_cleared(self):
model = torch.nn.Linear(2048, 2048, bias=False)
x = torch.ones(2048)
state_ref = 0

optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)

def opt_step():
optimizer.step()

compiled_opt_step = torch._dynamo.optimize("eager")(opt_step)

def compiled_model_step(x):
optimizer.zero_grad()
y = model(x)
torch.sum(y).backward()
compiled_opt_step()

compiled_model_step(x)

# Picked "square_avg" arbitrarily to check that
# optimizer state tensors are deallocated
state_ref = weakref.ref(
optimizer.state[optimizer.param_groups[0]["params"][0]]["square_avg"]
)
optimizer = None

self.assertIsNone(state_ref())

def test_grad_references_cleared(self):
model = torch.nn.Linear(2048, 2048, bias=False)
x = torch.ones(2048)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)

def opt_step():
optimizer.step()
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def cleanup(self) -> None:
del node.meta["grapharg"]
self.real_value_cache.clear()
self.input_name_to_proxy.clear()
self.side_effects.keepalive = []
self.side_effects.clear()

def create_proxy(
self,
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/side_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,7 @@ def is_empty(self):
any(map(self.is_modified, self.id_to_variable.values()))
or self.save_for_backward
)

def clear(self):
self.keepalive.clear()
self.id_to_variable.clear()

0 comments on commit b84bcc7

Please sign in to comment.