-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement OpFromGraph in PyTorch backend #956
Conversation
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) | ||
|
||
def opfromgraph(*inputs, dim=op.fgraph.outputs): | ||
res = fgraph_fn(*inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code that is generated for fgraph_fn
doesn't have enough of the locals / globals atm to actually execute the function
torch._dynamo.exc.InternalTorchDynamoError: module 'pytensor.link.utils' has no attribute 'elemwise_fn'
I'm working on fixing it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hey @ricardoV94 ; I spent of time debugging this but I'm not able to get anyway. Some things I've tried
- I tried setting the local_env and global_env passed to
compile_to_src
to have those functions included in the callable that gets compiled by torch, but that didn't seem to help - I enabled eager mode execution, and it threw an error from the pytensor c linker - should that be happening? I disabled graph optimizations as well, and completely removing torch.compile (in the Pytorch JITLinker)
Could you clarify for me if the c linker should be running, and if not what stuff I might be missing? Otherwise, if you happen to have suggestions on what to check lmk.
Also I tried joining the gitter, but the link seems broken 😓
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry we don't have a gitter, where did you find the link so we can remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No C code shouldn't be involved although I guess it could be called during optimization of the inner graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it - for the gitter it's under the community tag here https://pytensor.readthedocs.io/en/latest/index.html#pytensor-community.
The c error in question if it helps - however I did try to run these with PYTENSOR_FLAGS="optimizer=None"
and fast-compile
but I was still seeing the error
> raise CompileError(
f"Compilation failed (return status={status}):\n{' '.join(cmd)}\n{compile_stderr}"
E pytensor.link.c.exceptions.CompileError: Compilation failed (return status=1):
E /opt/anaconda3/envs/pytensor-dev/bin/clang++ -dynamiclib -g -O3 -fno-math-errno -Wno-unused-label -Wno-unused-variable -Wno-write-strings -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION -fPIC -undefined dynamic_lookup -I/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/numpy/core/include -I/opt/anaconda3/envs/pytensor-dev/include/python3.10 -I/Users/ischweer/dev/pytensor/pytensor/link/c/c_code -L/opt/anaconda3/envs/pytensor-dev/lib -fvisibility=hidden -o /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/medd18106a84c2e4f7b8beb5e52d5bb1d3f6f08a810be55e8472522fafe746dd7.so /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:9: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E V5_stride0, V5_stride1,
E ^~~~~~~~~~
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:9: note: insert an explicit cast to silence this issue
E V5_stride0, V5_stride1,
E ^~~~~~~~~~
E static_cast<int>( )
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:21: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E V5_stride0, V5_stride1,
E ^~~~~~~~~~
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:21: note: insert an explicit cast to silence this issue
E V5_stride0, V5_stride1,
E ^~~~~~~~~~
E static_cast<int>( )
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:1: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E V3_stride0, V3_stride1,
E ^~~~~~~~~~
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:1: note: insert an explicit cast to silence this issue
E V3_stride0, V3_stride1,
E ^~~~~~~~~~
E static_cast<int>( )
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:13: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E V3_stride0, V3_stride1,
E ^~~~~~~~~~
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:13: note: insert an explicit cast to silence this issue
E V3_stride0, V3_stride1,
E ^~~~~~~~~~
E static_cast<int>( )
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:1: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E V1_stride0, V1_stride1
E ^~~~~~~~~~
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:1: note: insert an explicit cast to silence this issue
E V1_stride0, V1_stride1
E ^~~~~~~~~~
E static_cast<int>( )
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:13: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E V1_stride0, V1_stride1
E ^~~~~~~~~~
E /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:13: note: insert an explicit cast to silence this issue
E V1_stride0, V1_stride1
E ^~~~~~~~~~
E static_cast<int>( )
E 6 errors generated.
E
E Apply node that caused the error: Add(*0-<Matrix(float64, shape=(?, ?))>, *1-<Matrix(float64, shape=(?, ?))>)
E Toposort index: 0
E Inputs types: [TensorType(float64, shape=(None, None)), TensorType(float64, shape=(None, None))]
E
E Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
E File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_callers.py", line 103, in _multicall
E res = hook_impl.function(*args)
E File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/_pytest/runner.py", line 173, in pytest_runtest_call
E item.runtest()
E File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/_pytest/python.py", line 1632, in runtest
E self.ihook.pytest_pyfunc_call(pyfuncitem=self)
E File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 513, in __call__
E return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
E File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_manager.py", line 120, in _hookexec
E return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_callers.py", line 103, in _multicall
E res = hook_impl.function(*args)
E File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/_pytest/python.py", line 162, in pytest_pyfunc_call
E result = testfunction(**testargs)
E File "/Users/ischweer/dev/pytensor/tests/link/pytorch/test_basic.py", line 309, in test_pytorch_OpFromGraph
E ofg_1 = OpFromGraph([x, y], [x + y])
E
E HINT: Use a linker other than the C linker to print the inputs' shapes and strides.
E HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
E Apply node that caused the error: OpFromGraph{inline=False}(y, z)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One think I'm not sure is if those are just like, m1 setup specific issues I'm having (could be some warnings being more permissive than others) - if that's the case then i think eager execution is probably fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've confirmed the above ^, if you set gcc__cxxflags=-Wno-c++11-narrowing
and enable eager mode execution you get this to work.
PYTENSOR_FLAGS="gcc__cxxflags=-Wno-c++11-narrowing" TORCHDYNAMO_SUPPRESS_ERRORS="True" pytest -k OpFromGraph tests/link/pytorch/test_basic.py
======================================================================= test session starts ========================================================================
platform darwin -- Python 3.10.14, pytest-8.2.2, pluggy-1.5.0
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /Users/ischweer/dev/pytensor
configfile: pyproject.toml
plugins: cov-5.0.0, sphinx-0.6.3, mock-3.14.0, benchmark-4.0.0, xdist-3.6.1
collected 14 items / 13 deselected / 1 selected
tests/link/pytorch/test_basic.py . [100%]
========================================================================= warnings summary =========================================================================
pytensor/configdefaults.py:375
/Users/ischweer/dev/pytensor/pytensor/configdefaults.py:375: DeprecationWarning: Use shutil.which instead of find_executable
newp = find_executable(param)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================================== 1 passed, 13 deselected, 1 warning in 4.64s ============================================================
(pytensor-dev) ➜ pytensor git:(makeopfromgraph) ✗
I'm not sure how I can add that env variable for only osx, maybe I can update a doc or smth, or even in the new m1 environment yaml? Otherwise, this all works now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I marked the import issue as a todo in the code, and was gonna open an issue. Is that okay @ricardoV94 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for taking a stab, I'll have to take a look to understand the problem myself. No specific advice yet
_ = kwargs.pop("storage_map", None) | ||
|
||
PYTORCH.optimizer(op.fgraph) | ||
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to compile the inner function? Is that a thing in PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was following what numba does where it jits the inner function - we could remove the inner torch.compile and just return op.fgraph if that seems more reasonable. That will still lead to some c-linker issues fwiw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the inner function, you only need to do indexing if the number of return values is more than 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Numba can only have inner compiled functions, I don't know if that's a requirement in pytorch, and whether it has any advantages. We don't do it for JAX
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not see / know of any requirement to have an inner compiled function.
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) | ||
|
||
def opfromgraph(*inputs, dim=op.fgraph.outputs): | ||
res = fgraph_fn(*inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry we don't have a gitter, where did you find the link so we can remove it?
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) | ||
|
||
def opfromgraph(*inputs, dim=op.fgraph.outputs): | ||
res = fgraph_fn(*inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No C code shouldn't be involved although I guess it could be called during optimization of the inner graph
@ricardoV94 I noted what the changed ended up being, I'm not sure if that is the right call but at least for now the tests pass so I'm going to put it into ready. It will fail on the build machine with this because we don't okay eager mode
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #956 +/- ##
=======================================
Coverage 81.74% 81.74%
=======================================
Files 183 183
Lines 47724 47733 +9
Branches 11616 11616
=======================================
+ Hits 39011 39020 +9
Misses 6518 6518
Partials 2195 2195
|
# Instead of rewriting many portions of code | ||
# this will allow for only this small section to | ||
# not be compiled by the outer graph | ||
return torch.compiler.disable(inner) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because of the two inner functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has something bizarre to do with the combination of fgraph_fn
being a bunch of nested functions, and this inner function being nested. The bigger part of that torch compiler isn't super great at handling conditionals user closure variables, at least in pytensor. It would probably need a much deeper dive. It looks like it might be something that can happen with other functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's worrisome. What error did you get without this disabling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the error above in a comment, but it's essentially going to say the generated code from pytensor can't find some functions (all the inner functions returned in torch dispatch)
@Ch0ronomato I tried something different in a commit I just pushed. I'm disabling only one level of the stack, so the inner code is still compiled? Does that make sense? The problem we're facing seems to be with PyTorch trying to import the dynamically generated |
I see - yea using the fine grain api does seem to be the only way. I like that yours doesn't do recursive disabling, where the api I used i think does. I like it! |
af6ac6c
to
b29be45
Compare
@ricardoV94 - I do see the same issue (I think) with this impl https://github.com/pymc-devs/pytensor/actions/runs/10763803733/job/29845771297?pr=988. One other thing I'm not sure I get - why does the error go away when we remove doing if / else things declared outside the inner function. |
What do you mean. Can you show the code that doesn't error out vs the one that does? |
Seems like we need to resolve some trivial conflicts. |
Co-authored-by: Ian Schweer <ischweer@riotgames.com> Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-authored-by: Ian Schweer <ischweer@riotgames.com> Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Co-authored-by: Ian Schweer <ischweer@riotgames.com> Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Torch OpFromGraph
Allows us to precompile a subset of the graph via torch.
Related Issue
Checklist
Type of change