Skip to content

Commit

Permalink
Add support for randperm (#160)
Browse files Browse the repository at this point in the history
* Add mps support for randperm

* Add testcase for randperm

* Address PR comments

* Fix randperm string key for graph caching

* Address remaining PR comments

* Fix warning message
  • Loading branch information
DenisVieriu97 authored Nov 2, 2022
1 parent d60a826 commit 7dfb388
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 1 deletion.
40 changes: 40 additions & 0 deletions aten/src/ATen/native/mps/operations/Distributions.mm
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/native/DistributionTemplates.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/mps/MPSGeneratorImpl.h>
#include <ATen/native/TensorFactories.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -341,6 +342,45 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional<Generator
"exponential_mps_:" + std::to_string(lambda), random_op_block);
}

Tensor& randperm_out_mps(int64_t n, c10::optional<Generator> generator, Tensor& result) {
if (!MPSDevice::getInstance()->macOS_13_0()) {
TORCH_WARN_ONCE("MPS: randperm op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performance implications.");

result = result.to("cpu");
result = at::randperm_out(result, n).to("mps");
return result;
}

TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
TORCH_CHECK(!generator.has_value() ||
(generator.has_value() && result.device() == generator->device()),
"Expected a '", result.device(), "' generator device but found '", generator->device(), "'");
check_supported_max_int_with_precision(n, result);

result.resize_({n});
if (n == 0) {
return result;
}

mps::RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* argsortTensor = [mpsGraph argSortWithTensor:randomTensor
axis:0
name:nil];
if (result.scalar_type() != kInt) {
argsortTensor = [mpsGraph castTensor:argsortTensor
toType:mps::getMPSDataType(result.scalar_type())
name:@"castOutput"];
}
return argsortTensor;
};

return mps::random_mps_impl<int64_t>(result, 0.0, 1.0, c10::nullopt, c10::nullopt,
MPSGraphRandomDistributionUniform, generator,
"ranperm_out_mps:" + mps::getTensorsStringKey({result}), random_op_block);
}

Tensor& multinomial_with_replacement_mps_kernel(
const Tensor& self,
const int64_t n_sample,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ void index_put_kernel_mps(TensorIterator& iter, IntArrayRef index_size, IntArray
static
Tensor nonzero_fallback(const Tensor& self) {
TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performace implications.");
"Falling back on CPU. This may have performance implications.");

return at::nonzero(self.to("cpu")).clone().to("mps");
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4188,6 +4188,7 @@
dispatch:
CPU: randperm_out_cpu
CUDA: randperm_out_cuda
MPS: randperm_out_mps

- func: range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
dispatch:
Expand Down
34 changes: 34 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,40 @@ def test_uniform(self):
low.grad.zero_()
high.grad.zero_()

def test_randperm(self, device="mps"):
rng_device = None
for n in (5, 100, 50000, 100000):
for dtype in (torch.long, torch.half, torch.float):
if n > 2049 and dtype == torch.half: # Large n for torch.half will raise an exception, do not test here.
continue
if n > 256 and dtype == torch.bfloat16:
continue
with torch.random.fork_rng(devices=rng_device):
res1 = torch.randperm(n, dtype=dtype, device=device)
res2 = torch.empty(0, dtype=dtype, device=device)
torch.randperm(n, out=res2, dtype=dtype, device=device)
self.assertEqual(res1.cpu().sort().values.long(), torch.arange(n, device=device))

# Default type is long
for n in (100, 10000):
self.assertEqual(torch.randperm(n, device=device).dtype, torch.long)

# randperm of 0 elements is an empty tensor
res1 = torch.randperm(0)
res2 = torch.tensor(5, dtype=dtype, device=device)
torch.randperm(0, out=res2)
self.assertEqual(res1.numel(), 0)
self.assertEqual(res2.numel(), 0)

# Test non-contiguous tensors
for n in (4, 5, 6, 10, 20):
non_contiguous_tensor = torch.zeros((2, 3), dtype=torch.long, device=device).t()
self.assertFalse(non_contiguous_tensor.is_contiguous())
with torch.random.fork_rng(devices=rng_device):
res = torch.randperm(n, dtype=torch.long, device=device)
torch.randperm(n, out=non_contiguous_tensor)
self.assertEqual(res.cpu().sort().values.long(), torch.arange(n, device=device))

# Test forward maxpool2d
def test_max_pool2d(self):
def helper(shape, ks, padding=0, dilation=1, ceil_mode=False, return_indices=False, test_ties=False):
Expand Down

0 comments on commit 7dfb388

Please sign in to comment.