diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 4090212d9504f..41c1c8aa8873e 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -128,6 +128,13 @@ struct MPSUnaryCachedGraph : public MPSCachedGraph MPSGraphTensor *outputTensor_ = nil; }; +struct MPSBinaryCachedGraph : public MPSCachedGraph +{ + MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} + MPSGraphTensor *inputTensor_ = nil; + MPSGraphTensor *otherTensor_ = nil; + MPSGraphTensor *outputTensor_ = nil; +}; // TODO: Improve the overall design of MPSGraphCache. // https://github.com/pytorch/pytorch/issues/77176 diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index 103a8832c1187..bc53b6d666471 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -13,6 +13,9 @@ namespace at { namespace native { +typedef MPSGraphTensor* (^NormOpBlock)(mps::MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*); +#define NormOpFn(graph, primary, secondary) MPSGraphTensor* (mps::MPSBinaryCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary) + enum StdVarType { STANDARD_VARIANCE, STANDARD_DEVIATION @@ -397,11 +400,16 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){ void impl_func_norm_mps( const Tensor& input_tensor, + const Tensor& other_tensor, const OptionalScalarRef& opt_p, IntArrayRef dim, bool keepdim, optional opt_dtype, - const Tensor& output_t) { + const Tensor& output_t, + bool cdist = false, + c10::optional input_broadcasted_shape = c10::nullopt, + NormOpBlock normOpBlock = nullptr + ) { namespace native_mps = at::native::mps; if (input_tensor.numel() == 0) @@ -411,7 +419,7 @@ void impl_func_norm_mps( auto in_dtype = opt_dtype.value_or(input_tensor.scalar_type()); auto mps_input_dtype = native_mps::getMPSDataType(in_dtype); - IntArrayRef input_shape = input_t.sizes(); + IntArrayRef input_shape = cdist ? input_broadcasted_shape.value() : input_t.sizes(); for(int i = 0; i < dim.size(); i++) { auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size()); @@ -419,7 +427,7 @@ void impl_func_norm_mps( "norm_out_mps: reduction dim must be in the range of input shape") } - using CachedGraph = native_mps::MPSUnaryCachedGraph; + using CachedGraph = native_mps::MPSBinaryCachedGraph; native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance(); @@ -449,6 +457,12 @@ void impl_func_norm_mps( num_output_dims, input_shape, axes); + + if (cdist) { + apparent_input_shape = [mps::getMPSShape(input_tensor.sizes()) mutableCopy]; + apparent_output_shape = [mps::getMPSShape(output_t.sizes()) mutableCopy]; + } + if (output_t.numel() == 0) { return; } @@ -458,7 +472,8 @@ void impl_func_norm_mps( @autoreleasepool { NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","]; string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; - string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + native_mps::getMPSTypeString(input_t.scalar_type()) + ":p" + to_string(p) + ":" + keepdim_info; + string tensor_key = cdist ? native_mps::getTensorsStringKey({input_tensor, other_tensor}) : mps::getTensorsStringKey({input_t}); + string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + keepdim_info; auto cachedGraph = cache_->LookUpAs(key); @@ -471,8 +486,15 @@ void impl_func_norm_mps( MPSGraph* mpsGraph = native_mps::make_mps_graph(); newCachedGraph = new CachedGraph(mpsGraph); - MPSGraphTensor* inputTensor_ = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); - MPSGraphTensor* inputTensor = inputTensor_; + if (cdist) { + newCachedGraph->inputTensor_ = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); + newCachedGraph->otherTensor_ = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, other_tensor); + } else { + newCachedGraph->inputTensor_ = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type())); + } + + MPSGraphTensor* inputTensor = cdist ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) : + newCachedGraph->inputTensor_; if (opt_dtype.has_value()) { inputTensor = [mpsGraph castTensor:inputTensor toType:mps_input_dtype @@ -534,7 +556,10 @@ void impl_func_norm_mps( name:nil]; } - newCachedGraph->inputTensor_ = inputTensor_; + if (cdist) { + outputTensor= [mpsGraph reshapeTensor:outputTensor withShape:mps::getMPSShape(output_t) name: nil]; + } + newCachedGraph->outputTensor_ = outputTensor; } return newCachedGraph; @@ -543,6 +568,7 @@ void impl_func_norm_mps( } auto inputPlaceholder = native_mps::Placeholder(); + auto otherPlaceholder = native_mps::Placeholder(); if(apparent_input_shape) inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape); @@ -551,10 +577,13 @@ void impl_func_norm_mps( auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape); + NSMutableDictionary* feeds =[NSMutableDictionary dictionary]; + feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); - NSDictionary *feeds = @{ - inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(), - }; + if (cdist) { + otherPlaceholder = native_mps::Placeholder(cachedGraph->otherTensor_, other_tensor); + feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); + } NSDictionary *results = @{ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() @@ -570,7 +599,7 @@ void impl_func_norm_mps( IntArrayRef dim, bool keepdim, const Tensor& result) { - impl_func_norm_mps(self, opt_p, dim, keepdim, c10::nullopt, result); + impl_func_norm_mps(self, self, opt_p, dim, keepdim, c10::nullopt, result, /*cdist=*/false); } TORCH_IMPL_FUNC(norm_dtype_out_mps) @@ -580,7 +609,84 @@ void impl_func_norm_mps( bool keepdim, ScalarType dtype, const Tensor& result) { - impl_func_norm_mps(self, opt_p, dim, keepdim, dtype, result); + impl_func_norm_mps(self, self, opt_p, dim, keepdim, dtype, result, /*cdist=*/false); +} + +Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c10::optional compute_mode) { + using namespace mps; + TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D"); + TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D"); + TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1)); + TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type()); + auto device1 = x1.device().type(); + TORCH_CHECK(at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type()); + auto device2 = x2.device().type(); + TORCH_CHECK(p >= 0, "cdist only supports non-negative p values"); + TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2); + TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")"); + + int64_t c1 = x1.size(-1); + int64_t c2 = x2.size(-1); + + auto dim1 = x1.dim(); + auto dim2 = x2.dim(); + int64_t mode = compute_mode.value_or(0); + TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode); + + int64_t r1 = x1.size(-2); + int64_t r2 = x2.size(-2); + + //For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them. + //The last two dimensions will stay the same + IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2); + IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2); + std::vector expand_batch_portion = infer_size(batch_tensor1, batch_tensor2); + std::vector tensor1_expand_size(expand_batch_portion); + tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1}); + std::vector tensor2_expand_size(expand_batch_portion); + tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); + + const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion); + std::vector tensor1_view{expand_batch_product, r1, c1}; + std::vector tensor2_view{expand_batch_product, r2, c2}; + + std::vector output_shape(expand_batch_portion); + output_shape.insert(output_shape.end(), {r1, r2}); + Tensor result = at::empty(output_shape, x1.options()); + + NormOpBlock norm_op_block = ^NormOpFn(cachedGraph, x1Tensor, x2Tensor) { + MPSGraph* mpsGraph = cachedGraph->graph(); + + MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor toShape:getMPSShape(tensor1_expand_size) name:nil]; + MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast withShape:getMPSShape(tensor1_view) name:nil]; + + MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor toShape:getMPSShape(tensor2_expand_size) name:nil]; + MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast withShape:getMPSShape(tensor2_view) name:nil]; + + NSMutableArray *inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]]; + NSMutableArray *otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]]; + + for (const auto i : c10::irange(tensor2_view[1])) { + inputArray[i] = inputBroadcastReshape; + } + + for (const auto i : c10::irange(tensor1_view[1])) { + otherArray[i] = otherBroadcastReshape; + } + + MPSGraphTensor *inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil]; + MPSGraphTensor *otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil]; + + + MPSGraphTensor *inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor: inputTensorReshaped + secondaryTensor: otherTensorReshaped + name: nil]; + return inputTensorPNorm; + }; + + c10::optional inputBroadcastSize = c10::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size())); + impl_func_norm_mps(x1, x2, OptionalScalarRef(p), makeArrayRef(2), false, c10::nullopt, result, /*cdist=*/true, inputBroadcastSize, norm_op_block); + return result; } Tensor std_var_common_impl_mps( diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index fdc4c88cbcca6..49361d1c19a00 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3901,6 +3901,7 @@ - func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor dispatch: CPU, CUDA: _cdist_forward + MPS: _cdist_forward_mps autogen: _cdist_forward.out - func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor diff --git a/test/test_mps.py b/test/test_mps.py index 8a616dee95701..1f70febd622e3 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -348,6 +348,151 @@ def helper(dtype): self.assertEqual(res2, res2_cpu) [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]] + def test_cdist_large(self, device="mps"): + for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + x = torch.randn(1000, 10, device=device) + y = torch.randn(1000, 10, device=device) + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertEqual(expected, actual) + + def test_cdist_large_batch(self, device="mps"): + for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + x = torch.randn(4, 3, 1000, 10, device=device) + y = torch.randn(4, 3, 1000, 10, device=device) + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertEqual(expected, actual) + + def test_cdist_non_contiguous(self, device="mps"): + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + x = torch.randn(5, 7, device=device).mT + y = torch.randn(5, 3, device=device).mT + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertEqual(expected, actual) + + x = torch.randn(7, 5, device=device) + y = torch.randn(5, 3, device=device).t() + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertTrue(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertEqual(expected, actual) + + x = torch.randn(5, 7, device=device).t() + y = torch.randn(3, 5, device=device) + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertTrue(y.is_contiguous()) + self.assertEqual(expected, actual) + + def test_cdist_non_contiguous_batch(self, device="mps"): + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + x = torch.randn(4, 3, 2, 5, 7, device=device).mT + y = torch.randn(4, 3, 2, 5, 3, device=device).mT + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertEqual(expected, actual) + + x = torch.randn(7, 2, 7, 5, device=device) + y = torch.randn(7, 2, 5, 3, device=device).mT + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertTrue(x.is_contiguous()) + self.assertFalse(y.is_contiguous()) + self.assertEqual(expected, actual) + + x = torch.randn(4, 5, 7, device=device).mT + y = torch.randn(4, 3, 5, device=device) + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertFalse(x.is_contiguous()) + self.assertTrue(y.is_contiguous()) + self.assertEqual(expected, actual) + + def test_cdist_euclidean_large(self, device="mps"): + def _test_euclidean_large_cdist(sizex, sizey=None): + if sizey is None: + sizey = sizex + x = torch.randn(sizex, device=device, dtype=torch.float) + y = torch.randn(sizey, device=device, dtype=torch.float) + eps = 1e-6 + # to avoid extremum + x = x - (((x - y) < eps).float() * 2 * eps) + x.requires_grad = True + y.requires_grad = True + dist = torch.cdist(x, y, p=2) + # Do a backward pass to check that it is valid for large + # matrices + loss = dist.sum() + loss.backward() + + _test_euclidean_large_cdist((2000, 5)) + + def test_cdist_same_inputs(self, device="mps"): + # Test to detect issues in cdist gradient calculation + # When the distances are 0 + sizex = (1, 27, 32) + for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: + x = torch.randn(sizex, device=device, dtype=torch.float) + dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float) + y = x.clone() + eps = 1e-6 + x.requires_grad = True + d = torch.cdist(x, y) + d.backward(dist_grad) + # Check that the backward passs does not contain invalid + # values such as nan or inf + assert torch.isfinite(x.grad).all() + + + def _brute_cdist(self, x, y, p=2): + r1 = x.shape[-2] + r2 = y.shape[-2] + if r1 == 0 or r2 == 0: + return torch.empty(r1, r2, device=x.device) + return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1) + + def test_cdist_norm(self, device="mps"): + for r1 in [3, 4, 5, 6]: + for m in [2, 3, 4, 10]: + for r2 in [4, 6, 7, 8]: + for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: + x = torch.randn(r1, m, device=device) + y = torch.randn(r2, m, device=device) + if p == 2: + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertEqual(expected, actual, rtol=0, atol=0.02) + else: + actual = torch.cdist(x, y, p=p) + expected = self._brute_cdist(x, y, p=p) + self.assertEqual(expected, actual) + + def test_cdist_norm_batch(self, device="mps"): + for r1 in [3, 4, 5, 6]: + for m in [2, 3, 4, 10]: + for r2 in [4, 6, 7, 8]: + for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: + x = torch.randn(2, 3, 6, r1, m, device=device) + y = torch.randn(2, 3, 6, r2, m, device=device) + if p == 2: + for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']: + actual = torch.cdist(x, y, p=2, compute_mode=cm) + expected = self._brute_cdist(x, y, p=2) + self.assertEqual(expected, actual, rtol=0, atol=0.02) + else: + actual = torch.cdist(x, y, p=p) + expected = self._brute_cdist(x, y, p=p) + self.assertEqual(expected, actual) + def test_cross(self): a = torch.randn(4, 3, device="mps") b = torch.randn(4, 3, device="mps")