Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable autocast #794

Merged
merged 4 commits into from
Jun 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions functorch/_src/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ class CompiledFunction(torch.autograd.Function):
@disable_torchdynamo
def forward(ctx, *flat_tensor_args):
nonlocal compiled_fw, compiled_bw, num_outs
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
if compiled_fw is None:
with preserve_rng_state():
# Set input tensors that require grad to leaves
Expand Down Expand Up @@ -194,15 +197,20 @@ def forward(ctx, *flat_tensor_args):
compiled_bw = bw_compiler(bw_module, bw_args)
else:
fw_outs = normalize_as_list(compiled_fw(*flat_tensor_args))
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
ctx.save_for_backward(*fw_outs[num_outs:])
return tuple(fw_outs[0:num_outs])

@staticmethod
@disable_torchdynamo
def backward(ctx, *flat_args):
# Disable the JIT Autocast flag to prevent re-autocasting of jitted graph.
# TODO - Remove when https://github.com/pytorch/functorch/pull/794 is fixed.
old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
contiguous_args = [t.contiguous() for t in flat_args]
# contiguous_args = [t for t in flat_args]
out = normalize_as_list(compiled_bw(*ctx.saved_tensors, *contiguous_args))
torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
return tuple(out)

return CompiledFunction
Expand Down
18 changes: 17 additions & 1 deletion test/test_pythonkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from functorch.compile import (
nnc_jit, compiled_function, compiled_module,
min_cut_rematerialization_partition, aot_function, aot_module, decomposition_table, nop,
num_of_recompilations, default_partition, default_decompositions
num_of_recompilations, default_partition, default_decompositions, memory_efficient_fusion,
)

from torch.testing._internal.common_device_type import ops
Expand Down Expand Up @@ -564,6 +564,22 @@ def fn(x):
assert torch.allclose(ref, res)


class TestAutocast(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_autocast(self):
mod = torchvision.models.resnet18().cuda()
mod.train()

x = torch.randn(16, 3, 32, 32, device="cuda")
aot_mod = memory_efficient_fusion(mod)

# Ensure that AOT Autograd works with AMP
with torch.cuda.amp.autocast(True):
res = aot_mod(x)
res.sum().backward()


only_for = ("cpu")
instantiate_device_type_tests(
TestPythonKey,
Expand Down