diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index d21035608a6cd..ecc338e2bcd8b 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -11,6 +11,32 @@ static std::unique_ptr mps_device; static c10::once_flag mpsdev_init; +static inline MTLLanguageVersion getMetalLanguageVersion(const id& device) { + #if defined(__MAC_10_13) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_13 + #else + #error "Metal is not available on the current platform." + #endif + + // MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffer and function constants) + MTLLanguageVersion languageVersion; + if (@available(macOS 13.0, *)) { + languageVersion = MTLLanguageVersion3_0; + } else if (@available(macOS 12.0, *)) { + languageVersion = MTLLanguageVersion2_4; + } else if (@available(macOS 11.0, *)) { + languageVersion = MTLLanguageVersion2_3; + } else if (@available(macOS 10.15, *)) { + languageVersion = MTLLanguageVersion2_2; + } else if (@available(macOS 10.14, *)) { + languageVersion = MTLLanguageVersion2_1; + } else if (@available(macOS 10.13, *)) { + languageVersion = MTLLanguageVersion2_0; + } + + TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2"); + return languageVersion; +} + MPSDevice* MPSDevice::getInstance() { c10::call_once(mpsdev_init, [] { mps_device = std::unique_ptr(new MPSDevice()); @@ -22,8 +48,11 @@ assert(_mtl_device); NSError* error = nil; if (!_mtl_indexing_library) { + MTLCompileOptions *options = [MTLCompileOptions new]; + [options setLanguageVersion: getMetalLanguageVersion(_mtl_device)]; + [options setFastMathEnabled: YES]; _mtl_indexing_library = [_mtl_device newLibraryWithSource: [NSString stringWithCString: mps::indexing_metal_shaders encoding:NSASCIIStringEncoding] - options: nil + options: options error: &error]; TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]); } diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index 982b86fbb6473..9eb58bbc343d9 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -59,6 +59,10 @@ bool dispatchIndexSelectKernel(TensorIteratorBase& iter, IntArrayRef index_size, using namespace mps; @autoreleasepool { + if (iter.numel() == 0) { + return true; + } + const Tensor& inputTensor = iter.tensor(1); Tensor outputTensor = iter.tensor(0);