Skip to content

Commit 70e484b

Browse files
abhudevkulinseth
authored andcommitted
Make data type of constants flexible (#68)
1 parent 3f548d0 commit 70e484b

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

aten/src/ATen/native/mps/operations/LossOps.mm

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -771,12 +771,12 @@ void smooth_l1_loss_impl(
771771
MPSGraphTensor *mpsGraphOneTensor = [mpsGraph constantWithScalar: 1.0
772772
dataType: inputTensor.dataType];
773773
MPSGraphTensor *mpsGraphHalfTensor = [mpsGraph constantWithScalar: 0.5
774-
dataType: MPSDataTypeFloat32];
774+
dataType: inputTensor.dataType];
775775
MPSGraphTensor *betaTensor = [mpsGraph constantWithScalar: beta
776-
dataType: MPSDataTypeFloat32];
776+
dataType: inputTensor.dataType];
777777
// 0.5 * beta
778778
MPSGraphTensor *halfTensorMulBetaTensor = [mpsGraph constantWithScalar: beta * 0.5
779-
dataType: MPSDataTypeFloat32];
779+
dataType: inputTensor.dataType];
780780
// Calculating first part of the equation:
781781
// ln = 0.5(xn - yn)^2/beta, if |xn - yn| < beta
782782

test/test_mps.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6242,7 +6242,8 @@ class TestConsistency(TestCase):
62426242
'torch.uint8'],
62436243
'nn.functional.selu': ['torch.float32'],
62446244
'nn.functional.silu': ['torch.float32'],
6245-
'nn.functional.smooth_l1_loss': ['torch.float32'],
6245+
'nn.functional.smooth_l1_loss': ['torch.float32',
6246+
'torch.float16'],
62466247
'nn.functional.softmin': ['torch.float32'],
62476248
'nn.functional.threshold': ['torch.float32',
62486249
'torch.int16',
@@ -6441,7 +6442,7 @@ class TestConsistency(TestCase):
64416442
'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64],
64426443
'nn.functional.nll_loss': [torch.float32],
64436444
'nn.functional.padreflect': [torch.float32], 'nn.functional.padreplicate': [torch.float32],
6444-
'nn.functional.smooth_l1_loss': [torch.float16], 'std': [torch.float16],
6445+
'std': [torch.float16],
64456446
'stft': [torch.float32], 'var': [torch.float16],
64466447

64476448
# These were moved from ALLOWLIST to BLOCK as they are not working

0 commit comments

Comments
 (0)