diff --git a/functorch/_src/aot_autograd.py b/functorch/_src/aot_autograd.py index 223df6a66..d0da1f24b 100644 --- a/functorch/_src/aot_autograd.py +++ b/functorch/_src/aot_autograd.py @@ -160,6 +160,9 @@ class CompiledFunction(torch.autograd.Function): @disable_torchdynamo def forward(ctx, *flat_tensor_args): nonlocal compiled_fw, compiled_bw, num_outs + # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph. + # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed. + old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) if compiled_fw is None: with preserve_rng_state(): # Set input tensors that require grad to leaves @@ -194,15 +197,20 @@ def forward(ctx, *flat_tensor_args): compiled_bw = bw_compiler(bw_module, bw_args) else: fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args)) + torch._C._jit_set_autocast_mode(old_jit_autocast_flag) ctx.save_for_backward(*fw_outs[num_outs:]) return tuple(fw_outs[0:num_outs]) @staticmethod @disable_torchdynamo def backward(ctx, *flat_args): + # Disable the JIT Autocast flag to prevent re-autocasting of jitted graph. + # TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed. + old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) contiguous_args = [t.contiguous() for t in flat_args] # contiguous_args = [t for t in flat_args] out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args)) + torch._C._jit_set_autocast_mode(old_jit_autocast_flag) return tuple(out) return CompiledFunction diff --git a/test/test_pythonkey.py b/test/test_pythonkey.py index faa836d48..16b3add73 100644 --- a/test/test_pythonkey.py +++ b/test/test_pythonkey.py @@ -21,7 +21,7 @@ from functorch.compile import ( nnc_jit, compiled_function, compiled_module, min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop, - num_of_recompilations, default_partition, default_decompositions + num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion, ) from torch.testing._internal.common_device_type import ops @@ -564,6 +564,22 @@ def fn(x): assert torch.allclose(ref, res) +class TestAutocast(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") + @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") + def test_autocast(self): + mod = torchvision.models.resnet18().cuda() + mod.train() + + x = torch.randn(16, 3, 32, 32, device="cuda") + aot_mod = memory_efficient_fusion(mod) + + # Ensure that AOT Autograd works with AMP + with torch.cuda.amp.autocast(True): + res = aot_mod(x) + res.sum().backward() + + only_for = ("cpu") instantiate_device_type_tests( TestPythonKey,