diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 53785d3470ce2..b524a04aad58d 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -16,9 +16,8 @@ Tensor _mps_linear( auto weight = (weight_arg.dim() == 1) ? weight_arg.view({1, weight_arg.size(0)}) : weight_arg; - TORCH_CHECK(input.scalar_type() == ScalarType::Double - || input.scalar_type() == ScalarType::Float - || input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs"); + TORCH_CHECK(input.scalar_type() == ScalarType::Float || + input.scalar_type() == ScalarType::Half, "MPS device does not support linear for non-float inputs"); const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt)); bool is_bias_defined = bias.defined(); @@ -145,8 +144,9 @@ Tensor _mps_linear_backward_input( { TORCH_CHECK(grad_output.is_mps(), "mps_linear_backward: grad_output needs to be mps layout"); - TORCH_CHECK(weight.device().is_mps() && weight.scalar_type() == kFloat, - "mps_linear_backward: weight needs to be a dense tensor"); + TORCH_CHECK(weight.device().is_mps() && + (weight.scalar_type() == kFloat || (weight.scalar_type() == kHalf)), + "mps_linear_backward: unsupported weights data type: ", weight.scalar_type()); TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double || grad_output.scalar_type() == ScalarType::Float @@ -228,9 +228,8 @@ Tensor _mps_linear_backward_input( TORCH_CHECK(grad_output.is_mps() && input.is_mps(), "_mps_linear_backward: grad_output and input needs to be mps layout"); - TORCH_CHECK(grad_output.scalar_type() == ScalarType::Double - || grad_output.scalar_type() == ScalarType::Float - || grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs"); + TORCH_CHECK(grad_output.scalar_type() == ScalarType::Float || + grad_output.scalar_type() == ScalarType::Half, "MPS device does not support linear backward for non-float inputs"); struct CachedGraph : public MPSCachedGraph {