Skip to content

Commit

Permalink
Cherry-picking upstream commit to get test_graph_make_graphed_callabl…
Browse files Browse the repository at this point in the history
…es to pass (#1144)

[Re-land] [CUDA graphs] Clear autocast amp cache (pytorch#81896)

Re-lands pytorch#81558 that got reverted due failing tests.

This failure happened because of the test that I poorly designed. [The loop here](https://github.com/pytorch/pytorch/pull/81558/files#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450R3837) is doing `cache_enabled=False` and then `cache_enabled=True`. By doing this loop the graph from previous iteration (case `False`) conflicts with the next one (case `True`). I redesigned the test such that it does not do any loops. The new test does separate function calls with different argument values.
Pull Request resolved: pytorch#81896
Approved by: https://github.com/ngimel

Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com>
  • Loading branch information
jithunnair-amd and Aidyn-A authored Nov 28, 2022
1 parent 79e07b3 commit 3c91b12
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
26 changes: 17 additions & 9 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
_compare_trilu_indices, _compare_large_trilu_indices
from torch.testing._internal.common_utils import TestCase, freeze_rng_state, run_tests, \
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_REMOTE_GPU, IS_SANDCASTLE, IS_WINDOWS, \
slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY
slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_WITH_ROCM, TEST_NUMPY, \
parametrize, instantiate_parametrized_tests
from torch.testing._internal.autocast_test_lists import AutocastTestLists

# load_tests from common_utils is used to automatically filter tests for
Expand Down Expand Up @@ -3699,7 +3700,10 @@ def test_graph_grad_scaling(self):
self.assertEqual(scaler._growth_tracker, growth_tracker)

@unittest.skipIf((not TEST_GRAPH), "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
def test_graph_make_graphed_callables(self):
@parametrize('with_amp,cache_enabled', [(False, False), (True, False), (True, True)],
name_fn=lambda x, y: '{}{}'.format({True: "with_amp", False: "without_amp"}[x],
{True: "_cache_enabled", False: "_cache_disabled"}[y] if x else ''))
def test_graph_make_graphed_callables(self, with_amp, cache_enabled):
torch.manual_seed(5)
torch.cuda.manual_seed(5)

Expand Down Expand Up @@ -3730,9 +3734,10 @@ def test_graph_make_graphed_callables(self):
relu_control = torch.nn.functional.relu

# This is a good stress test. It graphs four callables: two Modules and two python functions.
model_graphed[0], model_graphed[1], relu_graphed, loss_fn_graphed = \
torch.cuda.make_graphed_callables((model_graphed[0], model_graphed[1], relu_control, loss_fn_control),
((x,), (h,), (y_pred,), (y_pred, y)))
with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
model_graphed[0], model_graphed[1], relu_graphed, loss_fn_graphed = \
torch.cuda.make_graphed_callables((model_graphed[0], model_graphed[1], relu_control, loss_fn_control),
((x,), (h,), (y_pred,), (y_pred, y)))

real_inputs = [torch.rand_like(x) for _ in range(10)]
real_targets = [torch.rand_like(y) for _ in range(10)]
Expand All @@ -3747,10 +3752,11 @@ def test_graph_make_graphed_callables(self):
torch.cuda.manual_seed(5)
for data, target in zip(real_inputs, real_targets):
opt.zero_grad(set_to_none=True)
y_pred = m(data)
y_pred = relu(y_pred)
loss = loss_fn(y_pred, target)
loss.backward()
with torch.cuda.amp.autocast(with_amp, cache_enabled=cache_enabled):
y_pred = m(data)
y_pred = relu(y_pred)
loss = loss_fn(y_pred, target)
loss.backward()
opt.step()

for p, pc in zip(model_graphed.parameters(), model_control.parameters()):
Expand Down Expand Up @@ -4247,5 +4253,7 @@ class TestNamedTupleInput_1(NamedTuple):
cat = torch.cat((outputs[0][i].to('cpu'), outputs[1][i].to('cpu')))
self.assertTrue(torch.equal(x, cat))

instantiate_parametrized_tests(TestCuda)

if __name__ == '__main__':
run_tests()
6 changes: 6 additions & 0 deletions torch/cuda/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ def make_graphed_callables(callables, sample_args):
# the safest approach is to capture all passes in the same order they'll run:
# fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1.

# Clear AMP autocast cache before capturing the graphs
torch.clear_autocast_cache()

# Capture forward graphs
per_callable_static_outputs = []
per_callable_output_was_tensor = []
Expand Down Expand Up @@ -328,6 +331,9 @@ def make_graphed_callables(callables, sample_args):
per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs))
# Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable.

# Clear AMP autocast cache after both forward and backward graphs are captured
torch.clear_autocast_cache()

def make_graphed_autograd_function(fwd_graph,
bwd_graph,
module_params,
Expand Down

0 comments on commit 3c91b12

Please sign in to comment.