diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp index 7455e27a1701e..a0c061062174b 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -23,14 +23,14 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { } // non-empty tensor - if (self.device().type() != at::kCUDA) { + if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS) { //for cuda, rely on device assert thrown by scatter TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative."); } if (num_classes == -1) { num_classes = self.max().item().toLong() + 1; } else { - if (self.device().type() != at::kCUDA) { + if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS) { //rely on device asserts from scatter to avoid sync here TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes."); } else { diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index a4b73bd75fb03..21162ef74e21f 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -1,16 +1,8 @@ // Copyright © 2022 Apple Inc. -#include -#include -#include -#include -#include #include #include -#include -#include #include -#include namespace at { namespace native { @@ -198,11 +190,6 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional gen, Tensor& output) { - TORCH_CHECK( - std.min().ge(0).item(), - "normal expects all elements of std >= 0.0"); - - Tensor mean_t = empty_mps( output.sizes(), output.scalar_type(), @@ -218,7 +205,6 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional gen, Tensor& output) { TORCH_CHECK(!std.is_complex(), "normal expects standard deviation to be non-complex"); - TORCH_CHECK(std.numel() == 0 || std.min().ge(0).item(), "normal expects all elements of std >= 0.0"); // Check that mean and std have same number of elements TORCH_CHECK(mean.numel() == std.numel(), "normal_mps_out: mean and std must have same number of elements")