diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 449b9ee779..6faf38ae0d 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -4,6 +4,7 @@ import torch import torch.compiler +from pytensor.compile import PYTORCH from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph @@ -137,25 +138,13 @@ def makevector(*x): @pytorch_funcify.register(OpFromGraph) -def pytorch_funcify_OpFromGraph(op, node=None, **kwargs): - _ = kwargs.pop("storage_map", None) +def pytorch_funcify_OpFromGraph(op, node, **kwargs): + kwargs.pop("storage_map", None) - fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) + # Apply inner rewrites + PYTORCH.optimizer(op.fgraph) - if len(op.fgraph.outputs) > 1: - - def inner(*args): - return fgraph_fn(*args) - else: - - def inner(*args): - return fgraph_fn(*args)[0] - - # Don't compile the inner function - # This is due torch failing to create - # guards when parent scoped closure variables - # are used in conditional statements. - # Instead of rewriting many portions of code - # this will allow for only this small section to - # not be compiled by the outer graph - return torch.compiler.disable(inner) + fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) + # Disable one step inlining to prevent torch from trying to import local functions + # defined in `pytorch_funcify` + return torch.compiler.disable(fgraph_fn, recursive=False) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 4b0790a129..1945790fda 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -68,9 +68,9 @@ def compare_pytorch_and_py( if len(fgraph.outputs) > 1: for j, p in zip(pytorch_res, py_res): - assert_fn(j.cpu(), p) + assert_fn(j.detach().cpu().numpy(), p) else: - assert_fn([pytorch_res[0].cpu()], py_res) + assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0]) return pytensor_torch_fn, pytorch_res