Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ struct MPSUnaryCachedGraph : public MPSCachedGraph
MPSGraphTensor *outputTensor_ = nil;
};

struct MPSBinaryCachedGraph : public MPSCachedGraph
{
MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *otherTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
};

// TODO: Improve the overall design of MPSGraphCache.
// https://github.com/pytorch/pytorch/issues/77176
Expand Down
130 changes: 118 additions & 12 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
namespace at {
namespace native {

typedef MPSGraphTensor* (^NormOpBlock)(mps::MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*);
#define NormOpFn(graph, primary, secondary) MPSGraphTensor* (mps::MPSBinaryCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary)

enum StdVarType {
STANDARD_VARIANCE,
STANDARD_DEVIATION
Expand Down Expand Up @@ -397,11 +400,16 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){

void impl_func_norm_mps(
const Tensor& input_tensor,
const Tensor& other_tensor,
const OptionalScalarRef& opt_p,
IntArrayRef dim,
bool keepdim,
optional<ScalarType> opt_dtype,
const Tensor& output_t) {
const Tensor& output_t,
bool cdist = false,
c10::optional<IntArrayRef> input_broadcasted_shape = c10::nullopt,
NormOpBlock normOpBlock = nullptr
) {

namespace native_mps = at::native::mps;
if (input_tensor.numel() == 0)
Expand All @@ -411,15 +419,15 @@ void impl_func_norm_mps(
auto in_dtype = opt_dtype.value_or(input_tensor.scalar_type());
auto mps_input_dtype = native_mps::getMPSDataType(in_dtype);

IntArrayRef input_shape = input_t.sizes();
IntArrayRef input_shape = cdist ? input_broadcasted_shape.value() : input_t.sizes();

for(int i = 0; i < dim.size(); i++) {
auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size());
TORCH_CHECK(wrap_dim < input_shape.size(),
"norm_out_mps: reduction dim must be in the range of input shape")
}

using CachedGraph = native_mps::MPSUnaryCachedGraph;
using CachedGraph = native_mps::MPSBinaryCachedGraph;

native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();

Expand Down Expand Up @@ -449,6 +457,12 @@ void impl_func_norm_mps(
num_output_dims,
input_shape,
axes);

if (cdist) {
apparent_input_shape = [mps::getMPSShape(input_tensor.sizes()) mutableCopy];
apparent_output_shape = [mps::getMPSShape(output_t.sizes()) mutableCopy];
}

if (output_t.numel() == 0) {
return;
}
Expand All @@ -458,7 +472,8 @@ void impl_func_norm_mps(
@autoreleasepool {
NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","];
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + native_mps::getMPSTypeString(input_t.scalar_type()) + ":p" + to_string(p) + ":" + keepdim_info;
string tensor_key = cdist ? native_mps::getTensorsStringKey({input_tensor, other_tensor}) : mps::getTensorsStringKey({input_t});
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + keepdim_info;

auto cachedGraph = cache_->LookUpAs<CachedGraph>(key);

Expand All @@ -471,8 +486,15 @@ void impl_func_norm_mps(
MPSGraph* mpsGraph = native_mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

MPSGraphTensor* inputTensor_ = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
MPSGraphTensor* inputTensor = inputTensor_;
if (cdist) {
newCachedGraph->inputTensor_ = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
newCachedGraph->otherTensor_ = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, other_tensor);
} else {
newCachedGraph->inputTensor_ = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
}

MPSGraphTensor* inputTensor = cdist ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) :
newCachedGraph->inputTensor_;
if (opt_dtype.has_value()) {
inputTensor = [mpsGraph castTensor:inputTensor
toType:mps_input_dtype
Expand Down Expand Up @@ -534,7 +556,10 @@ void impl_func_norm_mps(
name:nil];
}

newCachedGraph->inputTensor_ = inputTensor_;
if (cdist) {
outputTensor= [mpsGraph reshapeTensor:outputTensor withShape:mps::getMPSShape(output_t) name: nil];
}

newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
Expand All @@ -543,6 +568,7 @@ void impl_func_norm_mps(
}

auto inputPlaceholder = native_mps::Placeholder();
auto otherPlaceholder = native_mps::Placeholder();

if(apparent_input_shape)
inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape);
Expand All @@ -551,10 +577,13 @@ void impl_func_norm_mps(

auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape);

NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};
if (cdist) {
otherPlaceholder = native_mps::Placeholder(cachedGraph->otherTensor_, other_tensor);
feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
}

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
Expand All @@ -570,7 +599,7 @@ void impl_func_norm_mps(
IntArrayRef dim,
bool keepdim,
const Tensor& result) {
impl_func_norm_mps(self, opt_p, dim, keepdim, c10::nullopt, result);
impl_func_norm_mps(self, self, opt_p, dim, keepdim, c10::nullopt, result, /*cdist=*/false);
}

TORCH_IMPL_FUNC(norm_dtype_out_mps)
Expand All @@ -580,7 +609,84 @@ void impl_func_norm_mps(
bool keepdim,
ScalarType dtype,
const Tensor& result) {
impl_func_norm_mps(self, opt_p, dim, keepdim, dtype, result);
impl_func_norm_mps(self, self, opt_p, dim, keepdim, dtype, result, /*cdist=*/false);
}

Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
using namespace mps;
TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type());
auto device1 = x1.device().type();
TORCH_CHECK(at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type());
auto device2 = x2.device().type();
TORCH_CHECK(p >= 0, "cdist only supports non-negative p values");
TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");

int64_t c1 = x1.size(-1);
int64_t c2 = x2.size(-1);

auto dim1 = x1.dim();
auto dim2 = x2.dim();
int64_t mode = compute_mode.value_or(0);
TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode);

int64_t r1 = x1.size(-2);
int64_t r2 = x2.size(-2);

//For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
//The last two dimensions will stay the same
IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2);
IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2);
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});

const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion);
std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};

std::vector<int64_t> output_shape(expand_batch_portion);
output_shape.insert(output_shape.end(), {r1, r2});
Tensor result = at::empty(output_shape, x1.options());

NormOpBlock norm_op_block = ^NormOpFn(cachedGraph, x1Tensor, x2Tensor) {
MPSGraph* mpsGraph = cachedGraph->graph();

MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor toShape:getMPSShape(tensor1_expand_size) name:nil];
MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast withShape:getMPSShape(tensor1_view) name:nil];

MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor toShape:getMPSShape(tensor2_expand_size) name:nil];
MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast withShape:getMPSShape(tensor2_view) name:nil];

NSMutableArray<MPSGraphTensor*> *inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]];
NSMutableArray<MPSGraphTensor*> *otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]];

for (const auto i : c10::irange(tensor2_view[1])) {
inputArray[i] = inputBroadcastReshape;
}

for (const auto i : c10::irange(tensor1_view[1])) {
otherArray[i] = otherBroadcastReshape;
}

MPSGraphTensor *inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil];
MPSGraphTensor *otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil];


MPSGraphTensor *inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor: inputTensorReshaped
secondaryTensor: otherTensorReshaped
name: nil];
return inputTensorPNorm;
};

c10::optional<IntArrayRef> inputBroadcastSize = c10::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size()));
impl_func_norm_mps(x1, x2, OptionalScalarRef(p), makeArrayRef<int64_t>(2), false, c10::nullopt, result, /*cdist=*/true, inputBroadcastSize, norm_op_block);
return result;
}

Tensor std_var_common_impl_mps(
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3901,6 +3901,7 @@
- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
dispatch:
CPU, CUDA: _cdist_forward
MPS: _cdist_forward_mps
autogen: _cdist_forward.out

- func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor
Expand Down
145 changes: 145 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,151 @@ def helper(dtype):
self.assertEqual(res2, res2_cpu)
[helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]

def test_cdist_large(self, device="mps"):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(1000, 10, device=device)
y = torch.randn(1000, 10, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual)

def test_cdist_large_batch(self, device="mps"):
for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(4, 3, 1000, 10, device=device)
y = torch.randn(4, 3, 1000, 10, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual)

def test_cdist_non_contiguous(self, device="mps"):
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(5, 7, device=device).mT
y = torch.randn(5, 3, device=device).mT
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)

x = torch.randn(7, 5, device=device)
y = torch.randn(5, 3, device=device).t()
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)

x = torch.randn(5, 7, device=device).t()
y = torch.randn(3, 5, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertTrue(y.is_contiguous())
self.assertEqual(expected, actual)

def test_cdist_non_contiguous_batch(self, device="mps"):
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
x = torch.randn(4, 3, 2, 5, 7, device=device).mT
y = torch.randn(4, 3, 2, 5, 3, device=device).mT
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)

x = torch.randn(7, 2, 7, 5, device=device)
y = torch.randn(7, 2, 5, 3, device=device).mT
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertTrue(x.is_contiguous())
self.assertFalse(y.is_contiguous())
self.assertEqual(expected, actual)

x = torch.randn(4, 5, 7, device=device).mT
y = torch.randn(4, 3, 5, device=device)
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertFalse(x.is_contiguous())
self.assertTrue(y.is_contiguous())
self.assertEqual(expected, actual)

def test_cdist_euclidean_large(self, device="mps"):
def _test_euclidean_large_cdist(sizex, sizey=None):
if sizey is None:
sizey = sizex
x = torch.randn(sizex, device=device, dtype=torch.float)
y = torch.randn(sizey, device=device, dtype=torch.float)
eps = 1e-6
# to avoid extremum
x = x - (((x - y) < eps).float() * 2 * eps)
x.requires_grad = True
y.requires_grad = True
dist = torch.cdist(x, y, p=2)
# Do a backward pass to check that it is valid for large
# matrices
loss = dist.sum()
loss.backward()

_test_euclidean_large_cdist((2000, 5))

def test_cdist_same_inputs(self, device="mps"):
# Test to detect issues in cdist gradient calculation
# When the distances are 0
sizex = (1, 27, 32)
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
x = torch.randn(sizex, device=device, dtype=torch.float)
dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
y = x.clone()
eps = 1e-6
x.requires_grad = True
d = torch.cdist(x, y)
d.backward(dist_grad)
# Check that the backward passs does not contain invalid
# values such as nan or inf
assert torch.isfinite(x.grad).all()


def _brute_cdist(self, x, y, p=2):
r1 = x.shape[-2]
r2 = y.shape[-2]
if r1 == 0 or r2 == 0:
return torch.empty(r1, r2, device=x.device)
return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)

def test_cdist_norm(self, device="mps"):
for r1 in [3, 4, 5, 6]:
for m in [2, 3, 4, 10]:
for r2 in [4, 6, 7, 8]:
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
x = torch.randn(r1, m, device=device)
y = torch.randn(r2, m, device=device)
if p == 2:
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual, rtol=0, atol=0.02)
else:
actual = torch.cdist(x, y, p=p)
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)

def test_cdist_norm_batch(self, device="mps"):
for r1 in [3, 4, 5, 6]:
for m in [2, 3, 4, 10]:
for r2 in [4, 6, 7, 8]:
for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
x = torch.randn(2, 3, 6, r1, m, device=device)
y = torch.randn(2, 3, 6, r2, m, device=device)
if p == 2:
for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
actual = torch.cdist(x, y, p=2, compute_mode=cm)
expected = self._brute_cdist(x, y, p=2)
self.assertEqual(expected, actual, rtol=0, atol=0.02)
else:
actual = torch.cdist(x, y, p=p)
expected = self._brute_cdist(x, y, p=p)
self.assertEqual(expected, actual)

def test_cross(self):
a = torch.randn(4, 3, device="mps")
b = torch.randn(4, 3, device="mps")
Expand Down