Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fake tensor] Fake tensor is leaked out with new with Aot Autograd #1448

Closed
anijain2305 opened this issue Oct 1, 2022 · 1 comment
Closed
Assignees

Comments

@anijain2305
Copy link
Contributor

Repro

import torch
from torch.nn import *

import torchdynamo


class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, argsort: torch.Tensor):
        new = argsort.new(2, 12, 4096)
        x = torch.add(new, 2)
        return (
            new,
            x,
        )


mod = Bar().to(device="cpu")

inp = torch.randn((2, 12, 4096), device="cpu")


def fn(x):
    y = mod(x)
    print("break", y[0].shape)
    return [torch.sin(l) for l in y]


opt_mod = torchdynamo.optimize("aot_eager")(fn)
opt_mod(inp)

Error

  File "/scratch/anijain/work/torchdynamo/torchdynamo/output_graph.py", line 387, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/scratch/anijain/work/torchdynamo/torchdynamo/output_graph.py", line 420, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torchdynamo.exc.BackendCompilerFailed: compile_fn raised Exception: Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. Please convert all Tensors to FakeTensors first. Found in aten.copy_.default
@anijain2305
Copy link
Contributor Author

Related - #1447

eellison added a commit to pytorch/pytorch that referenced this issue Oct 3, 2022
Delete the special-cased handling of `new` in FakeTensor. Ever since the dispatch keys were updated to reflect the FakeTensor's device, the special cased handling was not needed.

Fixes pytorch/torchdynamo#1448


[ghstack-poisoned]
eellison added a commit to pytorch/pytorch that referenced this issue Oct 3, 2022
Delete the special-cased handling of `new` in FakeTensor. Ever since the dispatch keys were updated to reflect the FakeTensor's device, the special cased handling was not needed.

Fixes pytorch/torchdynamo#1448


[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Oct 4, 2022
Delete the special-cased handling of `new` in FakeTensor. Ever since the dispatch keys were updated to reflect the FakeTensor's device, the special cased handling was not needed.

Fixes pytorch/torchdynamo#1448

Pull Request resolved: #86144
Approved by: https://github.com/ezyang
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Oct 4, 2022
Delete the special-cased handling of `new` in FakeTensor. Ever since the dispatch keys were updated to reflect the FakeTensor's device, the special cased handling was not needed.

Fixes pytorch/torchdynamo#1448


[ghstack-poisoned]
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this issue Oct 4, 2022
Delete the special-cased handling of `new` in FakeTensor. Ever since the dispatch keys were updated to reflect the FakeTensor's device, the special cased handling was not needed.

Fixes pytorch/torchdynamo#1448


[ghstack-poisoned]
@eellison eellison closed this as completed Oct 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants