Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpFromGraph in PyTorch backend #956

Merged
merged 8 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OPT_O3,
OPT_STABILIZE,
OPT_UNSAFE,
PYTORCH,
twiecki marked this conversation as resolved.
Show resolved Hide resolved
AddDestroyHandler,
AddFeatureOptimizer,
Mode,
Expand Down
16 changes: 16 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import numpy as np
import torch
import torch.compiler

from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
Expand Down Expand Up @@ -150,6 +153,19 @@ def makevector(*x):
return makevector


@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)

# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)

fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
# Disable one step inlining to prevent torch from trying to import local functions
# defined in `pytorch_funcify`
return torch.compiler.disable(fgraph_fn, recursive=False)


@pytorch_funcify.register(TensorFromScalar)
def pytorch_funcify_TensorFromScalar(op, **kwargs):
def tensorfromscalar(x):
Expand Down
19 changes: 18 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -14,7 +15,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector
from pytensor.tensor.type import matrices, matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -301,3 +302,19 @@ def test_pytorch_MakeVector():
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])


def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
ofg_2 = OpFromGraph([x, y], [x * y, x - y])

o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2

xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])
Loading