Skip to content

Commit 4ce1278

Browse files
committed
Add amax and amin with tests
1 parent 3cec28f commit 4ce1278

File tree

3 files changed

+73
-0
lines changed

3 files changed

+73
-0
lines changed

aten/src/ATen/native/mps/operations/ReduceOps.mm

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,15 @@ void set_axes_and_shapes(const Tensor& input_t,
220220
axes:axes
221221
name:nil];
222222
}
223+
else if(reduction_type == "amax") {
224+
castOutputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
225+
axes:axes
226+
name:nil];
227+
} else if(reduction_type == "amin") {
228+
castOutputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor
229+
axes:axes
230+
name:nil];
231+
}
223232

224233
MPSGraphTensor* outputTensor = nil;
225234

@@ -294,6 +303,24 @@ inline ScalarType get_dtype_from_self(
294303
return src_type;
295304
}
296305

306+
TORCH_IMPL_FUNC(amax_out_mps)
307+
(const Tensor& input_t,
308+
IntArrayRef dim,
309+
bool keepdim,
310+
const Tensor& output_t) {
311+
312+
reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amax", "amax_out_mps");
313+
}
314+
315+
TORCH_IMPL_FUNC(amin_out_mps)
316+
(const Tensor& input_t,
317+
IntArrayRef dim,
318+
bool keepdim,
319+
const Tensor& output_t) {
320+
321+
reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amin", "amin_out_mps");
322+
}
323+
297324
Tensor prod_mps(const Tensor &self, c10::optional<ScalarType> opt_dtype) {
298325

299326
auto num_dims = self.dim();

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3134,6 +3134,7 @@
31343134
structured: True
31353135
dispatch:
31363136
CPU, CUDA: amax_out
3137+
MPS: amax_out_mps
31373138

31383139
# Return: (Tensor output, Tensor indices)
31393140
- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
@@ -3290,6 +3291,7 @@
32903291
structured: True
32913292
dispatch:
32923293
CPU, CUDA: amin_out
3294+
MPS: amin_out_mps
32933295

32943296
# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
32953297
# native_functions.yaml

test/test_mps.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2755,6 +2755,50 @@ def helper(shape):
27552755

27562756
helper((4, 5, 6, 7))
27572757

2758+
# Test forward amax
2759+
def test_amax(self):
2760+
def helper(shape, dim, keepdim):
2761+
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2762+
x = cpu_x.detach().clone().to('mps').requires_grad_()
2763+
2764+
result = torch.amax(x, dim=dim, keepdim=keepdim)
2765+
result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)
2766+
2767+
cpu_grad = torch.randn(result_cpu.shape)
2768+
grad = cpu_grad.to('mps')
2769+
2770+
result_cpu.backward(gradient=cpu_grad)
2771+
result.backward(gradient=grad)
2772+
2773+
self.assertEqual(result, result_cpu)
2774+
self.assertEqual(x.grad, cpu_x.grad)
2775+
2776+
for dim in ([], [0], [0, 1], [2, 3]):
2777+
for keepdim in [False, True]:
2778+
helper((2, 8, 4, 5), dim, keepdim)
2779+
2780+
# Test forward amin
2781+
def test_amin(self):
2782+
def helper(shape, dim, keepdim):
2783+
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
2784+
x = cpu_x.detach().clone().to('mps').requires_grad_()
2785+
2786+
result = torch.amin(x, dim=dim, keepdim=keepdim)
2787+
result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)
2788+
2789+
cpu_grad = torch.randn(result_cpu.shape)
2790+
grad = cpu_grad.to('mps')
2791+
2792+
result_cpu.backward(gradient=cpu_grad)
2793+
result.backward(gradient=grad)
2794+
2795+
self.assertEqual(result, result_cpu)
2796+
self.assertEqual(x.grad, cpu_x.grad)
2797+
2798+
for dim in ([], [0], [0, 1], [2, 3]):
2799+
for keepdim in [False, True]:
2800+
helper((2, 8, 4, 5), dim, keepdim)
2801+
27582802
# Test minimum and maximum
27592803
def test_minimum_maximum(self):
27602804
def helper(n, c, h, w):

0 commit comments

Comments
 (0)