Skip to content

Commit

Permalink
Only disable one level of inlining
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 1, 2024
1 parent 10a841f commit b29be45
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
29 changes: 9 additions & 20 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.compiler

from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
Expand Down Expand Up @@ -137,25 +138,13 @@ def makevector(*x):


@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)

fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)

if len(op.fgraph.outputs) > 1:

def inner(*args):
return fgraph_fn(*args)
else:

def inner(*args):
return fgraph_fn(*args)[0]

# Don't compile the inner function
# This is due torch failing to create
# guards when parent scoped closure variables
# are used in conditional statements.
# Instead of rewriting many portions of code
# this will allow for only this small section to
# not be compiled by the outer graph
return torch.compiler.disable(inner)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
# Disable one step inlining to prevent torch from trying to import local functions
# defined in `pytorch_funcify`
return torch.compiler.disable(fgraph_fn, recursive=False)
4 changes: 2 additions & 2 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def compare_pytorch_and_py(

if len(fgraph.outputs) > 1:
for j, p in zip(pytorch_res, py_res):
assert_fn(j.cpu(), p)
assert_fn(j.detach().cpu().numpy(), p)
else:
assert_fn([pytorch_res[0].cpu()], py_res)
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])

return pytensor_torch_fn, pytorch_res

Expand Down

0 comments on commit b29be45

Please sign in to comment.