Skip to content

Commit cf803ac

Browse files
committed
Add passing tests to ALLOWLIST_OP
1 parent 6777991 commit cf803ac

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

test/test_mps.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7822,7 +7822,7 @@ class TestConsistency(TestCase):
78227822
'bitwise_or': ['b8', 'i16', 'i32', 'i64', 'u8'],
78237823
'bitwise_right_shift': ['i16', 'i32', 'i64', 'u8'],
78247824
'bitwise_xor': ['b8', 'i16', 'i32', 'i64', 'u8'],
7825-
'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
7825+
'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
78267826
'bmm': ['f32'],
78277827
'broadcast_shapes': ['f32'],
78287828
'cat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -7845,9 +7845,9 @@ class TestConsistency(TestCase):
78457845
'cumsum': ['f16', 'f32', 'int16', 'int32'],
78467846
'deg2rad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
78477847
'diag': ['f32', 'i32'],
7848-
'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
7848+
'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
78497849
'diagflat': ['f32', 'i32'],
7850-
'diagonal_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
7850+
'diagonal_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
78517851
'diff': ['f16', 'f32', 'i16', 'i32', 'i64'],
78527852
'dist': ['f32'],
78537853
'div': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -7886,6 +7886,7 @@ class TestConsistency(TestCase):
78867886
'linalg.vector_norm': ['f16', 'f32'],
78877887
'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
78887888
'log': ['b8', 'f32', 'i16', 'i32', 'u8'],
7889+
'log': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
78897890
'log10': ['b8', 'f32', 'i16', 'i32', 'u8'],
78907891
'log1p': ['b8', 'f32', 'i16', 'i32', 'u8'],
78917892
'log2': ['b8', 'f32', 'i16', 'i32', 'u8'],
@@ -7921,6 +7922,7 @@ class TestConsistency(TestCase):
79217922
'nn.functional.cosine_similarity': ['f32'],
79227923
'nn.functional.elu': ['f32'],
79237924
'nn.functional.feature_alpha_dropout': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
7925+
'nn.functional.embedding': ['f16', 'f32'],
79247926
'nn.functional.gaussian_nll_loss': ['f32'],
79257927
'nn.functional.glu': ['f32'],
79267928
'nn.functional.group_norm': ['f32'],
@@ -7933,11 +7935,12 @@ class TestConsistency(TestCase):
79337935
'nn.functional.layer_norm': ['f32'],
79347936
'nn.functional.leaky_relu': ['f32'],
79357937
'nn.functional.linear': ['f32'],
7936-
'nn.functional.local_response_norm': ['f32'],
7938+
'nn.functional.local_response_norm': ['f32', 'i64'],
79377939
'nn.functional.margin_ranking_loss': ['f32', 'i16', 'i32'],
79387940
'nn.functional.mse_loss': ['f16', 'f32'],
79397941
'nn.functional.normalize': ['f32'],
7940-
'nn.functional.pad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
7942+
'nn.functional.pad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
7943+
'nn.functional.padcircular': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
79417944
'nn.functional.pairwise_distance': ['f16',
79427945
'f32',
79437946
'i16',
@@ -7982,14 +7985,14 @@ class TestConsistency(TestCase):
79827985
'rot90': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
79837986
'round': ['f32', 'f16', 'i16', 'i32', 'i64'],
79847987
'rsqrt': ['b8', 'f32', 'i16', 'i32', 'u8'],
7985-
'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
7988+
'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
79867989
'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
79877990
'short': ['i16'],
79887991
'sigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'],
79897992
'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8', 'i64'],
79907993
'sin': ['b8', 'f32', 'i16', 'i32', 'u8'],
79917994
'sinh': ['b8', 'f32', 'i16', 'i32', 'u8'],
7992-
'slice_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
7995+
'slice_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
79937996
'softmax': ['f32'],
79947997
'special.ndtr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
79957998
'split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8027,6 +8030,7 @@ class TestConsistency(TestCase):
80278030
'linalg.cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
80288031
'unique_consecutive': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
80298032
'nn.functional.nll_loss': ['f32'],
8033+
'byte': ['b8', 'i16', 'i32', 'i64', 'u8'],
80308034
}
80318035

80328036

@@ -8140,7 +8144,7 @@ class TestConsistency(TestCase):
81408144
'nn.functional.kl_div': ['f32'],
81418145
'nn.functional.l1_loss': ['f16', 'f32'],
81428146
'nn.functional.leaky_relu': ['f32'],
8143-
'nn.functional.local_response_norm': ['f32'],
8147+
'nn.functional.local_response_norm': ['f32', 'i64'],
81448148
'nn.functional.margin_ranking_loss': ['f32'],
81458149
'nn.functional.mse_loss': ['f32'],
81468150
'nn.functional.pad': ['f16', 'f32', 'i16', 'i32', 'i64'],
@@ -8219,17 +8223,14 @@ class TestConsistency(TestCase):
82198223
'stft': [torch.float32], 'var': [torch.float16],
82208224
# + forward when requires_grad=True or running backward
82218225
'index_select': [torch.float16],
8222-
'nn.functional.embedding': [torch.float32, torch.float16],
82238226
'__rpow__': [torch.int64],
82248227
'masked.std': [torch.int32],
82258228
'masked.var': [torch.int32],
82268229
'as_strided_scatter': [torch.uint8],
82278230
'atan2': [torch.int64],
82288231
'bfloat16': None,
82298232
'block_diag': [torch.uint8],
8230-
'byte': None,
82318233
'chalf': None,
8232-
'diag_embed': [torch.uint8],
82338234
'diagonal_scatter': [torch.uint8],
82348235
'index_add': None,
82358236
'long': None,
@@ -8241,13 +8242,9 @@ class TestConsistency(TestCase):
82418242
'nn.functional.conv_transpose2d': [torch.int64],
82428243
'nn.functional.conv_transpose3d': [torch.int64, torch.float32],
82438244
'nn.functional.huber_loss': [torch.float16],
8244-
'nn.functional.local_response_norm': [torch.int64],
8245-
'nn.functional.padcircular': [torch.uint8],
82468245
'nn.functional.softplus': [torch.float32],
82478246
'pow': [torch.int64],
8248-
'select_scatter': [torch.uint8],
82498247
'sigmoid': [torch.int64],
8250-
'slice_scatter': [torch.uint8],
82518248
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below
82528249
'unique': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
82538250
'nonzero': [torch.uint8, torch.float16],

0 commit comments

Comments
 (0)