Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpFromGraph in PyTorch backend #956

Merged
merged 8 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OPT_O3,
OPT_STABILIZE,
OPT_UNSAFE,
PYTORCH,
twiecki marked this conversation as resolved.
Show resolved Hide resolved
AddDestroyHandler,
AddFeatureOptimizer,
Mode,
Expand Down
27 changes: 27 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from types import NoneType

import torch
import torch.compiler

from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
Expand Down Expand Up @@ -132,3 +134,28 @@ def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector


@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)

fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to compile the inner function? Is that a thing in PyTorch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following what numba does where it jits the inner function - we could remove the inner torch.compile and just return op.fgraph if that seems more reasonable. That will still lead to some c-linker issues fwiw.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the inner function, you only need to do indexing if the number of return values is more than 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numba can only have inner compiled functions, I don't know if that's a requirement in pytorch, and whether it has any advantages. We don't do it for JAX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see / know of any requirement to have an inner compiled function.


if len(op.fgraph.outputs) > 1:

def inner(*args):
return fgraph_fn(*args)
else:

def inner(*args):
return fgraph_fn(*args)[0]

# Don't compile the inner function
# This is due torch failing to create
# guards when parent scoped closure variables
# are used in conditional statements.
# Instead of rewriting many portions of code
# this will allow for only this small section to
# not be compiled by the outer graph
return torch.compiler.disable(inner)
Copy link
Member

@ricardoV94 ricardoV94 Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because of the two inner functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has something bizarre to do with the combination of fgraph_fn being a bunch of nested functions, and this inner function being nested. The bigger part of that torch compiler isn't super great at handling conditionals user closure variables, at least in pytensor. It would probably need a much deeper dive. It looks like it might be something that can happen with other functions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's worrisome. What error did you get without this disabling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the error above in a comment, but it's essentially going to say the generated code from pytensor can't find some functions (all the inner functions returned in torch dispatch)

19 changes: 18 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -14,7 +15,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector
from pytensor.tensor.type import matrices, matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -301,3 +302,19 @@ def test_pytorch_MakeVector():
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])


def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
ofg_2 = OpFromGraph([x, y], [x * y, x - y])

o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2

xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

f = FunctionGraph([x, y, z], [out])
compare_pytorch_and_py(f, [xv, yv, zv])
Loading