Skip to content

[Discussion] How to add MPS extension with custom kernel? #81103

Closed
@grimoire

Description

@grimoire

🚀 The feature, motivation and pitch

Hi,
I am working on adding MPS op for MPS backend with a custom kernel.
Here is an example:

https://github.com/grimoire/TorchMPSCustomOpsDemo

I am new to Metal. I am not sure if it is a good way (or the right way) to add such op. There are something I want to discuss:

Device and CommandQueue

Since PyTorch has not exposed the MPS-related API, I have to copy some head from torch csrc. The library is build with MPSDevice::getInstance()->device() and the command is commit to getCurrentMPSStream(). I am not sure if I should flush on commit or not.

LibraryFromUrl vs LibraryFromSource

It seems that Metal library can not be linked together with the other object file. So I have to:

Either load it at runtime, which leads to the problem of how to find the relative location of the .metallib.

// load from url
NSURL* metal_url = [NSURL fileURLWithPath: utl_str];
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithURL: metal_url error:&error];

Or build it at runtime. Which might take a long time to compile the kernel at runtime.

// build library from source string
NSString* code_str = [NSString stringWithCString: sources.c_str()];
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithSource: code_str options: nil error:&error];

BuildExtension

If we does not build metal kernel at runtime, we need to setup the compiler for metal kernel in the setup.py.

Since the build_ext provided by Python and PyTorch does not support build Metal, I patched the UnixCCompiler in BuildExtension to add the support. Both compile and link need to be updated:

        # compile
        def darwin_wrap_single_compile(obj, src, ext, cc_args, extra_postargs,
                                       pp_opts) -> None:
            cflags = copy.deepcopy(extra_postargs)
            try:
                original_compiler = self.compiler.compiler_so

                if _is_metal_file(src):
                    # use xcrun metal to compile metal file to `.air`
                    metal = ['xcrun', 'metal']
                    self.compiler.set_executable('compiler_so', metal)
                    if isinstance(cflags, dict):
                        cflags = cflags.get('metal', [])
                    else:
                        cflags = []
                elif isinstance(cflags, dict):
                    cflags = cflags['cxx']

                original_compile(obj, src, ext, cc_args, cflags, pp_opts)
            finally:
                self.compiler.set_executable('compiler_so', original_compiler)
        
        # link
        def darwin_wrap_single_link(target_desc,
                                    objects,
                                    output_filename,
                                    output_dir=None,
                                    libraries=None,
                                    library_dirs=None,
                                    runtime_library_dirs=None,
                                    export_symbols=None,
                                    debug=0,
                                    extra_preargs=None,
                                    extra_postargs=None,
                                    build_temp=None,
                                    target_lang=None):
            if osp.splitext(objects[0])[1].lower() == '.air':
                for obj in objects:
                    assert osp.splitext(obj)[1].lower(
                    ) == '.air', f'Expect .air file, but get {obj}.'
                # link `.air` with xcrun metallib
                linker = ['xcrun', 'metallib']
                self.compiler.spawn(linker + objects + ['-o', output_filename])
            else:
                return original_link(target_desc, objects, output_filename,
                                     output_dir, libraries, library_dirs,
                                     runtime_library_dirs, export_symbols,
                                     debug, extra_preargs, extra_postargs,
                                     build_temp, target_lang)

The code looks ... ugly. Hope there is a better way to do that.

So ... any advice?

Alternatives

No response

Additional context

No response

cc @malfet @zou3519 @kulinseth @albanD

Metadata

Metadata

Assignees

Labels

enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: cpp-extensionsRelated to torch.utils.cpp_extensionmodule: mpsRelated to Apple Metal Performance Shaders frameworktopic: docstopic categorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions