Description
🚀 The feature, motivation and pitch
Hi,
I am working on adding MPS op for MPS backend with a custom kernel.
Here is an example:
https://github.com/grimoire/TorchMPSCustomOpsDemo
I am new to Metal. I am not sure if it is a good way (or the right way) to add such op. There are something I want to discuss:
Device and CommandQueue
Since PyTorch has not exposed the MPS-related API, I have to copy some head from torch csrc. The library is build with MPSDevice::getInstance()->device()
and the command is commit to getCurrentMPSStream()
. I am not sure if I should flush on commit or not.
LibraryFromUrl vs LibraryFromSource
It seems that Metal library can not be linked together with the other object file. So I have to:
Either load it at runtime, which leads to the problem of how to find the relative location of the .metallib
.
// load from url
NSURL* metal_url = [NSURL fileURLWithPath: utl_str];
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithURL: metal_url error:&error];
Or build it at runtime. Which might take a long time to compile the kernel at runtime.
// build library from source string
NSString* code_str = [NSString stringWithCString: sources.c_str()];
library->_library = [at::mps::MPSDevice::getInstance()->device() newLibraryWithSource: code_str options: nil error:&error];
BuildExtension
If we does not build metal kernel at runtime, we need to setup the compiler for metal kernel in the setup.py
.
Since the build_ext
provided by Python and PyTorch does not support build Metal, I patched the UnixCCompiler
in BuildExtension
to add the support. Both compile
and link
need to be updated:
# compile
def darwin_wrap_single_compile(obj, src, ext, cc_args, extra_postargs,
pp_opts) -> None:
cflags = copy.deepcopy(extra_postargs)
try:
original_compiler = self.compiler.compiler_so
if _is_metal_file(src):
# use xcrun metal to compile metal file to `.air`
metal = ['xcrun', 'metal']
self.compiler.set_executable('compiler_so', metal)
if isinstance(cflags, dict):
cflags = cflags.get('metal', [])
else:
cflags = []
elif isinstance(cflags, dict):
cflags = cflags['cxx']
original_compile(obj, src, ext, cc_args, cflags, pp_opts)
finally:
self.compiler.set_executable('compiler_so', original_compiler)
# link
def darwin_wrap_single_link(target_desc,
objects,
output_filename,
output_dir=None,
libraries=None,
library_dirs=None,
runtime_library_dirs=None,
export_symbols=None,
debug=0,
extra_preargs=None,
extra_postargs=None,
build_temp=None,
target_lang=None):
if osp.splitext(objects[0])[1].lower() == '.air':
for obj in objects:
assert osp.splitext(obj)[1].lower(
) == '.air', f'Expect .air file, but get {obj}.'
# link `.air` with xcrun metallib
linker = ['xcrun', 'metallib']
self.compiler.spawn(linker + objects + ['-o', output_filename])
else:
return original_link(target_desc, objects, output_filename,
output_dir, libraries, library_dirs,
runtime_library_dirs, export_symbols,
debug, extra_preargs, extra_postargs,
build_temp, target_lang)
The code looks ... ugly. Hope there is a better way to do that.
So ... any advice?
Alternatives
No response
Additional context
No response