File tree Expand file tree Collapse file tree 2 files changed +6
-5
lines changed
aten/src/ATen/native/mps/operations Expand file tree Collapse file tree 2 files changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -771,12 +771,12 @@ void smooth_l1_loss_impl(
771771 MPSGraphTensor *mpsGraphOneTensor = [mpsGraph constantWithScalar: 1.0
772772 dataType: inputTensor.dataType];
773773 MPSGraphTensor *mpsGraphHalfTensor = [mpsGraph constantWithScalar: 0.5
774- dataType: MPSDataTypeFloat32 ];
774+ dataType: inputTensor.dataType ];
775775 MPSGraphTensor *betaTensor = [mpsGraph constantWithScalar: beta
776- dataType: MPSDataTypeFloat32 ];
776+ dataType: inputTensor.dataType ];
777777 // 0.5 * beta
778778 MPSGraphTensor *halfTensorMulBetaTensor = [mpsGraph constantWithScalar: beta * 0.5
779- dataType: MPSDataTypeFloat32 ];
779+ dataType: inputTensor.dataType ];
780780 // Calculating first part of the equation:
781781 // ln = 0.5(xn - yn)^2/beta, if |xn - yn| < beta
782782
Original file line number Diff line number Diff line change @@ -6242,7 +6242,8 @@ class TestConsistency(TestCase):
62426242 'torch.uint8' ],
62436243 'nn.functional.selu' : ['torch.float32' ],
62446244 'nn.functional.silu' : ['torch.float32' ],
6245- 'nn.functional.smooth_l1_loss' : ['torch.float32' ],
6245+ 'nn.functional.smooth_l1_loss' : ['torch.float32' ,
6246+ 'torch.float16' ],
62466247 'nn.functional.softmin' : ['torch.float32' ],
62476248 'nn.functional.threshold' : ['torch.float32' ,
62486249 'torch.int16' ,
@@ -6441,7 +6442,7 @@ class TestConsistency(TestCase):
64416442 'nn.functional.kl_div' : [torch .int16 , torch .int32 , torch .int64 ],
64426443 'nn.functional.nll_loss' : [torch .float32 ],
64436444 'nn.functional.padreflect' : [torch .float32 ], 'nn.functional.padreplicate' : [torch .float32 ],
6444- 'nn.functional.smooth_l1_loss' : [ torch . float16 ], ' std' : [torch .float16 ],
6445+ 'std' : [torch .float16 ],
64456446 'stft' : [torch .float32 ], 'var' : [torch .float16 ],
64466447
64476448 # These were moved from ALLOWLIST to BLOCK as they are not working
You can’t perform that action at this time.
0 commit comments