Skip to content

Commit

Permalink
Address PR comments (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisVieriu97 authored and kulinseth committed Aug 16, 2022
1 parent da50800 commit d6a6221
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 179 deletions.
156 changes: 91 additions & 65 deletions aten/src/ATen/mps/IndexKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,102 +7,128 @@ static const char * indexing_metal_shaders = R"INDEX_METAL(
#include <metal_stdlib>
using namespace metal;
constant int64_t storage_offset [[function_constant(0)]];
constant uint32_t num_indices [[function_constant(1)]];
constant uint32_t num_indices [[function_constant(0)]];
struct IndexAB {
// Allow up to 30 indices
metal::array<device void *, 30> indexArray [[ id(0) ]];
// Allow up to 16 indices
metal::array<constant const void *, 16> indexArray [[ id(0) ]];
};
template<typename T>
kernel void index_select(device const IndexAB & indexAB [[buffer(0)]],
device const void * indexSizes [[buffer(1)]],
device const void * indexStrides [[buffer(2)]],
device const uint3 * offsets [[buffer(3)]],
device const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
kernel void index_select(
constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
device const int64_t * index_sizes = (device const int64_t *)indexSizes;
device const int64_t * index_strides = (device const int64_t *)indexStrides;
constant const int64_t * index_sizes = (constant const int64_t *)indexSizes;
constant const int64_t * index_strides = (constant const int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
int64_t index = ((device const int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
int64_t index = ((constant const int64_t*)(indexAB.indexArray[i]))[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
}
offset += index * index_strides[i];
}
device T * out = (device T*)((device char*)outputData + offsets[thread_index].x);
device T * in = (device T*)((device char*)inputData + offsets[thread_index].y + offset + storage_offset * sizeof(T));
constant const T * in = (constant const T*)((constant const char*)inputData + offsets[thread_index].y + offset);
*out = *in;
}
template
[[host_name("index_select_float")]]
kernel void index_select<float>(device const IndexAB & indexAB [[buffer(0)]],
device const void * indexSizes [[buffer(1)]],
device const void * indexStrides [[buffer(2)]],
device const uint3 * offsets [[buffer(3)]],
device const void * inputData [[buffer(4)]],
kernel void index_select<float>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_half")]]
kernel void index_select<half>(device const IndexAB & indexAB [[buffer(0)]],
device const void * indexSizes [[buffer(1)]],
device const void * indexStrides [[buffer(2)]],
device const uint3 * offsets [[buffer(3)]],
device const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
kernel void index_select<half>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_long")]]
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_int32")]]
kernel void index_select<int32_t>(device const IndexAB & indexAB [[buffer(0)]],
device const void * indexSizes [[buffer(1)]],
device const void * indexStrides [[buffer(2)]],
device const uint3 * offsets [[buffer(3)]],
device const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
[[host_name("index_select_int")]]
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_int64")]]
kernel void index_select<int64_t>(device const IndexAB & indexAB [[buffer(0)]],
device const void * indexSizes [[buffer(1)]],
device const void * indexStrides [[buffer(2)]],
device const uint3 * offsets [[buffer(3)]],
device const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
[[host_name("index_select_short")]]
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_int16")]]
kernel void index_select<int16_t>(device const IndexAB & indexAB [[buffer(0)]],
device const void * indexSizes [[buffer(1)]],
device const void * indexStrides [[buffer(2)]],
device const uint3 * offsets [[buffer(3)]],
device const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
[[host_name("index_select_char")]]
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_uint8")]]
kernel void index_select<uint8_t>(device const IndexAB & indexAB [[buffer(0)]],
device const void * indexSizes [[buffer(1)]],
device const void * indexStrides [[buffer(2)]],
device const uint3 * offsets [[buffer(3)]],
device const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
[[host_name("index_select_uchar")]]
kernel void index_select<uchar>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_bool")]]
kernel void index_select<bool>(device const IndexAB & indexAB [[buffer(0)]],
device const void * indexSizes [[buffer(1)]],
device const void * indexStrides [[buffer(2)]],
device const uint3 * offsets [[buffer(3)]],
device const void * inputData [[buffer(4)]],
kernel void index_select<bool>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
)INDEX_METAL";
kernel void kernel_index_offsets(constant const packed_uint3 * strides [[buffer(0)]],
device uint3 * data_offsets [[buffer(1)]],
constant const uint * iter_shape [[buffer(2)]],
constant const uint & num_dimensions [[buffer(3)]],
constant const uint & num_offsets [[buffer(4)]],
uint thread_index [[thread_position_in_grid]]) {
device uint3 & localDataOffsets = data_offsets[thread_index];
uint32_t idx = thread_index;
for (uint32_t dim = 0; dim < num_dimensions; dim++) {
uint32_t remainder = idx % iter_shape[dim];
idx /= iter_shape[dim];
for (uint32_t offset = 0; offset < num_offsets; offset++)
data_offsets[thread_index][offset] += remainder * strides[dim][offset];
}
}
)INDEX_METAL";
}
}
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TORCH_API MPSDevice {
return _mtl_device;
}

MTLFunction_t metalFunction(const std::string &kernel, MTLFunctionConstantValues_t constantValues);
MTLFunction_t metalIndexingFunction(const std::string &kernel, MTLFunctionConstantValues_t constantValues);

~MPSDevice();

Expand Down
19 changes: 13 additions & 6 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
return mps_device.get();
}

id<MTLFunction> MPSDevice::metalFunction(const std::string& kernel, MTLFunctionConstantValues* constantValues) {
id<MTLFunction> MPSDevice::metalIndexingFunction(const std::string& kernel, MTLFunctionConstantValues* constantValues) {
assert(_mtl_device);
NSError* error = nil;
if (!_mtl_indexing_library) {
Expand All @@ -55,20 +55,28 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
}

id<MTLFunction> indexFunction = [_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String:kernel.c_str()]
constantValues: constantValues
error: &error];
id<MTLFunction> indexFunction = nil;
if (constantValues) {
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]
constantValues: constantValues
error: &error] autorelease];
} else {
indexFunction = [[_mtl_indexing_library newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]] autorelease];
}

TORCH_CHECK(indexFunction, "Failed to create specialized function state object: ", kernel, ", error: ", [[error description] UTF8String]);

return indexFunction;
}

MPSDevice::~MPSDevice() {
[_mtl_device release];
[_mtl_indexing_library release];
_mtl_device = nil;
_mtl_indexing_library = nil;
}

MPSDevice::MPSDevice(): _mtl_device(nil) {
MPSDevice::MPSDevice(): _mtl_device(nil), _mtl_indexing_library(nil) {
// Check that MacOS 12.3+ version of MPS framework is available
// Create the MPSGraph and check method introduced in 12.3+
// which is used by MPS backend.
Expand All @@ -91,7 +99,6 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
break;
}
}
_mtl_indexing_library = nil;
assert(_mtl_device);
}

Expand Down
56 changes: 56 additions & 0 deletions aten/src/ATen/native/mps/operations/Indexing.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright © 2022 Apple Inc.

#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/TensorFactory.h>
#include <c10/core/ScalarType.h>
#include <torch/library.h>
#include <unordered_map>

#ifdef __OBJC__
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#endif

using namespace at::mps;

namespace at {
namespace native {
namespace mps {

std::string getMetalScalarType(ScalarType scalar_type) {
std::string res = "";
switch (scalar_type) {
case ScalarType::Float:
res = "float"; break;
case ScalarType::Half:
res = "half"; break;
case ScalarType::Long:
res = "long"; break;
case ScalarType::Int:
res = "int"; break;
case ScalarType::Short:
res = "short"; break;
case ScalarType::Char:
res = "char"; break;
case ScalarType::Byte:
res = "uchar"; break;
case ScalarType::Bool:
res = "bool"; break;
default:
break;
}
return res;
}

std::string getIndexFunctionName(ScalarType scalar_type, bool index_select, bool accumulate) {
std::string indexFunction = index_select ? "index_select_" :
(accumulate && (scalar_type != kBool)) ? "index_put_accumulate_" : "index_put_";

return indexFunction + getMetalScalarType(scalar_type);
}

}
}
}
Loading

0 comments on commit d6a6221

Please sign in to comment.