-
Notifications
You must be signed in to change notification settings - Fork 23k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Discussion] How to add MPS extension with custom kernel? #81103
Comments
One way to workaround the problem is to dynamically compile metal kernel, for example see following snippet from #78619 pytorch/aten/src/ATen/native/mps/operations/RangeFactories.mm Lines 36 to 40 in 3f06d17
|
@malfet I see. Is there a roadmap for the custom Metal extension? |
Hi @grimoire , thanks for starting this thread and adding support for Custom kernels. On a related note, what ops and corresponding applications you are planning to target using the Custom extensions? |
@kulinseth We are planning to add M1 support to custom ops in MMCV. It would be cool if we can follow the same way how torch add Metal kernel. |
@kulinseth I mentioned in #77764 (comment) that JIT-compiling a Metal kernel is a good path to go down. This has since been attempted in a PR by @malfet, #82307, implementing the same operation I described in the issue comment. You mentioned #78619 (comment) that the best path would be pre-compiling the Metal shaders offline. I don't think that is needed, and I can provide benchmarks proving my statement. As long as the custom Metal code base is small, the benefits of JIT-compiling would outweight the downsides. For example, we don't need to modify the PyTorch build system to accomodate Although the type is decided at runtime, I have plenty of benchmarks showing it has no measurable impact on performance. I would explain why in greater depth, but I won't do so until someone asks me to. |
We have added that support: |
🚀 The feature, motivation and pitch
Hi,
I am working on adding MPS op for MPS backend with a custom kernel.
Here is an example:
https://github.com/grimoire/TorchMPSCustomOpsDemo
I am new to Metal. I am not sure if it is a good way (or the right way) to add such op. There are something I want to discuss:
Device and CommandQueue
Since PyTorch has not exposed the MPS-related API, I have to copy some head from torch csrc. The library is build with
MPSDevice::getInstance()->device()
and the command is commit togetCurrentMPSStream()
. I am not sure if I should flush on commit or not.LibraryFromUrl vs LibraryFromSource
It seems that Metal library can not be linked together with the other object file. So I have to:
Either load it at runtime, which leads to the problem of how to find the relative location of the
.metallib
.Or build it at runtime. Which might take a long time to compile the kernel at runtime.
BuildExtension
If we does not build metal kernel at runtime, we need to setup the compiler for metal kernel in the
setup.py
.Since the
build_ext
provided by Python and PyTorch does not support build Metal, I patched theUnixCCompiler
inBuildExtension
to add the support. Bothcompile
andlink
need to be updated:The code looks ... ugly. Hope there is a better way to do that.
So ... any advice?
Alternatives
No response
Additional context
No response
cc @malfet @zou3519 @kulinseth @albanD
The text was updated successfully, but these errors were encountered: