Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,22 @@ class Placeholder {
return _value == nullptr;
}

void allocateViewTensor(const at::Tensor& src)
{
assert (!_viewOutput.numel());
_viewOutput = at::native::empty_mps(
src.sizes(),
src.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
}

private:
MPSGraphTensor* _placeholder;
MPSGraphTensorData* _value;
Tensor _viewOutput;
};

void resize_tensor(Tensor* output);
Expand Down
119 changes: 73 additions & 46 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -265,73 +265,100 @@ void printTensorNDArray(const Tensor& t) {
[tdata printNDArray];
}

id<MTLBuffer> gatherViewTensor(const at::Tensor& src, id<MTLBuffer> sourceBuffer) {
assert (!src.is_contiguous());
MPSCachedGraph* _getCachedGraph(const at::Tensor& src) {
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
string key = getStridedKey(src, src.sizes(), src.strides(), src.storage_offset());
MPSCachedGraph* cachedGraph = cache_->LookUp(key);

return cachedGraph;
}

id<MTLBuffer> _gatherViewTensor(const at::Tensor& src, id<MTLBuffer> sourceBuffer, MPSCachedGraph* mpsCachedGraph, Tensor& output) {
TORCH_CHECK(mpsCachedGraph != nil);

id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* stream = getCurrentMPSStream();

struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
IntArrayRef size_;
IntArrayRef stride_;
int64_t storage_offset_;
};

CachedGraph* cachedGraph = static_cast<CachedGraph *>(mpsCachedGraph);

@autoreleasepool {
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
IntArrayRef size_;
IntArrayRef stride_;
int64_t storage_offset_;
MPSGraphTensor* inputTensor = cachedGraph->inputTensor_;
MPSGraphTensorData* inputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer
shape: [inputTensor shape]
dataType: [inputTensor dataType]] autorelease];
id<MTLBuffer> resultBuffer = __builtin_bit_cast(id<MTLBuffer>, output.storage().data());
MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: resultBuffer
shape: getMPSShape(src.sizes())
dataType: getMPSDataType(src.scalar_type())] autorelease];
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputTensor : inputTensorData
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();
string key = getStridedKey(src, src.sizes(), src.strides(), src.storage_offset());
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if (cachedGraph) {
@autoreleasepool {
MPSGraphTensor* inputTensor = cachedGraph->inputTensor_;
auto output = at::native::empty_mps(
src.sizes(),
src.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
MPSGraphTensorData* inputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: sourceBuffer
shape: [inputTensor shape]
dataType: [inputTensor dataType]] autorelease];
id<MTLBuffer> resultBuffer = __builtin_bit_cast(id<MTLBuffer>, output.storage().data());
MPSGraphTensorData* outputTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer: resultBuffer
shape: getMPSShape(src.sizes())
dataType: getMPSDataType(src.scalar_type())] autorelease];
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputTensor : inputTensorData
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
cachedGraph->outputTensor_ : outputTensorData
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
return resultBuffer;
}
}
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
cachedGraph->outputTensor_ : outputTensorData
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);
return resultBuffer;
}
}

id<MTLBuffer> gatherViewTensor(const at::Tensor& src, id<MTLBuffer> sourceBuffer) {
MPSCachedGraph* mpsCachedGraph = _getCachedGraph(src);
if (mpsCachedGraph) {
Tensor output = at::native::empty_mps(
src.sizes(),
src.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);

_gatherViewTensor(src, sourceBuffer, mpsCachedGraph, output);
return __builtin_bit_cast(id<MTLBuffer>, output.storage().data());
}

return nil;
}

id<MTLBuffer> gatherViewTensorWithAllocatedMem(const at::Tensor& src, id<MTLBuffer> sourceBuffer, Tensor& output, MPSCachedGraph* mpsCachedGraph) {
TORCH_CHECK(mpsCachedGraph != nil);

_gatherViewTensor(src, sourceBuffer, mpsCachedGraph, output);
return __builtin_bit_cast(id<MTLBuffer>, output.storage().data());
}

Placeholder::Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& src,
MPSShape *mpsShape, bool check_view)
{
Tensor src_ = src;
TORCH_CHECK(src_.is_mps(), "Placeholder storage has not been allocated on MPS device!");
// extract the pointer to MTLBuffer from the Tensor's storage
id<MTLBuffer> srcBuf = __builtin_bit_cast(id<MTLBuffer>, src.storage().data());
if (check_view && !src.is_contiguous()) {
id<MTLBuffer> gatherTensor = gatherViewTensor(src, srcBuf);
if (gatherTensor) {
srcBuf = gatherTensor;
if (check_view) {
MPSCachedGraph* cachedGraph = _getCachedGraph(src);
if (cachedGraph) {
allocateViewTensor(src);
id<MTLBuffer> gatherTensor = gatherViewTensorWithAllocatedMem(src, srcBuf, _viewOutput, cachedGraph);
if (gatherTensor) {
srcBuf = gatherTensor;
}
} else {
src_ = src.contiguous();
srcBuf = __builtin_bit_cast(id<MTLBuffer>, src_.storage().data());
}
}

const size_t buf_size = [srcBuf length];

// tensor.numel() could be zero, but tensor is valid as long as the buffer size is non-zero.
Expand Down
26 changes: 13 additions & 13 deletions aten/src/ATen/native/mps/operations/BinaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,19 @@
typedef MPSGraphTensor* (^BinaryOpBlock)(MPSGraph*, MPSGraphTensor*, MPSGraphTensor*);
#define BinaryOpFn() MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* primary, MPSGraphTensor* secondary)

void binaryOpTensor(const Tensor& self_t, const Tensor& other_t, const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock)
void binaryOpTensor(const Tensor& self, const Tensor& other, const Tensor& output, std::string op_name, BinaryOpBlock binaryBlock)
{
// it's possible to receive empty tensors here
if (self_t.numel() == 0 || other_t.numel() == 0) {
if (self.numel() == 0 || other.numel() == 0) {
return;
}
MPSStream* mpsStream = getCurrentMPSStream();

const bool is_self_scalar = self_t.dim() == 0;
const bool is_other_scalar = other_t.dim() == 0;
const bool is_self_scalar = self.dim() == 0;
const bool is_other_scalar = other.dim() == 0;

Tensor self = is_self_scalar ? self_t : self_t.contiguous(at::MemoryFormat::Contiguous);
Tensor other = is_other_scalar ? other_t : other_t.contiguous(at::MemoryFormat::Contiguous);

const MPSDataType self_dtype = getMPSScalarType((is_self_scalar && !is_other_scalar ? other_t : self_t).scalar_type());
const MPSDataType other_dtype = getMPSScalarType((!is_other_scalar ? other_t : self_t).scalar_type());
const MPSDataType self_dtype = getMPSScalarType((is_self_scalar && !is_other_scalar ? other : self).scalar_type());
const MPSDataType other_dtype = getMPSScalarType((!is_other_scalar ? other : self).scalar_type());

MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
Expand All @@ -58,20 +55,23 @@ void binaryOpTensor(const Tensor& self_t, const Tensor& other_t, const Tensor& o
cachedGraph = static_cast<BinaryOpCachedGraph *>(tmpCachedGraph);
}

NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
Placeholder selfPlaceholder;
Placeholder otherPlaceholder;

if (is_self_scalar) {
feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self.item(), self_dtype);
} else {
Placeholder selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self);
selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self, nullptr, true);
feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData();
}
if (is_other_scalar) {
feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other.item(), other_dtype);
} else {
Placeholder otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other);
otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other, nullptr, true);
feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output, nullptr);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
Expand Down
9 changes: 3 additions & 6 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size,
// 0 sizes won't result in any change in the shape of the Tensor so we can
// skip it. Also if the memory is contiguous we don't need to do
// gather-scatter operations using graph.
if (size.size() > 0 && (!result.is_contiguous())) {
if (size.size() > 0) {

// If self itself was a view tensor, that means we need to chain the graphs
// else we will create a new entry in the cache
Expand Down Expand Up @@ -287,11 +287,6 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
} else {
dst = dst_;
}
dst._set_conj(dst_.is_conj());
src._set_conj(src_.is_conj());

dst._set_neg(dst_.is_neg());
src._set_neg(src_.is_neg());

auto storage_byte_offset = src_.storage_offset() * src_.itemsize();
id<MTLBuffer> sourceBuffer = __builtin_bit_cast(id<MTLBuffer>, src_.storage().data());
Expand Down Expand Up @@ -399,6 +394,8 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
options:options
deallocator:nil];
sourceOffset = uintptr_t(host_src) - uintptr_t(alignedPtr);
if (src_.is_view() || !src_.is_contiguous())
sourceOffset += src_.storage_offset() * src_.itemsize();

dispatch_sync(stream->queue(), ^() {
@autoreleasepool {
Expand Down
65 changes: 63 additions & 2 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,9 @@ def helper(input_shape, batch1_shape, batch2_shape):
output_cpu = torch.baddbmm(M_cpu, batch1_cpu, batch2_cpu, beta=beta, alpha=alpha)
output_mps = torch.baddbmm(M_mps, batch1_mps, batch2_mps, beta=beta, alpha=alpha)

print(output_cpu.shape)
print(output_mps.shape)
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())

helper(input_shape=(3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
helper(input_shape=(10, 3, 5), batch1_shape=(10, 3, 4), batch2_shape=(10, 4, 5))
helper(input_shape=(1, 77, 77), batch1_shape=(8, 77, 64), batch2_shape=(8, 64, 77))
Expand Down Expand Up @@ -1222,6 +1221,68 @@ def test_slice(self):
mps_slice4 = mps_x[1, :].to('cpu')
self.assertEqual(cpu_slice4, mps_slice4)

def test_slice_contiguous_view(self):
# https://github.com/pytorch/pytorch/issues/77750

def helper(operator):
t_mps = torch.tensor([1, 2, 3, 4], device="mps")
t_cpu = torch.tensor([1, 2, 3, 4], device="cpu")

# contiguous view
x_mps = t_mps[2:] # 3, 4
y_mps = t_mps[:2] # 1, 2

x_cpu = t_cpu[2:]
y_cpu = t_cpu[:2]

res_mps = res_cpu = None
if operator == "<=":
res_mps = x_mps <= y_mps
res_cpu = x_cpu <= y_cpu
if operator == "<":
res_mps = x_mps < y_mps
res_cpu = x_cpu < y_cpu
if operator == ">=":
res_mps = x_mps >= y_mps
res_cpu = x_cpu >= y_cpu
if operator == ">":
res_mps = x_mps >= y_mps
res_cpu = x_cpu >= y_cpu
if operator == "==":
res_mps = x_mps == y_mps
res_cpu = x_cpu == y_cpu
if operator == "!=":
res_mps = x_mps != y_mps
res_cpu = x_cpu != y_cpu

self.assertEqual(res_mps, res_cpu)

for op in ["<=", "<", ">=", ">", "==", "!="]:
helper(op)

def test_index_storage_offset(self):
# https://github.com/pytorch/pytorch/issues/78107

a = torch.tensor([8.2670e-01,-1.0293e+00])
b_cpu = a[0]
c_cpu = a[1]

# both 'b' and 'c' are views of 'a'
# 'b' has a storage offset of 0, while 'c' has a storage offset of 1
# when copying from 'cpu' to 'mps', c will have a storage_offset of 1 which needs to be taking into account,
# otherwise it ends with same value as 'b'
b = b_cpu.to('mps')
c = c_cpu.to('mps')

res_mps = b > c
res_cpu = b_cpu > c_cpu
self.assertEqual(res_mps, res_cpu)


res_mps = c > b
res_cpu = c_cpu > b_cpu
self.assertEqual(res_mps, res_cpu)

def test_flatten(self):
values = [[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]
cpu_x = torch.tensor(values, device='cpu')
Expand Down