Skip to content

Commit 1944a59

Browse files
DenisVieriu97kulinseth
authored andcommitted
Move passing tests to ALLOWLIST_OP (#168)
* Add passing tests to ALLOWLIST_OP * Remove tab indentation
1 parent 7006d0c commit 1944a59

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

test/test_mps.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8260,7 +8260,7 @@ class TestConsistency(TestCase):
82608260
'bitwise_or': ['b8', 'i16', 'i32', 'i64', 'u8'],
82618261
'bitwise_right_shift': ['i16', 'i32', 'i64', 'u8'],
82628262
'bitwise_xor': ['b8', 'i16', 'i32', 'i64', 'u8'],
8263-
'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
8263+
'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
82648264
'bmm': ['f32'],
82658265
'broadcast_shapes': ['f32'],
82668266
'ceil': ['f32', 'int32', 'int64', 'f16'],
@@ -8283,10 +8283,10 @@ class TestConsistency(TestCase):
82838283
'cumsum': ['f16', 'f32', 'int16', 'int32'],
82848284
'deg2rad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
82858285
'diag': ['f32', 'i32'],
8286-
'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
8286+
'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
82878287
'diagflat': ['f32', 'i32'],
8288-
'diagonal_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
8289-
'diff': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8288+
'diagonal_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8289+
'diff': ['f16', 'f32', 'i16', 'i32', 'i64'],
82908290
'dist': ['f32'],
82918291
'dot': ['f32', 'i16', 'i32', 'i64', 'u8'],
82928292
'equal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8325,6 +8325,7 @@ class TestConsistency(TestCase):
83258325
'linalg.vector_norm': ['f16', 'f32'],
83268326
'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
83278327
'log': ['b8', 'f32', 'i16', 'i32', 'u8'],
8328+
'log': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
83288329
'log10': ['b8', 'f32', 'i16', 'i32', 'u8'],
83298330
'log1p': ['b8', 'f32', 'i16', 'i32', 'u8'],
83308331
'log2': ['b8', 'f32', 'i16', 'i32', 'u8'],
@@ -8374,14 +8375,15 @@ class TestConsistency(TestCase):
83748375
'nn.functional.layer_norm': ['f32'],
83758376
'nn.functional.leaky_relu': ['f32'],
83768377
'nn.functional.linear': ['f32'],
8377-
'nn.functional.local_response_norm': ['f32'],
8378+
'nn.functional.local_response_norm': ['f32', 'i64'],
83788379
'nn.functional.margin_ranking_loss': ['f32', 'i16', 'i32'],
83798380
'nn.functional.max_pool1d': ['f32'],
83808381
'nn.functional.max_pool2d': ['f32'],
83818382
'max_pool2d_with_indices_backward': ['f32'],
83828383
'nn.functional.mse_loss': ['f16', 'f32'],
83838384
'nn.functional.normalize': ['f32'],
8384-
'nn.functional.pad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
8385+
'nn.functional.pad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8386+
'nn.functional.padcircular': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
83858387
'nn.functional.pairwise_distance': ['f16',
83868388
'f32',
83878389
'i16',
@@ -8419,16 +8421,14 @@ class TestConsistency(TestCase):
84198421
'rot90': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
84208422
'round': ['f32', 'f16', 'i16', 'i32', 'i64'],
84218423
'rsqrt': ['b8', 'f32', 'i16', 'i32', 'u8'],
8422-
'scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8423-
'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8424-
'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
8424+
'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
84258425
'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
84268426
'short': ['i16'],
84278427
'sigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'],
84288428
'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8', 'i64'],
84298429
'sin': ['b8', 'f32', 'i16', 'i32', 'u8'],
84308430
'sinh': ['b8', 'f32', 'i16', 'i32', 'u8'],
8431-
'slice_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
8431+
'slice_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
84328432
'softmax': ['f32'],
84338433
'special.ndtr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
84348434
'split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -8478,6 +8478,8 @@ class TestConsistency(TestCase):
84788478
'masked.mean': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
84798479
'masked.prod': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
84808480
'masked.sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
8481+
'nn.functional.nll_loss': ['f32'],
8482+
'byte': ['b8', 'i16', 'i32', 'i64', 'u8'],
84818483
}
84828484

84838485

@@ -8591,7 +8593,7 @@ class TestConsistency(TestCase):
85918593
'nn.functional.kl_div': ['f32'],
85928594
'nn.functional.l1_loss': ['f16', 'f32'],
85938595
'nn.functional.leaky_relu': ['f32'],
8594-
'nn.functional.local_response_norm': ['f32'],
8596+
'nn.functional.local_response_norm': ['f32', 'i64'],
85958597
'nn.functional.margin_ranking_loss': ['f32'],
85968598
'nn.functional.max_pool1d': ['f32'],
85978599
'nn.functional.max_pool2d': ['f32'],
@@ -8674,15 +8676,14 @@ class TestConsistency(TestCase):
86748676
'stft': [torch.float32], 'var': [torch.float16],
86758677
# + forward when requires_grad=True or running backward
86768678
'nn.functional.embedding': [torch.float32, torch.float16],
8679+
'index_select': [torch.float16],
86778680
'__rpow__': [torch.int64],
86788681

86798682
'as_strided_scatter': [torch.uint8],
86808683
'atan2': [torch.int64],
86818684
'bfloat16': None,
86828685
'block_diag': [torch.uint8],
8683-
'byte': None,
86848686
'chalf': None,
8685-
'diag_embed': [torch.uint8],
86868687
'diagonal_scatter': [torch.uint8],
86878688
'index_add': None,
86888689
'linalg.inv': [torch.float32],
@@ -8693,12 +8694,9 @@ class TestConsistency(TestCase):
86938694
'nn.functional.conv_transpose2d': [torch.int64],
86948695
'nn.functional.conv_transpose3d': [torch.int64, torch.float32],
86958696
'nn.functional.huber_loss': [torch.float16],
8696-
'nn.functional.local_response_norm': [torch.int64],
8697-
'nn.functional.padcircular': [torch.uint8],
8697+
'nn.functional.softplus': [torch.float32],
86988698
'pow': [torch.int64],
8699-
'select_scatter': [torch.uint8],
87008699
'sigmoid': [torch.int64],
8701-
'slice_scatter': [torch.uint8],
87028700
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below
87038701

87048702
# failure in average pooling when both ceilMode and includeZeroPadToAverage are True

0 commit comments

Comments
 (0)