Skip to content

Commit 8028e5d

Browse files
committed
Add support for cdist
1 parent 893b902 commit 8028e5d

File tree

4 files changed

+284
-12
lines changed

4 files changed

+284
-12
lines changed

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ struct MPSUnaryCachedGraph : public MPSCachedGraph
128128
MPSGraphTensor *outputTensor_ = nil;
129129
};
130130

131+
struct MPSBinaryCachedGraph : public MPSCachedGraph
132+
{
133+
MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
134+
MPSGraphTensor *inputTensor_ = nil;
135+
MPSGraphTensor *otherTensor_ = nil;
136+
MPSGraphTensor *outputTensor_ = nil;
137+
};
131138

132139
// TODO: Improve the overall design of MPSGraphCache.
133140
// https://github.com/pytorch/pytorch/issues/77176

aten/src/ATen/native/mps/operations/ReduceOps.mm

Lines changed: 119 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
namespace at {
1414
namespace native {
1515

16+
typedef MPSGraphTensor* (^NormOpBlock)(mps::MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*);
17+
#define NormOpFn(graph, primary, secondary) MPSGraphTensor* (mps::MPSBinaryCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary)
18+
1619
enum StdVarType {
1720
STANDARD_VARIANCE,
1821
STANDARD_DEVIATION
@@ -397,11 +400,16 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){
397400

398401
void impl_func_norm_mps(
399402
const Tensor& input_tensor,
403+
const Tensor& other_tensor,
400404
const OptionalScalarRef& opt_p,
401405
IntArrayRef dim,
402406
bool keepdim,
403407
optional<ScalarType> opt_dtype,
404-
const Tensor& output_t) {
408+
const Tensor& output_t,
409+
bool cdist = false,
410+
c10::optional<IntArrayRef> input_broadcasted_shape = c10::nullopt,
411+
NormOpBlock normOpBlock = nullptr
412+
) {
405413

406414
namespace native_mps = at::native::mps;
407415
if (input_tensor.numel() == 0)
@@ -411,15 +419,15 @@ void impl_func_norm_mps(
411419
auto in_dtype = opt_dtype.value_or(input_tensor.scalar_type());
412420
auto mps_input_dtype = native_mps::getMPSDataType(in_dtype);
413421

414-
IntArrayRef input_shape = input_t.sizes();
422+
IntArrayRef input_shape = cdist ? input_broadcasted_shape.value() : input_t.sizes();
415423

416424
for(int i = 0; i < dim.size(); i++) {
417425
auto wrap_dim = maybe_wrap_dim(dim[i], input_shape.size());
418426
TORCH_CHECK(wrap_dim < input_shape.size(),
419427
"norm_out_mps: reduction dim must be in the range of input shape")
420428
}
421429

422-
using CachedGraph = native_mps::MPSUnaryCachedGraph;
430+
using CachedGraph = native_mps::MPSBinaryCachedGraph;
423431

424432
native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
425433

@@ -449,6 +457,12 @@ void impl_func_norm_mps(
449457
num_output_dims,
450458
input_shape,
451459
axes);
460+
461+
if (cdist) {
462+
apparent_input_shape = [mps::getMPSShape(input_tensor.sizes()) mutableCopy];
463+
apparent_output_shape = [mps::getMPSShape(output_t.sizes()) mutableCopy];
464+
}
465+
452466
if (output_t.numel() == 0) {
453467
return;
454468
}
@@ -458,7 +472,8 @@ void impl_func_norm_mps(
458472
@autoreleasepool {
459473
NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","];
460474
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
461-
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + native_mps::getMPSTypeString(input_t.scalar_type()) + ":p" + to_string(p) + ":" + keepdim_info;
475+
string op_key = cdist ? native_mps::getTensorsStringKey({input_tensor, other_tensor}) : native_mps::getMPSTypeString(input_t.scalar_type());
476+
string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + op_key + ":p" + to_string(p) + ":" + keepdim_info;
462477

463478
auto cachedGraph = cache_->LookUpAs<CachedGraph>(key);
464479

@@ -471,8 +486,16 @@ void impl_func_norm_mps(
471486
MPSGraph* mpsGraph = native_mps::make_mps_graph();
472487
newCachedGraph = new CachedGraph(mpsGraph);
473488

474-
MPSGraphTensor* inputTensor_ = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
475-
MPSGraphTensor* inputTensor = inputTensor_;
489+
if (cdist) {
490+
newCachedGraph->inputTensor_ = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
491+
newCachedGraph->otherTensor_ = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, other_tensor);
492+
}
493+
else {
494+
newCachedGraph->inputTensor_ = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
495+
}
496+
497+
MPSGraphTensor* inputTensor = cdist ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) :
498+
newCachedGraph->inputTensor_;
476499
if (opt_dtype.has_value()) {
477500
inputTensor = [mpsGraph castTensor:inputTensor
478501
toType:mps_input_dtype
@@ -534,7 +557,10 @@ void impl_func_norm_mps(
534557
name:nil];
535558
}
536559

537-
newCachedGraph->inputTensor_ = inputTensor_;
560+
if (cdist) {
561+
outputTensor= [mpsGraph reshapeTensor:outputTensor withShape:mps::getMPSShape(output_t) name: nil];
562+
}
563+
538564
newCachedGraph->outputTensor_ = outputTensor;
539565
}
540566
return newCachedGraph;
@@ -543,6 +569,7 @@ void impl_func_norm_mps(
543569
}
544570

545571
auto inputPlaceholder = native_mps::Placeholder();
572+
auto otherPlaceholder = native_mps::Placeholder();
546573

547574
if(apparent_input_shape)
548575
inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape);
@@ -551,10 +578,13 @@ void impl_func_norm_mps(
551578

552579
auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape);
553580

581+
NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
582+
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
554583

555-
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
556-
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
557-
};
584+
if (cdist) {
585+
otherPlaceholder = native_mps::Placeholder(cachedGraph->otherTensor_, other_tensor);
586+
feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
587+
}
558588

559589
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
560590
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
@@ -570,7 +600,7 @@ void impl_func_norm_mps(
570600
IntArrayRef dim,
571601
bool keepdim,
572602
const Tensor& result) {
573-
impl_func_norm_mps(self, opt_p, dim, keepdim, c10::nullopt, result);
603+
impl_func_norm_mps(self, self, opt_p, dim, keepdim, c10::nullopt, result, /*cdist=*/false);
574604
}
575605

576606
TORCH_IMPL_FUNC(norm_dtype_out_mps)
@@ -580,7 +610,84 @@ void impl_func_norm_mps(
580610
bool keepdim,
581611
ScalarType dtype,
582612
const Tensor& result) {
583-
impl_func_norm_mps(self, opt_p, dim, keepdim, dtype, result);
613+
impl_func_norm_mps(self, self, opt_p, dim, keepdim, dtype, result, /*cdist=*/false);
614+
}
615+
616+
Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
617+
using namespace mps;
618+
TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
619+
TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
620+
TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
621+
TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type());
622+
auto device1 = x1.device().type();
623+
TORCH_CHECK(at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type());
624+
auto device2 = x2.device().type();
625+
TORCH_CHECK(p >= 0, "cdist only supports non-negative p values");
626+
TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
627+
TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");
628+
629+
int64_t c1 = x1.size(-1);
630+
int64_t c2 = x2.size(-1);
631+
632+
auto dim1 = x1.dim();
633+
auto dim2 = x2.dim();
634+
int64_t mode = compute_mode.value_or(0);
635+
TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode);
636+
637+
int64_t r1 = x1.size(-2);
638+
int64_t r2 = x2.size(-2);
639+
640+
//For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
641+
//The last two dimensions will stay the same
642+
IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2);
643+
IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2);
644+
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
645+
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
646+
tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
647+
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
648+
tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});
649+
650+
const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion);
651+
std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
652+
std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};
653+
654+
std::vector<int64_t> output_shape(expand_batch_portion);
655+
output_shape.insert(output_shape.end(), {r1, r2});
656+
Tensor result = at::empty(output_shape, x1.options());
657+
658+
NormOpBlock norm_op_block = ^NormOpFn(cachedGraph, x1Tensor, x2Tensor) {
659+
MPSGraph* mpsGraph = cachedGraph->graph();
660+
661+
MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor toShape:getMPSShape(tensor1_expand_size) name:nil];
662+
MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast withShape:getMPSShape(tensor1_view) name:nil];
663+
664+
MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor toShape:getMPSShape(tensor2_expand_size) name:nil];
665+
MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast withShape:getMPSShape(tensor2_view) name:nil];
666+
667+
NSMutableArray<MPSGraphTensor*> *inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]];
668+
NSMutableArray<MPSGraphTensor*> *otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]];
669+
670+
for (const auto i : c10::irange(tensor2_view[1])) {
671+
inputArray[i] = inputBroadcastReshape;
672+
}
673+
674+
for (const auto i : c10::irange(tensor1_view[1])) {
675+
otherArray[i] = otherBroadcastReshape;
676+
}
677+
678+
MPSGraphTensor *inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil];
679+
MPSGraphTensor *otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil];
680+
681+
682+
MPSGraphTensor *inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor: inputTensorReshaped
683+
secondaryTensor: otherTensorReshaped
684+
name: nil];
685+
return inputTensorPNorm;
686+
};
687+
688+
c10::optional<IntArrayRef> inputBroadcastSize = c10::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size()));
689+
impl_func_norm_mps(x1, x2, OptionalScalarRef(p), makeArrayRef<int64_t>(2), false, c10::nullopt, result, /*cdist=*/true, inputBroadcastSize, norm_op_block);
690+
return result;
584691
}
585692

586693
Tensor std_var_common_impl_mps(

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3901,6 +3901,7 @@
39013901
- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
39023902
dispatch:
39033903
CPU, CUDA: _cdist_forward
3904+
MPS: _cdist_forward_mps
39043905
autogen: _cdist_forward.out
39053906

39063907
- func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor

0 commit comments

Comments
 (0)