diff --git a/pytensor/compile/__init__.py b/pytensor/compile/__init__.py index 04eba83290..9bd140d746 100644 --- a/pytensor/compile/__init__.py +++ b/pytensor/compile/__init__.py @@ -30,6 +30,7 @@ OPT_O3, OPT_STABILIZE, OPT_UNSAFE, + PYTORCH, AddDestroyHandler, AddFeatureOptimizer, Mode, diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 2cbb3631a9..e2edcf0fe4 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -3,7 +3,10 @@ import numpy as np import torch +import torch.compiler +from pytensor.compile import PYTORCH +from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python @@ -150,6 +153,19 @@ def makevector(*x): return makevector +@pytorch_funcify.register(OpFromGraph) +def pytorch_funcify_OpFromGraph(op, node, **kwargs): + kwargs.pop("storage_map", None) + + # Apply inner rewrites + PYTORCH.optimizer(op.fgraph) + + fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) + # Disable one step inlining to prevent torch from trying to import local functions + # defined in `pytorch_funcify` + return torch.compiler.disable(fgraph_fn, recursive=False) + + @pytorch_funcify.register(TensorFromScalar) def pytorch_funcify_TensorFromScalar(op, **kwargs): def tensorfromscalar(x): diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 89e6d8553d..1be74faf17 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -5,6 +5,7 @@ import pytest import pytensor.tensor.basic as ptb +from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function from pytensor.compile.mode import get_mode from pytensor.compile.sharedvalue import SharedVariable, shared @@ -14,7 +15,7 @@ from pytensor.graph.op import Op from pytensor.raise_op import CheckAndRaise from pytensor.tensor import alloc, arange, as_tensor, empty, eye -from pytensor.tensor.type import matrix, scalar, vector +from pytensor.tensor.type import matrices, matrix, scalar, vector torch = pytest.importorskip("torch") @@ -301,3 +302,19 @@ def test_pytorch_MakeVector(): x_fg = FunctionGraph([], [x]) compare_pytorch_and_py(x_fg, []) + + +def test_pytorch_OpFromGraph(): + x, y, z = matrices("xyz") + ofg_1 = OpFromGraph([x, y], [x + y]) + ofg_2 = OpFromGraph([x, y], [x * y, x - y]) + + o1, o2 = ofg_2(y, z) + out = ofg_1(x, o1) + o2 + + xv = np.ones((2, 2), dtype=config.floatX) + yv = np.ones((2, 2), dtype=config.floatX) * 3 + zv = np.ones((2, 2), dtype=config.floatX) * 5 + + f = FunctionGraph([x, y, z], [out]) + compare_pytorch_and_py(f, [xv, yv, zv])