Skip to content

Commit 35103bd

Browse files
committed
Exclude long dtype from reduction ops (min/max) (#138)
* Exclude long dtype from reduction ops (min/max) * Remove tab identation
1 parent 94a6d72 commit 35103bd

File tree

2 files changed

+51
-17
lines changed

2 files changed

+51
-17
lines changed

aten/src/ATen/native/mps/operations/ReduceOps.mm

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@
3434

3535
using namespace mps;
3636

37+
NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
38+
int64_t ndim = t.dim();
39+
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
40+
for (const auto i: c10::irange(ndim)) {
41+
axes[i] = [NSNumber numberWithInteger:i];
42+
}
43+
return axes;
44+
}
45+
3746
void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
3847
NSMutableArray<NSNumber*> * &apparent_in_shape,
3948
int64_t num_reduce_dims,
@@ -1091,19 +1100,13 @@ Tensor std_mps(
10911100
Tensor min_max_mps(const Tensor& input_t,
10921101
MPSReductionType reduction_type,
10931102
const std::string& func_name) {
1103+
TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "min/max not supported for Long dtype on MPS");
10941104
using CachedGraph = MPSUnaryCachedGraph;
10951105

10961106
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
1097-
IntArrayRef input_shape = input_t.sizes();
1098-
1099-
// Flatten the input tensor to reduce it to one value
1100-
NSMutableArray<NSNumber*> *apparent_input_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
1101-
int64_t num_in_elements = c10::multiply_integers(input_shape);
1102-
apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements];
1103-
1107+
num_in_elements *= input_shape[i];
11041108
Tensor output_t = at::native::empty_mps({}, input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
1105-
1106-
if (output_t.numel() == 0 || num_in_elements == 0) {
1109+
if (output_t.numel() == 0 || input_t.numel() == 0) {
11071110
return output_t;
11081111
}
11091112

@@ -1118,17 +1121,29 @@ Tensor min_max_mps(const Tensor& input_t,
11181121
MPSGraph* mpsGraph = make_mps_graph();
11191122
newCachedGraph = new CachedGraph(mpsGraph);
11201123

1121-
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
1124+
MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
11221125

11231126
MPSGraphTensor* outputTensor = nil;
1127+
MPSGraphTensor* castInputTensor = nil;
1128+
1129+
if(input_t.scalar_type() != ScalarType::Float &&
1130+
input_t.scalar_type() != ScalarType::Int &&
1131+
input_t.scalar_type() != ScalarType::Half) {
1132+
castInputTensor = [mpsGraph castTensor:inputTensor
1133+
toType:MPSDataTypeInt32
1134+
name:@"castInputTensor"];
1135+
} else {
1136+
castInputTensor = inputTensor;
1137+
}
11241138

1139+
NSArray<NSNumber*>* axes = getTensorAxes(input_t);
11251140
if (reduction_type == MPSReductionType::MAX) {
1126-
outputTensor = [mpsGraph reductionMaximumWithTensor: inputTensor
1127-
axes: @[@0]
1141+
outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor
1142+
axes:axes
11281143
name: nil];
11291144
} else if (reduction_type == MPSReductionType::MIN) {
1130-
outputTensor = [mpsGraph reductionMinimumWithTensor: inputTensor
1131-
axes: @[@0]
1145+
outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor
1146+
axes:axes
11321147
name: nil];
11331148
}
11341149

@@ -1139,7 +1154,7 @@ Tensor min_max_mps(const Tensor& input_t,
11391154
});
11401155
}
11411156

1142-
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape);
1157+
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
11431158
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, @[@1]);
11441159

11451160
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
@@ -1175,6 +1190,7 @@ void min_max_out_mps(const Tensor& input_t,
11751190
const Tensor& indices_t,
11761191
MPSReductionType reduction_type,
11771192
const std::string& func_name) {
1193+
TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "min/max not supported for Long dtype on MPS");
11781194

11791195
if (output_t.numel() == 0) {
11801196
return;
@@ -1222,7 +1238,7 @@ void min_max_out_mps(const Tensor& input_t,
12221238
MPSGraph* mpsGraph = make_mps_graph();
12231239
newCachedGraph = new CachedGraph(mpsGraph);
12241240

1225-
MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
1241+
MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
12261242
MPSGraphTensor* outputTensor = nil;
12271243
if (reduction_type == MPSReductionType::MAX) {
12281244
outputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
@@ -1240,7 +1256,7 @@ void min_max_out_mps(const Tensor& input_t,
12401256
input_t.scalar_type() != ScalarType::Int &&
12411257
input_t.scalar_type() != ScalarType::Half) {
12421258
castInputTensor = [mpsGraph castTensor:inputTensor
1243-
toType:MPSDataTypeFloat32
1259+
toType:MPSDataTypeInt32
12441260
name:@"castInputTensor"];
12451261
} else {
12461262
castInputTensor = inputTensor;

test/test_mps.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,6 +1880,24 @@ def helper(x, other):
18801880
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(True))
18811881
helper(self._wrap_tensor((1, 0, 1, 0)), self._wrap_tensor(False))
18821882

1883+
def test_min_max(self):
1884+
def helper(dtype):
1885+
for _ in range(10):
1886+
if dtype == torch.float32 or dtype == torch.float16:
1887+
x = torch.randn((30, 15), device='mps', dtype=dtype)
1888+
else:
1889+
x = torch.randint(0, 100, (30, 15), device="mps", dtype=dtype)
1890+
x_cpu = x.to("cpu")
1891+
1892+
y = x.max()
1893+
y_cpu = x_cpu.max()
1894+
self.assertEqual(y, y_cpu)
1895+
1896+
z = x.min()
1897+
z_cpu = x_cpu.min()
1898+
self.assertEqual(z, z_cpu)
1899+
1900+
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
18831901

18841902
class TestSmoothL1Loss(TestCase):
18851903

0 commit comments

Comments
 (0)