From c4b20eceef3ed174dba97934fbb685f0665fabe9 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Fri, 26 Jul 2024 20:59:57 -0700 Subject: [PATCH 1/7] Basic support for makeop --- pytensor/compile/__init__.py | 1 + pytensor/link/pytorch/dispatch/basic.py | 16 ++++++++++++++++ tests/link/pytorch/test_basic.py | 25 ++++++++++++++++++++++++- 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/pytensor/compile/__init__.py b/pytensor/compile/__init__.py index 04eba83290..9bd140d746 100644 --- a/pytensor/compile/__init__.py +++ b/pytensor/compile/__init__.py @@ -30,6 +30,7 @@ OPT_O3, OPT_STABILIZE, OPT_UNSAFE, + PYTORCH, AddDestroyHandler, AddFeatureOptimizer, Mode, diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index c71e1606bf..3ebe5e60e4 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -3,6 +3,8 @@ import torch +from pytensor.compile import PYTORCH +from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python @@ -132,3 +134,17 @@ def makevector(*x): return torch.tensor(x, dtype=torch_dtype) return makevector + + +@pytorch_funcify.register(OpFromGraph) +def pytorch_funcify_OpFromGraph(op, node=None, **kwargs): + _ = kwargs.pop("storage_map", None) + + PYTORCH.optimizer(op.fgraph) + fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) + + def opfromgraph(*inputs, dim=op.fgraph.outputs): + res = fgraph_fn(*inputs) + return res[0] + + return opfromgraph diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 27c1b1bd6a..571cd640fb 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -5,6 +5,7 @@ import pytest import pytensor.tensor.basic as ptb +from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function from pytensor.compile.mode import get_mode from pytensor.compile.sharedvalue import SharedVariable, shared @@ -14,7 +15,7 @@ from pytensor.graph.op import Op from pytensor.raise_op import CheckAndRaise from pytensor.tensor import alloc, arange, as_tensor, empty, eye -from pytensor.tensor.type import matrix, scalar, vector +from pytensor.tensor.type import matrices, matrix, scalar, vector torch = pytest.importorskip("torch") @@ -301,3 +302,25 @@ def test_pytorch_MakeVector(): x_fg = FunctionGraph([], [x]) compare_pytorch_and_py(x_fg, []) + + +def test_pytorch_OpFromGraph(): + x, y, z = matrices("xyz") + ofg_1 = OpFromGraph([x, y], [x + y]) + OpFromGraph([x, y], [x * y, x - y]) + + # o1, o2 = ofg_2(y, z) + # out = ofg_1(x, o1) + o2 + + out = ofg_1(y, z) + + xv = np.ones((2, 2), dtype=config.floatX) + np.ones((2, 2), dtype=config.floatX) * 3 + zv = np.ones((2, 2), dtype=config.floatX) * 5 + + f = FunctionGraph([y, z], [out]) + import pytensor.printing + + pytensor.printing.debugprint(f) + + compare_pytorch_and_py(f, [xv, zv]) From 9c64320ed596195cb496383fd7f738bc04bcef0f Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Tue, 30 Jul 2024 13:32:44 -0700 Subject: [PATCH 2/7] Clean up num args based on graph --- pytensor/link/pytorch/dispatch/basic.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 3ebe5e60e4..637fbea4c8 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -1,9 +1,9 @@ from functools import singledispatch +from operator import itemgetter from types import NoneType import torch -from pytensor.compile import PYTORCH from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph @@ -140,11 +140,9 @@ def makevector(*x): def pytorch_funcify_OpFromGraph(op, node=None, **kwargs): _ = kwargs.pop("storage_map", None) - PYTORCH.optimizer(op.fgraph) fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) - - def opfromgraph(*inputs, dim=op.fgraph.outputs): - res = fgraph_fn(*inputs) - return res[0] - - return opfromgraph + return ( + fgraph_fn + if len(op.fgraph.outputs) > 1 + else lambda *args: itemgetter(0)(fgraph_fn(*args)) + ) From fdd5d5c5ad027d9606bb83a9850cd98a76666739 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Tue, 6 Aug 2024 12:42:16 -0700 Subject: [PATCH 3/7] Disable torch dynamo --- pytensor/link/pytorch/dispatch/basic.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 637fbea4c8..285defc718 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -1,5 +1,4 @@ from functools import singledispatch -from operator import itemgetter from types import NoneType import torch @@ -140,9 +139,21 @@ def makevector(*x): def pytorch_funcify_OpFromGraph(op, node=None, **kwargs): _ = kwargs.pop("storage_map", None) + # @todo: Torch compile doesn't capture the scope accounting + # for op.fgraph, leading to an import error. Disable the + # dynamo compile for these graphs + import torch._dynamo.config + + torch._dynamo.config.suppress_errors = True + fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) - return ( - fgraph_fn - if len(op.fgraph.outputs) > 1 - else lambda *args: itemgetter(0)(fgraph_fn(*args)) - ) + if len(op.fgraph.outputs) > 1: + + def inner(*args): + return fgraph_fn(*args) + else: + + def inner(*args): + return fgraph_fn(*args)[0] + + return inner From d98e68e4da07b5e0bbcff4866a3cca8dec82b5d1 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Tue, 6 Aug 2024 12:49:23 -0700 Subject: [PATCH 4/7] Update tests --- tests/link/pytorch/test_basic.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 571cd640fb..4b0790a129 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -307,20 +307,14 @@ def test_pytorch_MakeVector(): def test_pytorch_OpFromGraph(): x, y, z = matrices("xyz") ofg_1 = OpFromGraph([x, y], [x + y]) - OpFromGraph([x, y], [x * y, x - y]) + ofg_2 = OpFromGraph([x, y], [x * y, x - y]) - # o1, o2 = ofg_2(y, z) - # out = ofg_1(x, o1) + o2 - - out = ofg_1(y, z) + o1, o2 = ofg_2(y, z) + out = ofg_1(x, o1) + o2 xv = np.ones((2, 2), dtype=config.floatX) - np.ones((2, 2), dtype=config.floatX) * 3 + yv = np.ones((2, 2), dtype=config.floatX) * 3 zv = np.ones((2, 2), dtype=config.floatX) * 5 - f = FunctionGraph([y, z], [out]) - import pytensor.printing - - pytensor.printing.debugprint(f) - - compare_pytorch_and_py(f, [xv, zv]) + f = FunctionGraph([x, y, z], [out]) + compare_pytorch_and_py(f, [xv, yv, zv]) From 10a841f0e093995a9438ab4696a808d26bba2a80 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 18 Aug 2024 09:10:40 -0700 Subject: [PATCH 5/7] Disable the opfromgraph inner function from compiling --- pytensor/link/pytorch/dispatch/basic.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 285defc718..449b9ee779 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -2,6 +2,7 @@ from types import NoneType import torch +import torch.compiler from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp @@ -139,14 +140,8 @@ def makevector(*x): def pytorch_funcify_OpFromGraph(op, node=None, **kwargs): _ = kwargs.pop("storage_map", None) - # @todo: Torch compile doesn't capture the scope accounting - # for op.fgraph, leading to an import error. Disable the - # dynamo compile for these graphs - import torch._dynamo.config - - torch._dynamo.config.suppress_errors = True - fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) + if len(op.fgraph.outputs) > 1: def inner(*args): @@ -156,4 +151,11 @@ def inner(*args): def inner(*args): return fgraph_fn(*args)[0] - return inner + # Don't compile the inner function + # This is due torch failing to create + # guards when parent scoped closure variables + # are used in conditional statements. + # Instead of rewriting many portions of code + # this will allow for only this small section to + # not be compiled by the outer graph + return torch.compiler.disable(inner) From b29be458b099cdea0eedbe65418b060261c6c965 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 1 Sep 2024 17:50:04 +0200 Subject: [PATCH 6/7] Only disable one level of inlining --- pytensor/link/pytorch/dispatch/basic.py | 29 ++++++++----------------- tests/link/pytorch/test_basic.py | 4 ++-- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 449b9ee779..6faf38ae0d 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -4,6 +4,7 @@ import torch import torch.compiler +from pytensor.compile import PYTORCH from pytensor.compile.builders import OpFromGraph from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph @@ -137,25 +138,13 @@ def makevector(*x): @pytorch_funcify.register(OpFromGraph) -def pytorch_funcify_OpFromGraph(op, node=None, **kwargs): - _ = kwargs.pop("storage_map", None) +def pytorch_funcify_OpFromGraph(op, node, **kwargs): + kwargs.pop("storage_map", None) - fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) + # Apply inner rewrites + PYTORCH.optimizer(op.fgraph) - if len(op.fgraph.outputs) > 1: - - def inner(*args): - return fgraph_fn(*args) - else: - - def inner(*args): - return fgraph_fn(*args)[0] - - # Don't compile the inner function - # This is due torch failing to create - # guards when parent scoped closure variables - # are used in conditional statements. - # Instead of rewriting many portions of code - # this will allow for only this small section to - # not be compiled by the outer graph - return torch.compiler.disable(inner) + fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) + # Disable one step inlining to prevent torch from trying to import local functions + # defined in `pytorch_funcify` + return torch.compiler.disable(fgraph_fn, recursive=False) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 4b0790a129..1945790fda 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -68,9 +68,9 @@ def compare_pytorch_and_py( if len(fgraph.outputs) > 1: for j, p in zip(pytorch_res, py_res): - assert_fn(j.cpu(), p) + assert_fn(j.detach().cpu().numpy(), p) else: - assert_fn([pytorch_res[0].cpu()], py_res) + assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0]) return pytensor_torch_fn, pytorch_res From 0f18d8d741c07d26167837e9b3060ca4ce1a6fe1 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Tue, 17 Sep 2024 09:50:17 -0700 Subject: [PATCH 7/7] Lint --- pytensor/link/pytorch/dispatch/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 7e2a834916..e2edcf0fe4 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -165,7 +165,7 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs): # defined in `pytorch_funcify` return torch.compiler.disable(fgraph_fn, recursive=False) - + @pytorch_funcify.register(TensorFromScalar) def pytorch_funcify_TensorFromScalar(op, **kwargs): def tensorfromscalar(x):