Skip to content

Commit

Permalink
[functorch] Disable autocast (pytorch/functorch#794)
Browse files Browse the repository at this point in the history
* Disable autocast

* Add global flag

* Add a test
  • Loading branch information
anijain2305 authored and zou3519 committed Jul 20, 2022
1 parent d8c020d commit ee941bf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
8 changes: 8 additions & 0 deletions functorch/functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ class CompiledFunction(torch.autograd.Function):
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
nonlocal compiled_fw, compiled_bw, num_outs
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
if compiled_fw is None:
with preserve_rng_state():
# Set input tensors that require grad to leaves
Expand Down Expand Up @@ -194,15 +197,20 @@ def forward(ctx, *flat_tensor_args):
compiled_bw = bw_compiler(bw_module, bw_args)
else:
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])

@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
contiguous_args = [t.contiguous() for t in flat_args]
# contiguous_args = [t for t in flat_args]
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
return tuple(out)

return CompiledFunction
Expand Down
18 changes: 17 additions & 1 deletion functorch/test/test_pythonkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop,
num_of_recompilations, default_partition, default_decompositions
num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion,
)

from torch.testing._internal.common_device_type import ops
Expand Down Expand Up @@ -564,6 +564,22 @@ def fn(x):
assert torch.allclose(ref, res)


class TestAutocast(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_autocast(self):
mod = torchvision.models.resnet18().cuda()
mod.train()

x = torch.randn(16, 3, 32, 32, device="cuda")
aot_mod = memory_efficient_fusion(mod)

# Ensure that AOT Autograd works with AMP
with torch.cuda.amp.autocast(True):
res = aot_mod(x)
res.sum().backward()


only_for = ("cpu")
instantiate_device_type_tests(
TestPythonKey,
Expand Down

0 comments on commit ee941bf

Please sign in to comment.