11// Copyright © 2022 Apple Inc.
22
3- #include < ATen/ATen.h>
4- #include < ATen/Tensor.h>
5- #include < ATen/Utils.h>
6- #include < ATen/native/UnaryOps.h>
7- #include < ATen/Dispatch.h>
83#include < ATen/native/Distributions.h>
94#include < ATen/native/DistributionTemplates.h>
10- #include < ATen/native/TensorIterator.h>
11- #include < ATen/mps/MPSStream.h>
125#include < ATen/native/mps/OperationUtils.h>
13- #include < torch/library.h>
146
157namespace at {
168namespace native {
@@ -198,11 +190,6 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional<Generator
198190}
199191
200192Tensor& normal_mps_out (double mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
201- TORCH_CHECK (
202- std.min ().ge (0 ).item <bool >(),
203- " normal expects all elements of std >= 0.0" );
204-
205-
206193 Tensor mean_t = empty_mps (
207194 output.sizes (),
208195 output.scalar_type (),
@@ -218,7 +205,6 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional<Generator
218205
219206Tensor& normal_mps_out (const Tensor& mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
220207 TORCH_CHECK (!std.is_complex (), " normal expects standard deviation to be non-complex" );
221- TORCH_CHECK (std.numel () == 0 || std.min ().ge (0 ).item <bool >(), " normal expects all elements of std >= 0.0" );
222208 // Check that mean and std have same number of elements
223209 TORCH_CHECK (mean.numel () == std.numel (), " normal_mps_out: mean and std must have same number of elements" )
224210
0 commit comments