1313namespace at {
1414namespace native {
1515
16+ typedef MPSGraphTensor* (^NormOpBlock)(mps::MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*);
17+ #define NormOpFn (graph, primary, secondary ) MPSGraphTensor* (mps::MPSBinaryCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary)
18+
1619enum StdVarType {
1720 STANDARD_VARIANCE,
1821 STANDARD_DEVIATION
@@ -397,11 +400,16 @@ Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims){
397400
398401void impl_func_norm_mps (
399402 const Tensor& input_tensor,
403+ const Tensor& other_tensor,
400404 const OptionalScalarRef& opt_p,
401405 IntArrayRef dim,
402406 bool keepdim,
403407 optional<ScalarType> opt_dtype,
404- const Tensor& output_t ) {
408+ const Tensor& output_t ,
409+ bool cdist = false ,
410+ c10::optional<IntArrayRef> input_broadcasted_shape = c10::nullopt ,
411+ NormOpBlock normOpBlock = nullptr
412+ ) {
405413
406414 namespace native_mps = at::native::mps;
407415 if (input_tensor.numel () == 0 )
@@ -411,15 +419,15 @@ void impl_func_norm_mps(
411419 auto in_dtype = opt_dtype.value_or (input_tensor.scalar_type ());
412420 auto mps_input_dtype = native_mps::getMPSDataType (in_dtype);
413421
414- IntArrayRef input_shape = input_t .sizes ();
422+ IntArrayRef input_shape = cdist ? input_broadcasted_shape. value () : input_t .sizes ();
415423
416424 for (int i = 0 ; i < dim.size (); i++) {
417425 auto wrap_dim = maybe_wrap_dim (dim[i], input_shape.size ());
418426 TORCH_CHECK (wrap_dim < input_shape.size (),
419427 " norm_out_mps: reduction dim must be in the range of input shape" )
420428 }
421429
422- using CachedGraph = native_mps::MPSUnaryCachedGraph ;
430+ using CachedGraph = native_mps::MPSBinaryCachedGraph ;
423431
424432 native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance ();
425433
@@ -449,6 +457,12 @@ void impl_func_norm_mps(
449457 num_output_dims,
450458 input_shape,
451459 axes);
460+
461+ if (cdist) {
462+ apparent_input_shape = [mps::getMPSShape (input_tensor.sizes ()) mutableCopy ];
463+ apparent_output_shape = [mps::getMPSShape (output_t .sizes ()) mutableCopy ];
464+ }
465+
452466 if (output_t .numel () == 0 ) {
453467 return ;
454468 }
@@ -458,7 +472,8 @@ void impl_func_norm_mps(
458472 @autoreleasepool {
459473 NSString * ns_key = [[axes valueForKey: @" description" ] componentsJoinedByString: @" ," ];
460474 string keepdim_info = (keepdim) ? " keepdim=1" : " keepdim=0" ;
461- string key = string (" norm_out_mps:" ) + [ns_key UTF8String ] + " :" + native_mps::getMPSTypeString (input_t .scalar_type ()) + " :p" + to_string (p) + " :" + keepdim_info;
475+ string op_key = cdist ? native_mps::getTensorsStringKey ({input_tensor, other_tensor}) : native_mps::getMPSTypeString (input_t .scalar_type ());
476+ string key = string (" norm_out_mps:" ) + [ns_key UTF8String ] + " :" + op_key + " :p" + to_string (p) + " :" + keepdim_info;
462477
463478 auto cachedGraph = cache_->LookUpAs <CachedGraph>(key);
464479
@@ -471,8 +486,16 @@ void impl_func_norm_mps(
471486 MPSGraph* mpsGraph = native_mps::make_mps_graph ();
472487 newCachedGraph = new CachedGraph (mpsGraph);
473488
474- MPSGraphTensor* inputTensor_ = native_mps::mpsGraphUnrankedPlaceHolder (mpsGraph, native_mps::getMPSDataType (input_t .scalar_type ()));
475- MPSGraphTensor* inputTensor = inputTensor_;
489+ if (cdist) {
490+ newCachedGraph->inputTensor_ = native_mps::mpsGraphRankedPlaceHolder (mpsGraph, input_tensor);
491+ newCachedGraph->otherTensor_ = native_mps::mpsGraphRankedPlaceHolder (mpsGraph, other_tensor);
492+ }
493+ else {
494+ newCachedGraph->inputTensor_ = native_mps::mpsGraphUnrankedPlaceHolder (mpsGraph, native_mps::getMPSDataType (input_t .scalar_type ()));
495+ }
496+
497+ MPSGraphTensor* inputTensor = cdist ? normOpBlock (newCachedGraph, newCachedGraph->inputTensor_ , newCachedGraph->otherTensor_ ) :
498+ newCachedGraph->inputTensor_ ;
476499 if (opt_dtype.has_value ()) {
477500 inputTensor = [mpsGraph castTensor: inputTensor
478501 toType: mps_input_dtype
@@ -534,7 +557,10 @@ void impl_func_norm_mps(
534557 name: nil ];
535558 }
536559
537- newCachedGraph->inputTensor_ = inputTensor_;
560+ if (cdist) {
561+ outputTensor= [mpsGraph reshapeTensor: outputTensor withShape:mps: :getMPSShape (output_t ) name: nil ];
562+ }
563+
538564 newCachedGraph->outputTensor_ = outputTensor;
539565 }
540566 return newCachedGraph;
@@ -543,6 +569,7 @@ void impl_func_norm_mps(
543569 }
544570
545571 auto inputPlaceholder = native_mps::Placeholder ();
572+ auto otherPlaceholder = native_mps::Placeholder ();
546573
547574 if (apparent_input_shape)
548575 inputPlaceholder = native_mps::Placeholder (cachedGraph->inputTensor_ , input_t , apparent_input_shape);
@@ -551,10 +578,13 @@ void impl_func_norm_mps(
551578
552579 auto outputPlaceholder = native_mps::Placeholder (cachedGraph->outputTensor_ , output_t , apparent_output_shape);
553580
581+ NSMutableDictionary <MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary ];
582+ feeds[inputPlaceholder.getMPSGraphTensor ()] = inputPlaceholder.getMPSGraphTensorData ();
554583
555- NSDictionary <MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
556- inputPlaceholder.getMPSGraphTensor () : inputPlaceholder.getMPSGraphTensorData (),
557- };
584+ if (cdist) {
585+ otherPlaceholder = native_mps::Placeholder (cachedGraph->otherTensor_ , other_tensor);
586+ feeds[otherPlaceholder.getMPSGraphTensor ()] = otherPlaceholder.getMPSGraphTensorData ();
587+ }
558588
559589 NSDictionary <MPSGraphTensor *, MPSGraphTensorData *> *results = @{
560590 outputPlaceholder.getMPSGraphTensor () : outputPlaceholder.getMPSGraphTensorData ()
@@ -570,7 +600,7 @@ void impl_func_norm_mps(
570600 IntArrayRef dim,
571601 bool keepdim,
572602 const Tensor& result) {
573- impl_func_norm_mps (self, opt_p, dim, keepdim, c10::nullopt , result);
603+ impl_func_norm_mps (self, self, opt_p, dim, keepdim, c10::nullopt , result, /* cdist= */ false );
574604}
575605
576606TORCH_IMPL_FUNC (norm_dtype_out_mps)
@@ -580,7 +610,84 @@ void impl_func_norm_mps(
580610 bool keepdim,
581611 ScalarType dtype,
582612 const Tensor& result) {
583- impl_func_norm_mps (self, opt_p, dim, keepdim, dtype, result);
613+ impl_func_norm_mps (self, self, opt_p, dim, keepdim, dtype, result, /* cdist=*/ false );
614+ }
615+
616+ Tensor _cdist_forward_mps (const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t > compute_mode) {
617+ using namespace mps ;
618+ TORCH_CHECK (x1.dim () >= 2 , " cdist only supports at least 2D tensors, X1 got: " , x1.dim (), " D" );
619+ TORCH_CHECK (x2.dim () >= 2 , " cdist only supports at least 2D tensors, X2 got: " , x2.dim (), " D" );
620+ TORCH_CHECK (x1.size (-1 ) == x2.size (-1 ), " X1 and X2 must have the same number of columns. X1: " , x1.size (-1 ), " X2: " , x2.size (-1 ));
621+ TORCH_CHECK (at::isFloatingType (x1.scalar_type ()), " cdist only supports floating-point dtypes, X1 got: " , x1.scalar_type ());
622+ auto device1 = x1.device ().type ();
623+ TORCH_CHECK (at::isFloatingType (x2.scalar_type ()), " cdist only supports floating-point dtypes, X2 got: " , x2.scalar_type ());
624+ auto device2 = x2.device ().type ();
625+ TORCH_CHECK (p >= 0 , " cdist only supports non-negative p values" );
626+ TORCH_CHECK (device1 == device2, " X1 and X2 must have the same device type. X1: " , device1, " X2: " , device2);
627+ TORCH_CHECK (x1.is_mps () && (x1.get_device () == x2.get_device ()), " device of X1 (" , x1.get_device (), " ) must match device of X2 (" , x2.get_device (), " )" );
628+
629+ int64_t c1 = x1.size (-1 );
630+ int64_t c2 = x2.size (-1 );
631+
632+ auto dim1 = x1.dim ();
633+ auto dim2 = x2.dim ();
634+ int64_t mode = compute_mode.value_or (0 );
635+ TORCH_CHECK (mode >= 0 && mode <= 2 , " possible modes: 0, 1, 2, but was: " , mode);
636+
637+ int64_t r1 = x1.size (-2 );
638+ int64_t r2 = x2.size (-2 );
639+
640+ // For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
641+ // The last two dimensions will stay the same
642+ IntArrayRef batch_tensor1 (x1.sizes ().data (), dim1 - 2 );
643+ IntArrayRef batch_tensor2 (x2.sizes ().data (), dim2 - 2 );
644+ std::vector<int64_t > expand_batch_portion = infer_size (batch_tensor1, batch_tensor2);
645+ std::vector<int64_t > tensor1_expand_size (expand_batch_portion);
646+ tensor1_expand_size.insert (tensor1_expand_size.end (), {r1, c1});
647+ std::vector<int64_t > tensor2_expand_size (expand_batch_portion);
648+ tensor2_expand_size.insert (tensor2_expand_size.end (), {r2, c2});
649+
650+ const int64_t expand_batch_product = c10::multiply_integers (expand_batch_portion);
651+ std::vector<int64_t > tensor1_view{expand_batch_product, r1, c1};
652+ std::vector<int64_t > tensor2_view{expand_batch_product, r2, c2};
653+
654+ std::vector<int64_t > output_shape (expand_batch_portion);
655+ output_shape.insert (output_shape.end (), {r1, r2});
656+ Tensor result = at::empty (output_shape, x1.options ());
657+
658+ NormOpBlock norm_op_block = ^NormOpFn (cachedGraph, x1Tensor, x2Tensor) {
659+ MPSGraph* mpsGraph = cachedGraph->graph ();
660+
661+ MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor: x1Tensor toShape: getMPSShape (tensor1_expand_size) name: nil ];
662+ MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor: inputBroadcast withShape: getMPSShape (tensor1_view) name: nil ];
663+
664+ MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor: x2Tensor toShape: getMPSShape (tensor2_expand_size) name: nil ];
665+ MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor: otherBroadcast withShape: getMPSShape (tensor2_view) name: nil ];
666+
667+ NSMutableArray <MPSGraphTensor*> *inputArray = [NSMutableArray arrayWithCapacity: tensor1_view[1 ]];
668+ NSMutableArray <MPSGraphTensor*> *otherArray = [NSMutableArray arrayWithCapacity: tensor2_view[1 ]];
669+
670+ for (const auto i : c10::irange (tensor2_view[1 ])) {
671+ inputArray[i] = inputBroadcastReshape;
672+ }
673+
674+ for (const auto i : c10::irange (tensor1_view[1 ])) {
675+ otherArray[i] = otherBroadcastReshape;
676+ }
677+
678+ MPSGraphTensor *inputTensorReshaped = [mpsGraph concatTensors: inputArray dimension: 1 interleave: YES name: nil ];
679+ MPSGraphTensor *otherTensorReshaped = [mpsGraph concatTensors: otherArray dimension: 1 interleave: NO name: nil ];
680+
681+
682+ MPSGraphTensor *inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor: inputTensorReshaped
683+ secondaryTensor: otherTensorReshaped
684+ name: nil ];
685+ return inputTensorPNorm;
686+ };
687+
688+ c10::optional<IntArrayRef> inputBroadcastSize = c10::make_optional (makeArrayRef (tensor1_view.data (), tensor1_view.size ()));
689+ impl_func_norm_mps (x1, x2, OptionalScalarRef (p), makeArrayRef<int64_t >(2 ), false , c10::nullopt , result, /* cdist=*/ true , inputBroadcastSize, norm_op_block);
690+ return result;
584691}
585692
586693Tensor std_var_common_impl_mps (
0 commit comments