Skip to content

Commit

Permalink
Implement OpFromGraph in PyTorch backend (#956)
Browse files Browse the repository at this point in the history
Co-authored-by: Ian Schweer <ischweer@riotgames.com>
Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
  • Loading branch information
3 people authored Sep 17, 2024
1 parent 3e55a20 commit ba4fcbe
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
1 change: 1 addition & 0 deletions pytensor/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OPT_O3,
OPT_STABILIZE,
OPT_UNSAFE,
PYTORCH,
AddDestroyHandler,
AddFeatureOptimizer,
Mode,
Expand Down
16 changes: 16 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import numpy as np
import torch
import torch.compiler

from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
Expand Down Expand Up @@ -150,6 +153,19 @@ def makevector(*x):
return makevector


@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)

# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)

fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
# Disable one step inlining to prevent torch from trying to import local functions
# defined in `pytorch_funcify`
return torch.compiler.disable(fgraph_fn, recursive=False)


@pytorch_funcify.register(TensorFromScalar)
def pytorch_funcify_TensorFromScalar(op, **kwargs):
def tensorfromscalar(x):
Expand Down
19 changes: 18 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -14,7 +15,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector
from pytensor.tensor.type import matrices, matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -301,3 +302,19 @@ def test_pytorch_MakeVector():
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])


def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
ofg_2 = OpFromGraph([x, y], [x * y, x - y])

o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2

xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])

0 comments on commit ba4fcbe

Please sign in to comment.