Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ void set_axes_and_shapes(const Tensor& input_t,
axes:axes
name:nil];
}
else if(reduction_type == "amax") {
castOutputTensor = [mpsGraph reductionMaximumWithTensor:inputTensor
axes:axes
name:nil];
} else if(reduction_type == "amin") {
castOutputTensor = [mpsGraph reductionMinimumWithTensor:inputTensor
axes:axes
name:nil];
}

MPSGraphTensor* outputTensor = nil;

Expand Down Expand Up @@ -294,6 +303,24 @@ inline ScalarType get_dtype_from_self(
return src_type;
}

TORCH_IMPL_FUNC(amax_out_mps)
(const Tensor& input_t,
IntArrayRef dim,
bool keepdim,
const Tensor& output_t) {

reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amax", "amax_out_mps");
}

TORCH_IMPL_FUNC(amin_out_mps)
(const Tensor& input_t,
IntArrayRef dim,
bool keepdim,
const Tensor& output_t) {

reduction_out_mps(input_t, dim, keepdim, c10::nullopt, output_t, "amin", "amin_out_mps");
}

Tensor prod_mps(const Tensor &self, c10::optional<ScalarType> opt_dtype) {

auto num_dims = self.dim();
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3134,6 +3134,7 @@
structured: True
dispatch:
CPU, CUDA: amax_out
MPS: amax_out_mps

# Return: (Tensor output, Tensor indices)
- func: max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)
Expand Down Expand Up @@ -3290,6 +3291,7 @@
structured: True
dispatch:
CPU, CUDA: amin_out
MPS: amin_out_mps

# TODO: Add this function to MPS dispatch key so that we avoid declaring it in
# native_functions.yaml
Expand Down
44 changes: 44 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2755,6 +2755,50 @@ def helper(shape):

helper((4, 5, 6, 7))

# Test forward amax
def test_amax(self):
def helper(shape, dim, keepdim):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

result = torch.amax(x, dim=dim, keepdim=keepdim)
result_cpu = torch.amax(cpu_x, dim=dim, keepdim=keepdim)

cpu_grad = torch.randn(result_cpu.shape)
grad = cpu_grad.to('mps')

result_cpu.backward(gradient=cpu_grad)
result.backward(gradient=grad)

self.assertEqual(result, result_cpu)
self.assertEqual(x.grad, cpu_x.grad)

for dim in ([], [0], [0, 1], [2, 3]):
for keepdim in [False, True]:
helper((2, 8, 4, 5), dim, keepdim)

# Test forward amin
def test_amin(self):
def helper(shape, dim, keepdim):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

result = torch.amin(x, dim=dim, keepdim=keepdim)
result_cpu = torch.amin(cpu_x, dim=dim, keepdim=keepdim)

cpu_grad = torch.randn(result_cpu.shape)
grad = cpu_grad.to('mps')

result_cpu.backward(gradient=cpu_grad)
result.backward(gradient=grad)

self.assertEqual(result, result_cpu)
self.assertEqual(x.grad, cpu_x.grad)

for dim in ([], [0], [0, 1], [2, 3]):
for keepdim in [False, True]:
helper((2, 8, 4, 5), dim, keepdim)

# Test minimum and maximum
def test_minimum_maximum(self):
def helper(n, c, h, w):
Expand Down