Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Discussion] How to add MPS extension with custom kernel? #81103

Closed
grimoire opened this issue Jul 8, 2022 · 6 comments
Closed

[Discussion] How to add MPS extension with custom kernel? #81103

grimoire opened this issue Jul 8, 2022 · 6 comments
Assignees
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: cpp-extensions Related to torch.utils.cpp_extension module: mps Related to Apple Metal Performance Shaders framework topic: docs topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@grimoire
Copy link

grimoire commented Jul 8, 2022

🚀 The feature, motivation and pitch

Hi,
I am working on adding MPS op for MPS backend with a custom kernel.
Here is an example:

https://github.com/grimoire/TorchMPSCustomOpsDemo

I am new to Metal. I am not sure if it is a good way (or the right way) to add such op. There are something I want to discuss:

Device and CommandQueue

Since PyTorch has not exposed the MPS-related API, I have to copy some head from torch csrc. The library is build with MPSDevice::getInstance()->device() and the command is commit to getCurrentMPSStream(). I am not sure if I should flush on commit or not.

LibraryFromUrl vs LibraryFromSource

It seems that Metal library can not be linked together with the other object file. So I have to:

Either load it at runtime, which leads to the problem of how to find the relative location of the .metallib.

// load from url
NSURL* metal_url = [NSURL fileURLWithPath: utl_str];
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithURL: metal_url error:&error];

Or build it at runtime. Which might take a long time to compile the kernel at runtime.

// build library from source string
NSString* code_str = [NSString stringWithCString: sources.c_str()];
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithSource: code_str options: nil error:&error];

BuildExtension

If we does not build metal kernel at runtime, we need to setup the compiler for metal kernel in the setup.py.

Since the build_ext provided by Python and PyTorch does not support build Metal, I patched the UnixCCompiler in BuildExtension to add the support. Both compile and link need to be updated:

        # compile
        def darwin_wrap_single_compile(obj, src, ext, cc_args, extra_postargs,
                                       pp_opts) -> None:
            cflags = copy.deepcopy(extra_postargs)
            try:
                original_compiler = self.compiler.compiler_so

                if _is_metal_file(src):
                    # use xcrun metal to compile metal file to `.air`
                    metal = ['xcrun', 'metal']
                    self.compiler.set_executable('compiler_so', metal)
                    if isinstance(cflags, dict):
                        cflags = cflags.get('metal', [])
                    else:
                        cflags = []
                elif isinstance(cflags, dict):
                    cflags = cflags['cxx']

                original_compile(obj, src, ext, cc_args, cflags, pp_opts)
            finally:
                self.compiler.set_executable('compiler_so', original_compiler)
        
        # link
        def darwin_wrap_single_link(target_desc,
                                    objects,
                                    output_filename,
                                    output_dir=None,
                                    libraries=None,
                                    library_dirs=None,
                                    runtime_library_dirs=None,
                                    export_symbols=None,
                                    debug=0,
                                    extra_preargs=None,
                                    extra_postargs=None,
                                    build_temp=None,
                                    target_lang=None):
            if osp.splitext(objects[0])[1].lower() == '.air':
                for obj in objects:
                    assert osp.splitext(obj)[1].lower(
                    ) == '.air', f'Expect .air file, but get {obj}.'
                # link `.air` with xcrun metallib
                linker = ['xcrun', 'metallib']
                self.compiler.spawn(linker + objects + ['-o', output_filename])
            else:
                return original_link(target_desc, objects, output_filename,
                                     output_dir, libraries, library_dirs,
                                     runtime_library_dirs, export_symbols,
                                     debug, extra_preargs, extra_postargs,
                                     build_temp, target_lang)

The code looks ... ugly. Hope there is a better way to do that.

So ... any advice?

Alternatives

No response

Additional context

No response

cc @malfet @zou3519 @kulinseth @albanD

@malfet malfet added topic: docs topic category module: mps Related to Apple Metal Performance Shaders framework module: cpp-extensions Related to torch.utils.cpp_extension enhancement Not as big of a feature, but technically not a bug. Should be easy to fix labels Jul 8, 2022
@malfet
Copy link
Contributor

malfet commented Jul 8, 2022

One way to workaround the problem is to dynamically compile metal kernel, for example see following snippet from #78619

static id<MTLComputePipelineState> compileMetalShader(id<MTLDevice> device) {
static id<MTLComputePipelineState> rc = nil;
if (rc != nil) {
return rc;
}

@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 8, 2022
@grimoire
Copy link
Author

@malfet I see. Is there a roadmap for the custom Metal extension?

@kulinseth
Copy link
Collaborator

@malfet I see. Is there a roadmap for the custom Metal extension?

Hi @grimoire , thanks for starting this thread and adding support for Custom kernels.
I am currently working on exposing the MPSEvent and MPSStream API through Python (through csrc/mps), which will make it easy to add custom ops.

On a related note, what ops and corresponding applications you are planning to target using the Custom extensions?

@grimoire
Copy link
Author

@kulinseth We are planning to add M1 support to custom ops in MMCV. It would be cool if we can follow the same way how torch add Metal kernel.

@philipturner
Copy link

@kulinseth I mentioned in #77764 (comment) that JIT-compiling a Metal kernel is a good path to go down. This has since been attempted in a PR by @malfet, #82307, implementing the same operation I described in the issue comment.

You mentioned #78619 (comment) that the best path would be pre-compiling the Metal shaders offline. I don't think that is needed, and I can provide benchmarks proving my statement. As long as the custom Metal code base is small, the benefits of JIT-compiling would outweight the downsides. For example, we don't need to modify the PyTorch build system to accomodate .metallib files. I do have one precaution. To remove compile-time and runtime overhead, consider dynamically typing stuff on the GPU whenever possible. That means you pass metadata into the GPU shader about which data type it's supposed to read from.

Although the type is decided at runtime, I have plenty of benchmarks showing it has no measurable impact on performance. I would explain why in greater depth, but I won't do so until someone asks me to.

@kulinseth kulinseth self-assigned this Oct 5, 2022
@kulinseth
Copy link
Collaborator

We have added that support:
https://developer.apple.com/documentation/metal/metal_sample_code_library/customizing_a_pytorch_operation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: cpp-extensions Related to torch.utils.cpp_extension module: mps Related to Apple Metal Performance Shaders framework topic: docs topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants