Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpFromGraph in PyTorch backend #956

Merged
merged 8 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
OPT_O3,
OPT_STABILIZE,
OPT_UNSAFE,
PYTORCH,
twiecki marked this conversation as resolved.
Show resolved Hide resolved
AddDestroyHandler,
AddFeatureOptimizer,
Mode,
Expand Down
16 changes: 16 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import torch

from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
Expand Down Expand Up @@ -132,3 +134,17 @@ def makevector(*x):
return torch.tensor(x, dtype=torch_dtype)

return makevector


@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)

PYTORCH.optimizer(op.fgraph)
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to compile the inner function? Is that a thing in PyTorch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following what numba does where it jits the inner function - we could remove the inner torch.compile and just return op.fgraph if that seems more reasonable. That will still lead to some c-linker issues fwiw.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the inner function, you only need to do indexing if the number of return values is more than 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numba can only have inner compiled functions, I don't know if that's a requirement in pytorch, and whether it has any advantages. We don't do it for JAX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see / know of any requirement to have an inner compiled function.


def opfromgraph(*inputs, dim=op.fgraph.outputs):
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved
res = fgraph_fn(*inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code that is generated for fgraph_fn doesn't have enough of the locals / globals atm to actually execute the function

torch._dynamo.exc.InternalTorchDynamoError: module 'pytensor.link.utils' has no attribute 'elemwise_fn'

I'm working on fixing it

Copy link
Contributor Author

@Ch0ronomato Ch0ronomato Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @ricardoV94 ; I spent of time debugging this but I'm not able to get anyway. Some things I've tried

  1. I tried setting the local_env and global_env passed to compile_to_src to have those functions included in the callable that gets compiled by torch, but that didn't seem to help
  2. I enabled eager mode execution, and it threw an error from the pytensor c linker - should that be happening? I disabled graph optimizations as well, and completely removing torch.compile (in the Pytorch JITLinker)

Could you clarify for me if the c linker should be running, and if not what stuff I might be missing? Otherwise, if you happen to have suggestions on what to check lmk.
Also I tried joining the gitter, but the link seems broken 😓

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry we don't have a gitter, where did you find the link so we can remove it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No C code shouldn't be involved although I guess it could be called during optimization of the inner graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it - for the gitter it's under the community tag here https://pytensor.readthedocs.io/en/latest/index.html#pytensor-community.

The c error in question if it helps - however I did try to run these with PYTENSOR_FLAGS="optimizer=None" and fast-compile but I was still seeing the error

>           raise CompileError(
                f"Compilation failed (return status={status}):\n{' '.join(cmd)}\n{compile_stderr}"
E               pytensor.link.c.exceptions.CompileError: Compilation failed (return status=1):
E               /opt/anaconda3/envs/pytensor-dev/bin/clang++ -dynamiclib -g -O3 -fno-math-errno -Wno-unused-label -Wno-unused-variable -Wno-write-strings -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION -fPIC -undefined dynamic_lookup -I/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/numpy/core/include -I/opt/anaconda3/envs/pytensor-dev/include/python3.10 -I/Users/ischweer/dev/pytensor/pytensor/link/c/c_code -L/opt/anaconda3/envs/pytensor-dev/lib -fvisibility=hidden -o /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/medd18106a84c2e4f7b8beb5e52d5bb1d3f6f08a810be55e8472522fafe746dd7.so /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:9: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E                       V5_stride0, V5_stride1, 
E                       ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:9: note: insert an explicit cast to silence this issue
E                       V5_stride0, V5_stride1, 
E                       ^~~~~~~~~~
E                       static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:21: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E                       V5_stride0, V5_stride1, 
E                                   ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:21: note: insert an explicit cast to silence this issue
E                       V5_stride0, V5_stride1, 
E                                   ^~~~~~~~~~
E                                   static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:1: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E               V3_stride0, V3_stride1, 
E               ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:1: note: insert an explicit cast to silence this issue
E               V3_stride0, V3_stride1, 
E               ^~~~~~~~~~
E               static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:13: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E               V3_stride0, V3_stride1, 
E                           ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:13: note: insert an explicit cast to silence this issue
E               V3_stride0, V3_stride1, 
E                           ^~~~~~~~~~
E                           static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:1: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E               V1_stride0, V1_stride1
E               ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:1: note: insert an explicit cast to silence this issue
E               V1_stride0, V1_stride1
E               ^~~~~~~~~~
E               static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:13: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E               V1_stride0, V1_stride1
E                           ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:13: note: insert an explicit cast to silence this issue
E               V1_stride0, V1_stride1
E                           ^~~~~~~~~~
E                           static_cast<int>( )
E               6 errors generated.
E               
E               Apply node that caused the error: Add(*0-<Matrix(float64, shape=(?, ?))>, *1-<Matrix(float64, shape=(?, ?))>)
E               Toposort index: 0
E               Inputs types: [TensorType(float64, shape=(None, None)), TensorType(float64, shape=(None, None))]
E               
E               Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_callers.py", line 103, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/_pytest/runner.py", line 173, in pytest_runtest_call
E                   item.runtest()
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/_pytest/python.py", line 1632, in runtest
E                   self.ihook.pytest_pyfunc_call(pyfuncitem=self)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 513, in __call__
E                   return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_manager.py", line 120, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_callers.py", line 103, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/_pytest/python.py", line 162, in pytest_pyfunc_call
E                   result = testfunction(**testargs)
E                 File "/Users/ischweer/dev/pytensor/tests/link/pytorch/test_basic.py", line 309, in test_pytorch_OpFromGraph
E                   ofg_1 = OpFromGraph([x, y], [x + y])
E               
E               HINT: Use a linker other than the C linker to print the inputs' shapes and strides.
E               HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
E               Apply node that caused the error: OpFromGraph{inline=False}(y, z)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One think I'm not sure is if those are just like, m1 setup specific issues I'm having (could be some warnings being more permissive than others) - if that's the case then i think eager execution is probably fine.

Copy link
Contributor Author

@Ch0ronomato Ch0ronomato Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've confirmed the above ^, if you set gcc__cxxflags=-Wno-c++11-narrowing and enable eager mode execution you get this to work.

PYTENSOR_FLAGS="gcc__cxxflags=-Wno-c++11-narrowing" TORCHDYNAMO_SUPPRESS_ERRORS="True" pytest -k OpFromGraph tests/link/pytorch/test_basic.py
======================================================================= test session starts ========================================================================
platform darwin -- Python 3.10.14, pytest-8.2.2, pluggy-1.5.0
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /Users/ischweer/dev/pytensor
configfile: pyproject.toml
plugins: cov-5.0.0, sphinx-0.6.3, mock-3.14.0, benchmark-4.0.0, xdist-3.6.1
collected 14 items / 13 deselected / 1 selected                                                                                                                    

tests/link/pytorch/test_basic.py .                                                                                                                           [100%]

========================================================================= warnings summary =========================================================================
pytensor/configdefaults.py:375
  /Users/ischweer/dev/pytensor/pytensor/configdefaults.py:375: DeprecationWarning: Use shutil.which instead of find_executable
    newp = find_executable(param)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================================== 1 passed, 13 deselected, 1 warning in 4.64s ============================================================
(pytensor-dev) ➜  pytensor git:(makeopfromgraph) ✗ 

I'm not sure how I can add that env variable for only osx, maybe I can update a doc or smth, or even in the new m1 environment yaml? Otherwise, this all works now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I marked the import issue as a todo in the code, and was gonna open an issue. Is that okay @ricardoV94 ?

return res[0]

return opfromgraph
25 changes: 24 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
Expand All @@ -14,7 +15,7 @@
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import alloc, arange, as_tensor, empty, eye
from pytensor.tensor.type import matrix, scalar, vector
from pytensor.tensor.type import matrices, matrix, scalar, vector


torch = pytest.importorskip("torch")
Expand Down Expand Up @@ -301,3 +302,25 @@ def test_pytorch_MakeVector():
x_fg = FunctionGraph([], [x])

compare_pytorch_and_py(x_fg, [])


def test_pytorch_OpFromGraph():
x, y, z = matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y])
OpFromGraph([x, y], [x * y, x - y])

# o1, o2 = ofg_2(y, z)
# out = ofg_1(x, o1) + o2
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved

out = ofg_1(y, z)

xv = np.ones((2, 2), dtype=config.floatX)
np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5

f = FunctionGraph([y, z], [out])
import pytensor.printing

pytensor.printing.debugprint(f)
Ch0ronomato marked this conversation as resolved.
Show resolved Hide resolved

compare_pytorch_and_py(f, [xv, zv])
Loading