Skip to content

Commit e4c3074

Browse files
authored
Shorten MPSDataType strings and clean up debug message for Multinomial (#50)
1 parent cc15b85 commit e4c3074

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,21 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
120120
switch (scalar_type) {
121121
case ScalarType::Double:
122122
case ScalarType::Float:
123-
return "MPSDataTypeFloat32";
123+
return "Float32";
124124
case ScalarType::Half:
125-
return "MPSDataTypeFloat16";
125+
return "Float16";
126126
case ScalarType::Int:
127-
return "MPSDataTypeInt32";
127+
return "Int32";
128128
case ScalarType::Long:
129-
return "MPSDataTypeInt64";
129+
return "Int64";
130130
case ScalarType::Short:
131-
return "MPSDataTypeInt16";
131+
return "Int16";
132+
case ScalarType::Char:
133+
return "UInt8";
132134
case ScalarType::Byte:
133-
return "MPSDataTypeInt8";
135+
return "Int8";
134136
case ScalarType::Bool:
135-
return "MPSDataTypeBool";
137+
return "Bool";
136138
default:
137139
return "Undefined";
138140
}
@@ -316,6 +318,9 @@ void printTensorNDArray(const Tensor& t) {
316318
case MPSDataTypeInt8:
317319
v.i = scalar.to<int8_t>();
318320
break;
321+
case MPSDataTypeUInt8:
322+
v.i = scalar.to<uint8_t>();
323+
break;
319324
case MPSDataTypeBool:
320325
v.b = scalar.to<bool>();
321326
break;

aten/src/ATen/native/mps/operations/Distributions.mm

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,9 @@
1515
#include <ATen/NativeFunctions.h>
1616
#include <ATen/AccumulateType.h>
1717
#include <torch/library.h>
18+
1819
namespace at {
1920
namespace native {
20-
namespace templates {
21-
22-
}
2321

2422
Tensor& uniform_mps_(Tensor& input, double from, double to, c10::optional<Generator> gen_)
2523
{
@@ -717,8 +715,6 @@ static void check_from_to_in_range(int64_t from, int64_t to_inc, ScalarType scal
717715
c10::optional<Generator> gen,
718716
Tensor& result) {
719717

720-
std::cout<<"Multinomial MPS\n";
721-
722718
TORCH_CHECK(
723719
result.device() == self.device(),
724720
"multinomial arguments must have the same device");

0 commit comments

Comments
 (0)