Skip to content

Commit ac3df3c

Browse files
razarmehrkulinseth
authored andcommitted
Remove checks that incur unnecessary syncs on GPU with tensor.item() (#74)
1 parent 060f772 commit ac3df3c

File tree

2 files changed

+2
-16
lines changed

2 files changed

+2
-16
lines changed

aten/src/ATen/native/Onehot.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) {
2323
}
2424

2525
// non-empty tensor
26-
if (self.device().type() != at::kCUDA) {
26+
if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS) {
2727
//for cuda, rely on device assert thrown by scatter
2828
TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
2929
}
3030
if (num_classes == -1) {
3131
num_classes = self.max().item().toLong() + 1;
3232
} else {
33-
if (self.device().type() != at::kCUDA) {
33+
if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS) {
3434
//rely on device asserts from scatter to avoid sync here
3535
TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
3636
} else {

aten/src/ATen/native/mps/operations/Distributions.mm

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
11
// Copyright © 2022 Apple Inc.
22

3-
#include <ATen/ATen.h>
4-
#include <ATen/Tensor.h>
5-
#include <ATen/Utils.h>
6-
#include <ATen/native/UnaryOps.h>
7-
#include <ATen/Dispatch.h>
83
#include <ATen/native/Distributions.h>
94
#include <ATen/native/DistributionTemplates.h>
10-
#include <ATen/native/TensorIterator.h>
11-
#include <ATen/mps/MPSStream.h>
125
#include <ATen/native/mps/OperationUtils.h>
13-
#include <torch/library.h>
146

157
namespace at {
168
namespace native {
@@ -198,11 +190,6 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional<Generator
198190
}
199191

200192
Tensor& normal_mps_out(double mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
201-
TORCH_CHECK(
202-
std.min().ge(0).item<bool>(),
203-
"normal expects all elements of std >= 0.0");
204-
205-
206193
Tensor mean_t = empty_mps(
207194
output.sizes(),
208195
output.scalar_type(),
@@ -218,7 +205,6 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, c10::optional<Generator
218205

219206
Tensor& normal_mps_out(const Tensor& mean, const Tensor& std, c10::optional<Generator> gen, Tensor& output) {
220207
TORCH_CHECK(!std.is_complex(), "normal expects standard deviation to be non-complex");
221-
TORCH_CHECK(std.numel() == 0 || std.min().ge(0).item<bool>(), "normal expects all elements of std >= 0.0");
222208
// Check that mean and std have same number of elements
223209
TORCH_CHECK(mean.numel() == std.numel(), "normal_mps_out: mean and std must have same number of elements")
224210

0 commit comments

Comments
 (0)