diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 8df366b70de23..73cad4870f852 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -120,19 +120,21 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { switch (scalar_type) { case ScalarType::Double: case ScalarType::Float: - return "MPSDataTypeFloat32"; + return "Float32"; case ScalarType::Half: - return "MPSDataTypeFloat16"; + return "Float16"; case ScalarType::Int: - return "MPSDataTypeInt32"; + return "Int32"; case ScalarType::Long: - return "MPSDataTypeInt64"; + return "Int64"; case ScalarType::Short: - return "MPSDataTypeInt16"; + return "Int16"; + case ScalarType::Char: + return "UInt8"; case ScalarType::Byte: - return "MPSDataTypeInt8"; + return "Int8"; case ScalarType::Bool: - return "MPSDataTypeBool"; + return "Bool"; default: return "Undefined"; } @@ -316,6 +318,9 @@ void printTensorNDArray(const Tensor& t) { case MPSDataTypeInt8: v.i = scalar.to(); break; + case MPSDataTypeUInt8: + v.i = scalar.to(); + break; case MPSDataTypeBool: v.b = scalar.to(); break; diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index f4dad1052cbde..e33d6eaa8565a 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -15,11 +15,9 @@ #include #include #include + namespace at { namespace native { -namespace templates { - -} Tensor& uniform_mps_(Tensor& input, double from, double to, c10::optional gen_) { @@ -717,8 +715,6 @@ static void check_from_to_in_range(int64_t from, int64_t to_inc, ScalarType scal c10::optional gen, Tensor& result) { - std::cout<<"Multinomial MPS\n"; - TORCH_CHECK( result.device() == self.device(), "multinomial arguments must have the same device");