3434
3535using namespace mps ;
3636
37+ NSArray <NSNumber *>* getTensorAxes (const Tensor& t) {
38+ int64_t ndim = t.dim ();
39+ auto axes = [NSMutableArray <NSNumber *> arrayWithCapacity:ndim];
40+ for (const auto i: c10::irange (ndim)) {
41+ axes[i] = [NSNumber numberWithInteger: i];
42+ }
43+ return axes;
44+ }
45+
3746void set_apparent_shapes (NSMutableArray <NSNumber *> * &apparent_out_shape,
3847 NSMutableArray <NSNumber *> * &apparent_in_shape,
3948 int64_t num_reduce_dims,
@@ -1091,19 +1100,13 @@ Tensor std_mps(
10911100Tensor min_max_mps (const Tensor& input_t ,
10921101 MPSReductionType reduction_type,
10931102 const std::string& func_name) {
1103+ TORCH_INTERNAL_ASSERT (input_t .scalar_type () != ScalarType::Long, " min/max not supported for Long dtype on MPS" );
10941104 using CachedGraph = MPSUnaryCachedGraph;
10951105
10961106 MPSGraphCache* cache_ = MPSGraphCache::getInstance ();
1097- IntArrayRef input_shape = input_t .sizes ();
1098-
1099- // Flatten the input tensor to reduce it to one value
1100- NSMutableArray <NSNumber *> *apparent_input_shape = [NSMutableArray <NSNumber *> arrayWithCapacity:1 ];
1101- int64_t num_in_elements = c10::multiply_integers (input_shape);
1102- apparent_input_shape[0 ] = [NSNumber numberWithInt: num_in_elements];
1103-
1107+ num_in_elements *= input_shape[i];
11041108 Tensor output_t = at::native::empty_mps ({}, input_t .scalar_type (), c10::nullopt , kMPS , c10::nullopt , c10::nullopt );
1105-
1106- if (output_t .numel () == 0 || num_in_elements == 0 ) {
1109+ if (output_t .numel () == 0 || input_t .numel () == 0 ) {
11071110 return output_t ;
11081111 }
11091112
@@ -1118,17 +1121,29 @@ Tensor min_max_mps(const Tensor& input_t,
11181121 MPSGraph* mpsGraph = make_mps_graph ();
11191122 newCachedGraph = new CachedGraph (mpsGraph);
11201123
1121- MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder (mpsGraph, getMPSDataType (input_t .scalar_type ()));
1124+ MPSGraphTensor* inputTensor = native_mps:: mpsGraphUnrankedPlaceHolder (mpsGraph, native_mps:: getMPSDataType (input_t .scalar_type ()));
11221125
11231126 MPSGraphTensor* outputTensor = nil ;
1127+ MPSGraphTensor* castInputTensor = nil ;
1128+
1129+ if (input_t .scalar_type () != ScalarType::Float &&
1130+ input_t .scalar_type () != ScalarType::Int &&
1131+ input_t .scalar_type () != ScalarType::Half) {
1132+ castInputTensor = [mpsGraph castTensor: inputTensor
1133+ toType: MPSDataTypeInt32
1134+ name: @" castInputTensor" ];
1135+ } else {
1136+ castInputTensor = inputTensor;
1137+ }
11241138
1139+ NSArray <NSNumber *>* axes = getTensorAxes (input_t );
11251140 if (reduction_type == MPSReductionType::MAX) {
1126- outputTensor = [mpsGraph reductionMaximumWithTensor: inputTensor
1127- axes: @[@ 0 ]
1141+ outputTensor = [mpsGraph reductionMaximumWithTensor: castInputTensor
1142+ axes: axes
11281143 name: nil ];
11291144 } else if (reduction_type == MPSReductionType::MIN) {
1130- outputTensor = [mpsGraph reductionMinimumWithTensor: inputTensor
1131- axes: @[@ 0 ]
1145+ outputTensor = [mpsGraph reductionMinimumWithTensor: castInputTensor
1146+ axes: axes
11321147 name: nil ];
11331148 }
11341149
@@ -1139,7 +1154,7 @@ Tensor min_max_mps(const Tensor& input_t,
11391154 });
11401155 }
11411156
1142- auto inputPlaceholder = Placeholder (cachedGraph->inputTensor_ , input_t , apparent_input_shape );
1157+ auto inputPlaceholder = Placeholder (cachedGraph->inputTensor_ , input_t );
11431158 auto outputPlaceholder = Placeholder (cachedGraph->outputTensor_ , output_t , @[@1 ]);
11441159
11451160 NSDictionary <MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
@@ -1175,6 +1190,7 @@ void min_max_out_mps(const Tensor& input_t,
11751190 const Tensor& indices_t ,
11761191 MPSReductionType reduction_type,
11771192 const std::string& func_name) {
1193+ TORCH_INTERNAL_ASSERT (input_t .scalar_type () != ScalarType::Long, " min/max not supported for Long dtype on MPS" );
11781194
11791195 if (output_t .numel () == 0 ) {
11801196 return ;
@@ -1222,7 +1238,7 @@ void min_max_out_mps(const Tensor& input_t,
12221238 MPSGraph* mpsGraph = make_mps_graph ();
12231239 newCachedGraph = new CachedGraph (mpsGraph);
12241240
1225- MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder (mpsGraph, getMPSDataType (input_t .scalar_type ()));
1241+ MPSGraphTensor* inputTensor = native_mps:: mpsGraphUnrankedPlaceHolder (mpsGraph, native_mps:: getMPSDataType (input_t .scalar_type ()));
12261242 MPSGraphTensor* outputTensor = nil ;
12271243 if (reduction_type == MPSReductionType::MAX) {
12281244 outputTensor = [mpsGraph reductionMaximumWithTensor: inputTensor
@@ -1240,7 +1256,7 @@ void min_max_out_mps(const Tensor& input_t,
12401256 input_t .scalar_type () != ScalarType::Int &&
12411257 input_t .scalar_type () != ScalarType::Half) {
12421258 castInputTensor = [mpsGraph castTensor: inputTensor
1243- toType: MPSDataTypeFloat32
1259+ toType: MPSDataTypeInt32
12441260 name: @" castInputTensor" ];
12451261 } else {
12461262 castInputTensor = inputTensor;
0 commit comments