Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OpFromGraph in PyTorch backend #956

Merged
merged 8 commits into from
Sep 17, 2024

Conversation

Ch0ronomato
Copy link
Contributor

Torch OpFromGraph

Allows us to precompile a subset of the graph via torch.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@Ch0ronomato Ch0ronomato marked this pull request as draft July 27, 2024 04:01
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))

def opfromgraph(*inputs, dim=op.fgraph.outputs):
res = fgraph_fn(*inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code that is generated for fgraph_fn doesn't have enough of the locals / globals atm to actually execute the function

torch._dynamo.exc.InternalTorchDynamoError: module 'pytensor.link.utils' has no attribute 'elemwise_fn'

I'm working on fixing it

Copy link
Contributor Author

@Ch0ronomato Ch0ronomato Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @ricardoV94 ; I spent of time debugging this but I'm not able to get anyway. Some things I've tried

  1. I tried setting the local_env and global_env passed to compile_to_src to have those functions included in the callable that gets compiled by torch, but that didn't seem to help
  2. I enabled eager mode execution, and it threw an error from the pytensor c linker - should that be happening? I disabled graph optimizations as well, and completely removing torch.compile (in the Pytorch JITLinker)

Could you clarify for me if the c linker should be running, and if not what stuff I might be missing? Otherwise, if you happen to have suggestions on what to check lmk.
Also I tried joining the gitter, but the link seems broken 😓

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry we don't have a gitter, where did you find the link so we can remove it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No C code shouldn't be involved although I guess it could be called during optimization of the inner graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it - for the gitter it's under the community tag here https://pytensor.readthedocs.io/en/latest/index.html#pytensor-community.

The c error in question if it helps - however I did try to run these with PYTENSOR_FLAGS="optimizer=None" and fast-compile but I was still seeing the error

>           raise CompileError(
                f"Compilation failed (return status={status}):\n{' '.join(cmd)}\n{compile_stderr}"
E               pytensor.link.c.exceptions.CompileError: Compilation failed (return status=1):
E               /opt/anaconda3/envs/pytensor-dev/bin/clang++ -dynamiclib -g -O3 -fno-math-errno -Wno-unused-label -Wno-unused-variable -Wno-write-strings -DNPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION -fPIC -undefined dynamic_lookup -I/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/numpy/core/include -I/opt/anaconda3/envs/pytensor-dev/include/python3.10 -I/Users/ischweer/dev/pytensor/pytensor/link/c/c_code -L/opt/anaconda3/envs/pytensor-dev/lib -fvisibility=hidden -o /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/medd18106a84c2e4f7b8beb5e52d5bb1d3f6f08a810be55e8472522fafe746dd7.so /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:9: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E                       V5_stride0, V5_stride1, 
E                       ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:9: note: insert an explicit cast to silence this issue
E                       V5_stride0, V5_stride1, 
E                       ^~~~~~~~~~
E                       static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:21: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E                       V5_stride0, V5_stride1, 
E                                   ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:585:21: note: insert an explicit cast to silence this issue
E                       V5_stride0, V5_stride1, 
E                                   ^~~~~~~~~~
E                                   static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:1: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E               V3_stride0, V3_stride1, 
E               ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:1: note: insert an explicit cast to silence this issue
E               V3_stride0, V3_stride1, 
E               ^~~~~~~~~~
E               static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:13: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E               V3_stride0, V3_stride1, 
E                           ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:586:13: note: insert an explicit cast to silence this issue
E               V3_stride0, V3_stride1, 
E                           ^~~~~~~~~~
E                           static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:1: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E               V1_stride0, V1_stride1
E               ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:1: note: insert an explicit cast to silence this issue
E               V1_stride0, V1_stride1
E               ^~~~~~~~~~
E               static_cast<int>( )
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:13: error: non-constant-expression cannot be narrowed from type 'ssize_t' (aka 'long') to 'int' in initializer list [-Wc++11-narrowing]
E               V1_stride0, V1_stride1
E                           ^~~~~~~~~~
E               /Users/ischweer/.pytensor/compiledir_macOS-14.3-arm64-arm-64bit-arm-3.10.14-64/tmpqcvbz8_y/mod.cpp:587:13: note: insert an explicit cast to silence this issue
E               V1_stride0, V1_stride1
E                           ^~~~~~~~~~
E                           static_cast<int>( )
E               6 errors generated.
E               
E               Apply node that caused the error: Add(*0-<Matrix(float64, shape=(?, ?))>, *1-<Matrix(float64, shape=(?, ?))>)
E               Toposort index: 0
E               Inputs types: [TensorType(float64, shape=(None, None)), TensorType(float64, shape=(None, None))]
E               
E               Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_callers.py", line 103, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/_pytest/runner.py", line 173, in pytest_runtest_call
E                   item.runtest()
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/_pytest/python.py", line 1632, in runtest
E                   self.ihook.pytest_pyfunc_call(pyfuncitem=self)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_hooks.py", line 513, in __call__
E                   return self._hookexec(self.name, self._hookimpls.copy(), kwargs, firstresult)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_manager.py", line 120, in _hookexec
E                   return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/pluggy/_callers.py", line 103, in _multicall
E                   res = hook_impl.function(*args)
E                 File "/opt/anaconda3/envs/pytensor-dev/lib/python3.10/site-packages/_pytest/python.py", line 162, in pytest_pyfunc_call
E                   result = testfunction(**testargs)
E                 File "/Users/ischweer/dev/pytensor/tests/link/pytorch/test_basic.py", line 309, in test_pytorch_OpFromGraph
E                   ofg_1 = OpFromGraph([x, y], [x + y])
E               
E               HINT: Use a linker other than the C linker to print the inputs' shapes and strides.
E               HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
E               Apply node that caused the error: OpFromGraph{inline=False}(y, z)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One think I'm not sure is if those are just like, m1 setup specific issues I'm having (could be some warnings being more permissive than others) - if that's the case then i think eager execution is probably fine.

Copy link
Contributor Author

@Ch0ronomato Ch0ronomato Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've confirmed the above ^, if you set gcc__cxxflags=-Wno-c++11-narrowing and enable eager mode execution you get this to work.

PYTENSOR_FLAGS="gcc__cxxflags=-Wno-c++11-narrowing" TORCHDYNAMO_SUPPRESS_ERRORS="True" pytest -k OpFromGraph tests/link/pytorch/test_basic.py
======================================================================= test session starts ========================================================================
platform darwin -- Python 3.10.14, pytest-8.2.2, pluggy-1.5.0
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /Users/ischweer/dev/pytensor
configfile: pyproject.toml
plugins: cov-5.0.0, sphinx-0.6.3, mock-3.14.0, benchmark-4.0.0, xdist-3.6.1
collected 14 items / 13 deselected / 1 selected                                                                                                                    

tests/link/pytorch/test_basic.py .                                                                                                                           [100%]

========================================================================= warnings summary =========================================================================
pytensor/configdefaults.py:375
  /Users/ischweer/dev/pytensor/pytensor/configdefaults.py:375: DeprecationWarning: Use shutil.which instead of find_executable
    newp = find_executable(param)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================================== 1 passed, 13 deselected, 1 warning in 4.64s ============================================================
(pytensor-dev) ➜  pytensor git:(makeopfromgraph) ✗ 

I'm not sure how I can add that env variable for only osx, maybe I can update a doc or smth, or even in the new m1 environment yaml? Otherwise, this all works now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I marked the import issue as a todo in the code, and was gonna open an issue. Is that okay @ricardoV94 ?

@Ch0ronomato Ch0ronomato changed the title Basic support for makeop Basic support for opfromgraph Jul 29, 2024
@Ch0ronomato Ch0ronomato mentioned this pull request Jul 29, 2024
5 tasks
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a stab, I'll have to take a look to understand the problem myself. No specific advice yet

_ = kwargs.pop("storage_map", None)

PYTORCH.optimizer(op.fgraph)
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to compile the inner function? Is that a thing in PyTorch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following what numba does where it jits the inner function - we could remove the inner torch.compile and just return op.fgraph if that seems more reasonable. That will still lead to some c-linker issues fwiw.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the inner function, you only need to do indexing if the number of return values is more than 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numba can only have inner compiled functions, I don't know if that's a requirement in pytorch, and whether it has any advantages. We don't do it for JAX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see / know of any requirement to have an inner compiled function.

pytensor/link/pytorch/dispatch/basic.py Outdated Show resolved Hide resolved
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))

def opfromgraph(*inputs, dim=op.fgraph.outputs):
res = fgraph_fn(*inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry we don't have a gitter, where did you find the link so we can remove it?

fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))

def opfromgraph(*inputs, dim=op.fgraph.outputs):
res = fgraph_fn(*inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No C code shouldn't be involved although I guess it could be called during optimization of the inner graph

@Ch0ronomato
Copy link
Contributor Author

Ch0ronomato commented Jul 30, 2024

@ricardoV94 I noted what the changed ended up being, I'm not sure if that is the right call but at least for now the tests pass so I'm going to put it into ready. It will fail on the build machine with this because we don't okay eager mode

FAILED tests/link/pytorch/test_basic.py::test_pytorch_OpFromGraph - torch._dynamo.exc.InternalTorchDynamoError: module 'pytensor.link.utils' has no attribute 'elemwise_fn'

@Ch0ronomato Ch0ronomato marked this pull request as ready for review July 30, 2024 20:34
tests/link/pytorch/test_basic.py Outdated Show resolved Hide resolved
tests/link/pytorch/test_basic.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Aug 3, 2024
@ricardoV94 ricardoV94 changed the title Basic support for opfromgraph Implement OpFromGraph in PyTorch backend Aug 3, 2024
Copy link

codecov bot commented Aug 6, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.74%. Comparing base (3e55a20) to head (0f18d8d).
Report is 2 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #956   +/-   ##
=======================================
  Coverage   81.74%   81.74%           
=======================================
  Files         183      183           
  Lines       47724    47733    +9     
  Branches    11616    11616           
=======================================
+ Hits        39011    39020    +9     
  Misses       6518     6518           
  Partials     2195     2195           
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/basic.py 93.93% <100.00%> (+0.60%) ⬆️

# Instead of rewriting many portions of code
# this will allow for only this small section to
# not be compiled by the outer graph
return torch.compiler.disable(inner)
Copy link
Member

@ricardoV94 ricardoV94 Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because of the two inner functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has something bizarre to do with the combination of fgraph_fn being a bunch of nested functions, and this inner function being nested. The bigger part of that torch compiler isn't super great at handling conditionals user closure variables, at least in pytensor. It would probably need a much deeper dive. It looks like it might be something that can happen with other functions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's worrisome. What error did you get without this disabling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the error above in a comment, but it's essentially going to say the generated code from pytensor can't find some functions (all the inner functions returned in torch dispatch)

@ricardoV94
Copy link
Member

@Ch0ronomato I tried something different in a commit I just pushed. I'm disabling only one level of the stack, so the inner code is still compiled? Does that make sense?

The problem we're facing seems to be with PyTorch trying to import the dynamically generated elemwise_fn from the module where fgraph_to_python is defined, but this is non-sensical as that function is defined in an inner scope, not at the module level. I also checked that allow_in_graph works, but according to the torch docs this disables safety checks, and shouldn't be used with functions that can mutate the inputs, so I don't think it's safe to do in our case?

@Ch0ronomato
Copy link
Contributor Author

I see - yea using the fine grain api does seem to be the only way. I like that yours doesn't do recursive disabling, where the api I used i think does. I like it!

@Ch0ronomato
Copy link
Contributor Author

@ricardoV94 - I do see the same issue (I think) with this impl https://github.com/pymc-devs/pytensor/actions/runs/10763803733/job/29845771297?pr=988.

One other thing I'm not sure I get - why does the error go away when we remove doing if / else things declared outside the inner function.

@ricardoV94
Copy link
Member

why does the error go away when we remove doing if / else things declared outside the inner function.

What do you mean. Can you show the code that doesn't error out vs the one that does?

@ricardoV94
Copy link
Member

Seems like we need to resolve some trivial conflicts.

@ricardoV94 ricardoV94 merged commit 47478e6 into pymc-devs:main Sep 17, 2024
59 of 60 checks passed
HelloBroBro pushed a commit to HelloBroBro/pytensor that referenced this pull request Sep 18, 2024
Co-authored-by: Ian Schweer <ischweer@riotgames.com>
Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Ch0ronomato added a commit to Ch0ronomato/pytensor that referenced this pull request Nov 2, 2024
Co-authored-by: Ian Schweer <ischweer@riotgames.com>
Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Ch0ronomato added a commit to Ch0ronomato/pytensor that referenced this pull request Nov 2, 2024
Co-authored-by: Ian Schweer <ischweer@riotgames.com>
Co-authored-by: Ricardo Vieira <ricardo.vieira1994@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants