Skip to content

Commit

Permalink
Disable the opfromgraph inner function from compiling
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Schweer committed Aug 18, 2024
1 parent d98e68e commit 10a841f
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from types import NoneType

import torch
import torch.compiler

from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
Expand Down Expand Up @@ -139,14 +140,8 @@ def makevector(*x):
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)

# @todo: Torch compile doesn't capture the scope accounting
# for op.fgraph, leading to an import error. Disable the
# dynamo compile for these graphs
import torch._dynamo.config

torch._dynamo.config.suppress_errors = True

fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))

if len(op.fgraph.outputs) > 1:

def inner(*args):
Expand All @@ -156,4 +151,11 @@ def inner(*args):
def inner(*args):
return fgraph_fn(*args)[0]

return inner
# Don't compile the inner function
# This is due torch failing to create
# guards when parent scoped closure variables
# are used in conditional statements.
# Instead of rewriting many portions of code
# this will allow for only this small section to
# not be compiled by the outer graph
return torch.compiler.disable(inner)

0 comments on commit 10a841f

Please sign in to comment.