Skip to content

Commit

Permalink
Aot Autograd - Dont fallback with new op (#1493)
Browse files Browse the repository at this point in the history
* Aot Autograd - Dont fallback with new op

* Bring back moco - Moco fails with einsum float16 error

* CI
  • Loading branch information
anijain2305 authored Oct 6, 2022
1 parent 9b69076 commit 7aa0e60
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 24 deletions.
17 changes: 0 additions & 17 deletions test/test_aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,6 @@ def fn(param, y):
aot_fn(x, y)
self.assertTrue(not is_safe[0])

def test_new(self):
# https://github.com/pytorch/torchdynamo/issues/1448
def fn(argsort: torch.Tensor):
new = argsort.new(2, 12, 4096)
x = torch.add(new, 2)
return (
new,
x,
)

x = torch.randn((2, 12, 4096))
is_safe = [True]
compiler_fn = functools.partial(compiler_safe_fn, is_safe=is_safe)
aot_fn = torchdynamo.optimize(compiler_fn)(fn)
aot_fn(x)
self.assertTrue(not is_safe[0])

def test_negative_testing(self):
def fn(x, y):
return torch.sin(x).add_(y)
Expand Down
8 changes: 1 addition & 7 deletions torchdynamo/optimizations/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def is_aot_autograd_safe_to_run(gm, example_inputs):
Issues
1) LSTM - https://github.com/pytorch/torchdynamo/issues/1147
2) LSTM - https://github.com/pytorch/functorch/issues/586
3) New op - https://github.com/pytorch/torchdynamo/issues/1448
4) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301
3) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301
"""

def raise_or_warn(reason):
Expand All @@ -52,11 +51,6 @@ def raise_or_warn(reason):
if submod.__class__.__name__ == "LSTM":
return raise_or_warn("LSTM")

# 2) new does not work with fake tensor and aot autograd
for node in gm.graph.nodes:
if node.op == "call_method" and node.target == "new":
return raise_or_warn("new operator")

# 2) Mutation in the graph
mutated = False
try:
Expand Down

0 comments on commit 7aa0e60

Please sign in to comment.