diff --git a/test/test_mps.py b/test/test_mps.py index 97a015a9611ef..388f613610b04 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -27,7 +27,7 @@ from torch.testing._internal.common_dtype import get_all_dtypes, integral_types import torch.backends.mps from torch.distributions import Uniform, Exponential -from functools import partial +from functools import partial, reduce from torch.testing._internal.common_methods_invocations import ( op_db, @@ -8356,56 +8356,63 @@ class TestConsistency(TestCase): # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU` # You most likely do NOT want to modify this manually ALLOWLIST_OP = { + 'H': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'T': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__getitem__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__radd__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__rand__': ['b8', 'i16', 'i32', 'i64', 'u8'], '__rdiv__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - '__rmatmul__': ['f32'], + '__rmatmul__': ['f32', 'i16', 'i32', 'i64', 'u8'], + '__rmod__': ['f16', 'f32'], '__rmul__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__ror__': ['b8', 'i16', 'i32', 'i64', 'u8'], - '__rpow__': ['f16'], + '__rpow__': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + '__rsub__': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__rxor__': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'masked.argmax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.argmin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.log_softmax': ['f32'], - 'masked.logaddexp': ['f32'], - 'masked.logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.norm': ['f16', 'f32'], - 'masked.normalize': ['f16', 'f32'], - 'masked.softmax': ['f32'], - 'masked.softmin': ['f32'], - 'abs': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], - 'acos': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'acosh': ['b8', 'f32', 'i16', 'i32', 'u8'], + '_native_batch_norm_legit': ['f32'], + '_softmax_backward_data': ['f32'], + 'abs': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'acos': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'acosh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'addbmm': ['f32'], + 'addbmm': ['f32', 'i16', 'i32', 'i64', 'u8'], 'addcdiv': ['f32'], 'addcmul': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'addmm': ['f32'], - 'addmv': ['f32'], - 'addr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'addmm': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'addmv': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'addr': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'all': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'allclose': ['f16', 'f32'], + 'amax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'amin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'aminmax': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'angle': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'any': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'arange': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'amax': ['f32'], - 'amix': ['f32'], - 'mean': ['f32'], - 'meshgrid': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'meshgridvariadic_tensors': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'meshgridlist_of_tensors': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sum': ['f32'], - 'asin': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'asinh': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'atan': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'atan2': ['f32', 'i64'], - 'atanh': ['b8', 'f32', 'i16', 'i32', 'u8'], + 'argsort': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'argwhere': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'as_strided': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'as_strided_scatter': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'asin': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'asinh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'atan': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'atan2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'atanh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'atleast_1d': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'atleast_2d': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'atleast_3d': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'baddbmm': ['f32'], + 'baddbmm': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'bernoulli': ['f32'], + 'bfloat16': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'bincount': ['i16', 'i32', 'i64', 'u8'], 'bitwise_and': ['b8', 'i16', 'i32', 'i64', 'u8'], 'bitwise_left_shift': ['i16', 'i32', 'i64', 'u8'], 'bitwise_not': ['b8', 'i16', 'i32', 'i64', 'u8'], @@ -8413,13 +8420,30 @@ class TestConsistency(TestCase): 'bitwise_right_shift': ['i16', 'i32', 'i64', 'u8'], 'bitwise_xor': ['b8', 'i16', 'i32', 'i64', 'u8'], 'block_diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'bmm': ['f32'], + 'bmm': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'bool': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'broadcast_shapes': ['f32'], + 'broadcast_tensors': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'broadcast_to': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'bucketize': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'byte': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cartesian_prod': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'cat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'ceil': ['f32', 'int32', 'int64', 'f16'], + 'cdist': ['f32'], + 'cdouble': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'ceil': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cfloat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'chalf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'char': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cholesky': ['f32'], + 'cholesky_inverse': ['f32'], + 'cholesky_solve': ['f32'], 'chunk': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'clamp': ['f32', 'i16', 'i32', 'i64', 'u8'], 'clamp_max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -8427,243 +8451,643 @@ class TestConsistency(TestCase): 'clone': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'column_stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'combinations': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'complex': ['f16', 'f32'], 'conj': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'conj_physical': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'constant_pad_nd': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'contiguous': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'corrcoef': ['f32'], - 'cos': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'], - 'cosh': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'], - 'cov': ['f32'], - 'cumsum': ['f16', 'f32', 'int16', 'int32'], + 'copysign': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'corrcoef': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cos': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cosh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'count_nonzero': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cov': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cross': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cummax': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cummin': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'cumprod': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'cumsum': ['f32', 'i16', 'i32', 'i64', 'u8'], 'deg2rad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'diag': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'diagflat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'diagonal_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'diagonal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'diagonal_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'diagonal_scatter': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'diff': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'dist': ['f32'], - 'div': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'divfloor_rounding': ['f16', 'f32', 'u8'], - 'divtrunc_rounding': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'digamma': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'dist': ['f16', 'f32'], + 'div': ['f16', 'f32', 'u8', 'b8', 'i16', 'i32', 'i64'], 'dot': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'einsum': ['f32'], - 'equal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'double': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'dsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'dstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'einsum': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'empty': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'empty_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'eq': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'erf': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'exp': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'exp2': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], + 'equal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'erf': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'erfc': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'erfinv': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'exp': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'exp2': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'expand': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'expand_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'expm1': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'eye': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.fft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.fft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.fftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.fftshift': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.hfft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.hfft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.hfftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ifft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ifft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ifftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ifftshift': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ihfft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ihfft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.ihfftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.irfft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.irfft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.irfftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.rfft': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.rfft2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fft.rfftn': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'flatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'flip': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'fliplr': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'flipud': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'flip': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fliplr': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'flipud': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'float': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'floor': ['f32', 'f16', 'i16', 'i32', 'i64'], - 'floor_divide': ['f32', 'f16'], + 'float_power': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'floor': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'floor_divide': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fmax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fmin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'fmod': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'frac': ['f16', 'f32'], + 'frexp': ['f16', 'f32'], + 'full': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'full_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'gather': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'gcd': ['i16', 'i32', 'i64', 'u8'], 'ge': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'gradient': ['f16', 'f32', 'i16'], - 'outer': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'geqrf': ['f32'], + 'gradient': ['f16', 'f32', 'i16', 'i32', 'i64'], + 'grid_sampler_2d': ['f32'], 'gt': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'half': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'heaviside': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'histc': ['f32'], + 'histogram': ['f32'], + 'histogramdd': ['f32'], + 'hsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'hstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'hypot': ['f32'], + 'i0': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'igamma': ['f16', 'f32'], + 'igammac': ['f16', 'f32'], + 'index_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'index_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'index_put': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'index_reduce': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'index_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'inner': ['f32', 'i16', 'i32', 'i64', 'u8'], 'int': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'isclose': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'isfinite': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'isin': ['f32', 'i16', 'i32', 'i64', 'u8'], 'isinf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'isnan': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'isneginf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'isposinf': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'isreal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'kron': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'kthvalue': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'lcm': ['i16', 'i32', 'i64', 'u8'], + 'ldexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'le': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'linalg.matrix_norm': ['f16'], - 'linalg.multi_dot': ['f32'], + 'lerp': ['f32'], + 'lgamma': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'linalg.cholesky': ['f32'], + 'linalg.cholesky_ex': ['f32'], + 'linalg.cond': ['f32'], + 'linalg.cross': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'linalg.det': ['f32'], + 'linalg.eig': ['f32'], + 'linalg.eigh': ['f32'], + 'linalg.eigvals': ['f32'], + 'linalg.eigvalsh': ['f32'], + 'linalg.householder_product': ['f32'], + 'linalg.inv': ['f32'], + 'linalg.inv_ex': ['f32'], + 'linalg.ldl_factor': ['f32'], + 'linalg.ldl_factor_ex': ['f32'], + 'linalg.ldl_solve': ['f32'], + 'linalg.lstsq': ['f32'], + 'linalg.lu': ['f32'], + 'linalg.lu_factor': ['f32'], + 'linalg.lu_factor_ex': ['f32'], + 'linalg.lu_solve': ['f32'], + 'linalg.matrix_norm': ['f16', 'f32'], + 'linalg.matrix_power': ['f32'], + 'linalg.matrix_rank': ['f32'], + 'linalg.multi_dot': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'linalg.norm': ['f16', 'f32'], + 'linalg.pinv': ['f32'], + 'linalg.qr': ['f32'], + 'linalg.slogdet': ['f32'], + 'linalg.solve': ['f32'], + 'linalg.solve_ex': ['f32'], + 'linalg.solve_triangular': ['f32'], 'linalg.svd': ['f32'], + 'linalg.svdvals': ['f32'], + 'linalg.tensorinv': ['f32'], + 'linalg.tensorsolve': ['f32'], + 'linalg.vander': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'linalg.vecdot': ['f32'], 'linalg.vector_norm': ['f16', 'f32'], 'linspace': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'log': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'log': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'log10': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'log1p': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'log2': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'log_softmax': ['f32'], - 'logaddexp': ['f16', 'f32'], - 'logaddexp2': ['f16', 'f32'], + 'log': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'log10': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'log1p': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'log2': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'log_softmax': ['f32', 'b8', 'f16', 'i16', 'i32', 'i64', 'u8'], + 'logaddexp': ['f32'], + 'logaddexp2': ['f32'], + 'logcumsumexp': ['f32'], + 'logdet': ['f32'], 'logical_and': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logical_not': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logical_or': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logical_xor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'logit': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'logspace': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'logsumexp': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'logsumexp': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'long': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'lt': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'lu': ['f32'], + 'lu_solve': ['f32'], + 'lu_unpack': ['f32'], + 'mH': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'mT': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.amax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.amin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.cumprod': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.cumsum': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.log_softmax': ['f32'], + 'masked.logaddexp': ['f32'], + 'masked.logsumexp': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.mean': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.median': ['f32'], + 'masked.norm': ['f16', 'f32'], + 'masked.normalize': ['f16', 'f32'], + 'masked.prod': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.softmax': ['f32'], + 'masked.softmin': ['f32'], + 'masked.std': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked.var': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'masked_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'masked_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'masked_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'matmul': ['f32'], - 'maximum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'matmul': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'matrix_exp': ['f32'], 'max': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'maxreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], - 'maxreduction_no_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], - 'maxbinary': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], - 'minimum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'max_pool2d_with_indices_backward': ['f32'], + 'maximum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'mean': ['f16', 'f32'], + 'median': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'meshgrid': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'min': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'minreduction_with_dim': ['f16', 'f32', 'i32'], - 'minreduction_no_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], - 'minbinary': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], - 'mm': ['f32'], - 'mv': ['f32'], + 'minimum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'mm': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'mode': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'movedim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'msort': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'mul': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'multinomial': ['f32'], + 'mv': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'mvlgamma': ['f32', 'i16', 'i32', 'i64', 'u8'], 'nan_to_num': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'nanmean': ['f16', 'f32'], + 'nanmedian': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'nanquantile': ['f32'], + 'nansum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'narrow': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'narrow_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'native_batch_norm': ['f32'], + 'native_dropout_backward': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'native_layer_norm': ['f32'], 'ne': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'neg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'neg': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'new_empty': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'new_empty_strided': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'new_full': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'new_ones': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'new_zeros': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'nextafter': ['f32'], + 'nn.functional._scaled_dot_product_attention': ['f32'], + 'nn.functional.adaptive_avg_pool1d': ['f32'], + 'nn.functional.adaptive_avg_pool2d': ['f32'], + 'nn.functional.adaptive_avg_pool3d': ['f16', 'f32'], 'nn.functional.adaptive_max_pool1d': ['f32'], 'nn.functional.adaptive_max_pool2d': ['f32'], - 'nn.functional.bilinear': ['f32'], + 'nn.functional.adaptive_max_pool3d': ['f32'], + 'nn.functional.alpha_dropout': ['f32'], + 'nn.functional.avg_pool1d': ['f32', 'i64'], + 'nn.functional.avg_pool2d': ['f32', 'i64'], + 'nn.functional.avg_pool3d': ['f32', 'i64'], + 'nn.functional.batch_norm': ['f32'], + 'nn.functional.bilinear': ['f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.binary_cross_entropy': ['f32'], 'nn.functional.binary_cross_entropy_with_logits': ['f32'], 'nn.functional.celu': ['f32'], 'nn.functional.conv1d': ['f32'], 'nn.functional.conv2d': ['f32'], 'nn.functional.conv_transpose1d': ['f32'], - 'nn.functional.cosine_embedding_loss': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'nn.functional.conv_transpose2d': ['f32'], + 'nn.functional.cosine_embedding_loss': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'nn.functional.cosine_similarity': ['f32'], 'nn.functional.cross_entropy': ['f32'], + 'nn.functional.ctc_loss': ['f32'], + 'nn.functional.dropout': ['f32'], + 'nn.functional.dropout2d': ['f32'], + 'nn.functional.dropout3d': ['f32'], 'nn.functional.elu': ['f32'], - 'nn.functional.feature_alpha_dropout': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.embedding': ['f16', 'f32'], + 'nn.functional.embedding_bag': ['f16', 'f32'], + 'nn.functional.feature_alpha_dropout': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.fractional_max_pool2d': ['f32'], + 'nn.functional.fractional_max_pool3d': ['f32'], 'nn.functional.gaussian_nll_loss': ['f32'], + 'nn.functional.gelu': ['f32'], 'nn.functional.glu': ['f32'], + 'nn.functional.grid_sample': ['f32'], 'nn.functional.group_norm': ['f32'], + 'nn.functional.hardshrink': ['f32'], + 'nn.functional.hardsigmoid': ['f32'], + 'nn.functional.hardswish': ['f32'], 'nn.functional.hardtanh': ['f32', 'i16', 'i32', 'i64'], 'nn.functional.hinge_embedding_loss': ['f32'], 'nn.functional.huber_loss': ['f16', 'f32'], 'nn.functional.instance_norm': ['f32'], - 'nn.functional.kl_div': ['f32', 'i16', 'i32', 'i64'], + 'nn.functional.interpolate': ['f32', 'u8'], + 'nn.functional.kl_div': ['f32'], 'nn.functional.l1_loss': ['f16', 'f32'], 'nn.functional.layer_norm': ['f32'], 'nn.functional.leaky_relu': ['f32'], - 'nn.functional.linear': ['f32'], + 'nn.functional.linear': ['f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.local_response_norm': ['f32', 'i64'], - 'nn.functional.margin_ranking_loss': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'nn.functional.logsigmoid': ['f32'], + 'nn.functional.margin_ranking_loss': ['f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'nn.functional.max_pool1d': ['f32'], 'nn.functional.max_pool2d': ['f32'], + 'nn.functional.max_pool3d': ['f32'], + 'nn.functional.max_unpool1d': ['f32'], + 'nn.functional.max_unpool2d': ['f32'], + 'nn.functional.max_unpool3d': ['f32'], + 'nn.functional.mish': ['f32'], 'nn.functional.mse_loss': ['f16', 'f32'], + 'nn.functional.multi_margin_loss': ['f32'], + 'nn.functional.multilabel_margin_loss': ['f32'], + 'nn.functional.multilabel_soft_margin_loss': ['f32'], + 'nn.functional.nll_loss': ['f32'], 'nn.functional.normalize': ['f32'], 'nn.functional.one_hot': ['i64'], - 'nn.functional.pad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.padcircular': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.padconstant': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.padreflect': ['f32'], - 'nn.functional.padreplicate': ['f32'], - 'nn.functional.pairwise_distance': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.pixel_unshuffle': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.pixel_shuffle': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.poisson_nll_loss': ['f32', 'i16', 'i32', 'u8'], + 'nn.functional.pad': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.pairwise_distance': ['f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.pdist': ['f32'], + 'nn.functional.pixel_shuffle': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.pixel_unshuffle': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.poisson_nll_loss': ['f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'nn.functional.prelu': ['f32'], 'nn.functional.relu': ['f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.relu6': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'nn.functional.rrelu': ['f32'], 'nn.functional.selu': ['f32'], 'nn.functional.silu': ['f32'], 'nn.functional.smooth_l1_loss': ['f16', 'f32'], 'nn.functional.soft_margin_loss': ['f32'], - 'nn.functional.softmin': ['f32'], - 'nn.functional.softplus': ['f32'], - 'nn.functional.softsign': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.tanhshrink': ['f32', 'i16', 'i32', 'u8'], + 'nn.functional.softmin': ['f32', 'f16', 'i16', 'i32', 'i64', 'u8'], + 'nn.functional.softshrink': ['f32'], + 'nn.functional.softsign': ['f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.tanhshrink': ['f32', 'i16', 'i32', 'i64', 'u8'], 'nn.functional.threshold': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.triplet_margin_loss': ['f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.triplet_margin_with_distance_loss': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'nn.functional.triplet_margin_loss': ['f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.triplet_margin_with_distance_loss': ['f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'nn.functional.unfold': ['f16', 'f32'], 'nn.functional.upsample_bilinear': ['f32'], - 'nn.functional.upsample_nearest': ['f32'], + 'nn.functional.upsample_nearest': ['f32', 'u8'], + 'nonzero': ['b8', 'f32', 'i16', 'i32', 'i64'], 'norm': ['f32', 'f16'], + 'normal': ['f16', 'f32'], + 'ones': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'ones_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'ormqr': ['f32'], + 'outer': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'pca_lowrank': ['f32'], + 'permute': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'pinverse': ['f32'], + 'polar': ['f32'], + 'polygamma': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'positive': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'pow': ['f16'], + 'pow': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'prod': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'put': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'qr': ['f32'], + 'quantile': ['f32'], 'rad2deg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'rand_like': ['f16', 'f32'], + 'randint': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'randint_like': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'randn': ['f16', 'f32'], + 'randn_like': ['f16', 'f32'], + 'ravel': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'real': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'reciprocal': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8'], + 'reciprocal': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'remainder': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'renorm': ['f16', 'f32'], 'repeat': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'repeat_interleave': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'resize_': ['b8', 'i16', 'i32', 'i64', 'u8'], - 'resize_as_': ['b8', 'i16', 'i32', 'i64', 'u8'], + 'repeat_interleave': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'reshape': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'reshape_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'resize_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'resize_as_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'resolve_conj': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'resolve_neg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'reshape_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'roll': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'rot90': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'round': ['f32', 'f16', 'i16', 'i32', 'i64'], - 'rsqrt': ['b8', 'f32', 'i16', 'i32', 'u8'], + 'round': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'rsqrt': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'rsub': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'scalar_tensor': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'scatter_add': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'scatter_reduce': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'searchsorted': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'segment_reduce': ['f16', 'f32'], + 'select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'select_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'sgn': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'short': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sigmoid': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'u8', 'i64'], - 'sin': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'sinh': ['b8', 'f32', 'i16', 'i32', 'u8'], + 'sigmoid': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sign': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'signal.windows.bartlett': ['f16', 'f32'], + 'signal.windows.blackman': ['f16', 'f32'], + 'signal.windows.cosine': ['f16', 'f32'], + 'signal.windows.exponential': ['f16', 'f32'], + 'signal.windows.gaussian': ['f16', 'f32'], + 'signal.windows.general_cosine': ['f16', 'f32'], + 'signal.windows.general_hamming': ['f16', 'f32'], + 'signal.windows.hamming': ['f16', 'f32'], + 'signal.windows.hann': ['f16', 'f32'], + 'signal.windows.kaiser': ['f16', 'f32'], + 'signbit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sin': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sinc': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sinh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'slice': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'slice_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'softmax': ['f32'], + 'softmax': ['f32', 'b8', 'f16', 'i16', 'i32', 'i64', 'u8'], + 'sort': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.airy_ai': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.bessel_j0': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.bessel_j1': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.bessel_y0': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.bessel_y1': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.chebyshev_polynomial_t': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.chebyshev_polynomial_u': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.entr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.erfcx': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.hermite_polynomial_h': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.hermite_polynomial_he': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.i0e': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.i1': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.i1e': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.laguerre_polynomial_l': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.log_ndtr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.modified_bessel_i0': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.modified_bessel_i1': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.modified_bessel_k0': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.modified_bessel_k1': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], 'special.ndtr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'splitlist_args': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.ndtri': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.polygamma': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.scaled_modified_bessel_k0': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.scaled_modified_bessel_k1': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.spherical_bessel_j0': ['b8', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'special.xlog1py': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'special.zeta': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sqrt': ['b8', 'f32', 'i16', 'i32', 'u8'], + 'split_with_sizes': ['b8', + 'f16', + 'f32', + 'i16', + 'i32', + 'i64', + 'u8'], + 'sqrt': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'square': ['f16', 'f32'], 'squeeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'stack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'std': ['f16', 'f32'], + 'std_mean': ['f16', 'f32'], + 'stft': ['f32'], 'sub': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'sum_to_size': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'svd': ['f32'], + 'svd_lowrank': ['f32'], + 'symeig': ['f32'], 't': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'tan': ['b8', 'i16', 'i32', 'u8'], - 'tanh': ['b8', 'f32', 'i16', 'i32', 'u8'], - 'tensordot': ['f32'], + 'take': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'take_along_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'tan': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'tanh': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'tensor_split': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'tensordot': ['f32', 'i16', 'i32', 'i64', 'u8'], 'tile': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'topk': ['f32'], + 'to_sparse': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'topk': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'trace': ['f32', 'i16', 'i32', 'i64', 'u8'], + 'transpose': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'trapezoid': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'trapz': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'triangular_solve': ['f32'], 'tril': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'tril_indices': ['i32', 'i64'], 'triu': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'triu_indices': ['i32', 'i64'], 'true_divide': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'trunc': ['f32'], + 'trunc': ['f32', 'i16', 'i32', 'i64', 'u8'], 'unbind': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'unflatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'unfold': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'unfold_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'uniform': ['f16', 'f32'], + 'unique_consecutive': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], 'unsqueeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'var': ['f16', 'f32'], + 'var_mean': ['f16', 'f32'], + 'vdot': ['f32', 'i16', 'i32', 'i64', 'u8'], 'view': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'view_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'view_as_complex': ['f16', 'f32'], + 'view_copy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'vstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'where': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nonzero': ['f32', 'i16', 'i32', 'i64'], - 'cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'linalg.cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'unique_consecutive': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'nn.functional.nll_loss': ['f32'], - 'std': ['f16','f32'], - 'var': ['f16','f32'], - 'amax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'amin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'prod': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'mean': ['f16', 'f32'], - 'count_nonzero': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.amax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.amin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.mean': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.prod': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.std': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'masked.var': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'minreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'maxreduction_with_dim': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'xlogy': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'zeros': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], + 'zeros_like': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'] } - ALLOWLIST_OP_GRAD = { '__radd__': ['f16', 'f32'], '__rdiv__': ['f16', 'f32'], @@ -8843,104 +9267,393 @@ class TestConsistency(TestCase): # Functions that hard crash 'index_add': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], 'nn.functional.softplus': [torch.float32], - 'nonzero': [torch.uint8, torch.float16], + 'nonzero': [torch.bool, torch.uint8, torch.float16], + 'median': [torch.float32, torch.int16, torch.int32, torch.uint8, torch.int16], + 'sgn': [torch.bool], + 'linalg.inv': [torch.float32], + 'linalg.inv_ex': [torch.float32], + 'linalg.matrix_power': [torch.float32], + 'nn.functional.interpolate': [torch.float32], + 'resize_': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nn.functional.interpolatearea': [torch.float32], + 'resize_as_': [torch.float16, torch.float32], + 'topk': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # Functions with correctness issues - '__rpow__': None, 'nn.functional.avg_pool1d': [torch.float32, torch.int64], 'nn.functional.avg_pool2d': [torch.float32, torch.int64], 'unique': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], - 'index_put': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], 'divfloor_rounding': [torch.int16, torch.int32, torch.int64], 'divtrunc_rounding': [torch.float16], 'norm': [torch.float16], - - 'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], - # Functions that are flaky - # These are detected as "ok" by the expect case but actually fail to run sometimes - 'softmaxwith_dtype': None, - 'rounddecimals_neg_3': None, - 'rounddecimals_3': None, - 'rounddecimals_0': None, - 'normnuc': None, - 'nn.functional.softminwith_dtype': None, - 'nn.functional.feature_alpha_dropoutwith_train': None, - 'log_softmaxwith_dtype': None, - 'split_with_sizes': None, - 'trapezoid': None, - 'inner': None, - 'take_along_dim': None, - - # New block list ops that need investigation - 'nn.functional.conv_transpose2d': ['torch.float32'], - 'nn.functional.interpolate': ['torch.float32'], - 'topk': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - - # failures due to lack of op implementation on MPS backend - 'linalg.eig': ['torch.float32'], - 'linalg.eigvals': ['torch.float32'], - 'fft.fft': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'fft.ifft': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'fft.ihfft2': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'fft.ihfft': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'fft.ihfftn': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'fft.rfft2': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'fft.rfft': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'fft.rfftn': ['torch.bool', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'put': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'stft': ['torch.float32'], - 'nn.functional.conv_transpose3d': [torch.int64, torch.float32], + 'nn.functional.feature_alpha_dropoutwith_train': [torch.float32], + 'cumulative_trapezoid': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + '_native_batch_norm_legit': [torch.float32], + 'addr': [torch.float16], + 'as_stridedpartial_views': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'uniform': [torch.float16, torch.float32], + 'trace': [torch.int64], + 'tan': [torch.float32], + 'normalnumber_mean': [torch.float16, torch.float32], + 'nn.functional.gelu': [torch.float32], + 'nn.functional.conv_transpose2d': [torch.float32, torch.int64], + 'new_empty_strided': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'native_batch_norm': [torch.float32], + 'multinomial': [torch.float32], + 'masked.softmin': [torch.float32], + 'masked.softmax': [torch.float32], + 'masked.log_softmax': [torch.float32], + 'floor_divide': [torch.int16, torch.int32, torch.int64], + 'dist': [torch.float16], # failure due to issue: atan2() may generate NAN in output with - 'atan2': ['torch.bool', 'torch.int16', 'torch.int32', 'torch.uint8'], + 'atan2': [torch.bool, torch.int16, torch.int32, torch.uint8], # Unsupported Border padding mode - 'grid_sampler_2d': ['f16', 'f32', 'i16'], - 'nn.functional.grid_sample': ['f32'], + 'grid_sampler_2d': [torch.float32], + 'nn.functional.grid_sample': [torch.float32], # failures due to issue #103039644: Wrong results from avgPooling2DWithSourceTensor() # when both ceilMode and includeZeroPadToAverage are True - 'nn.functional.avg_pool1d': ['torch.float32', 'torch.int64'], - 'nn.functional.avg_pool2d': ['torch.float32', 'torch.int64'], - 'nn.functional.adaptive_avg_pool1d': ['torch.float32'], - 'nn.functional.adaptive_avg_pool2d': ['torch.float32'], + 'nn.functional.avg_pool1d': [torch.float32, torch.int64], + 'nn.functional.avg_pool2d': [torch.float32, torch.int64], + 'nn.functional.adaptive_avg_pool1d': [torch.float32], + 'nn.functional.adaptive_avg_pool2d': [torch.float32], + } - # failures due to issue #102048039: powerWithPrimaryTensor() with integer input may return wrong results - 'pow': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - '__rpow__': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + UNIMPLEMENTED_OPS = { + # Failures due to lack of op implementation on MPS backend + 'linalg.eig': [torch.float32], + 'linalg.eigvals': [torch.float32], + 'fft.fft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.rfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.rfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.rfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'put': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'stft': [torch.float32], + 'nn.functional.conv_transpose3d': [torch.int64, torch.float32], + 'rounddecimals_neg_3': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'rounddecimals_3': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'rounddecimals_0': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + '__rmod__': [torch.float16, torch.float32], + '__rsub__': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'aminmax': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'angle': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'argsort': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'bucketize': [torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cholesky': [torch.float32], + 'cholesky_inverse': [torch.float32], + 'cholesky_solve': [torch.float32], + 'copysign': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cummax': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cummin': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cumprod': [torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'digamma': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'erfc': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'erfinv': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fmax': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fmin': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fmod': [torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'frexp': [torch.float16, torch.float32], + 'gcd': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'geqrf': [torch.float32], + 'heaviside': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'histc': [torch.float32], + 'histogram': [torch.float32], + 'histogramdd': [torch.float32], + 'hypot': [torch.float32], + 'i0': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'igamma': [torch.float16, torch.float32], + 'igammac': [torch.float16, torch.float32], + 'index_copy': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'index_fill': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'index_reduce': [torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'isin': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'isneginf': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'isposinf': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'kthvalue': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'lcm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'ldexp': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'lerp': [torch.float32], + 'lgamma': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'linalg.cholesky': [torch.float32], + 'linalg.cholesky_ex': [torch.float32], + 'linalg.cond': [torch.float32], + 'linalg.detsingular': [torch.float32], + 'linalg.det': [torch.float32], + 'linalg.eig': [torch.float32], + 'linalg.eigh': [torch.float32], + 'linalg.eigvals': [torch.float32], + 'linalg.eigvalsh': [torch.float32], + 'linalg.householder_product': [torch.float32], + 'linalg.ldl_factor': [torch.float32], + 'linalg.ldl_factor_ex': [torch.float32], + 'linalg.ldl_solve': [torch.float32], + 'linalg.lstsq': [torch.float32], + 'linalg.lstsqgrad_oriented': [torch.float32], + 'linalg.lu': [torch.float32], + 'linalg.lu_factor': [torch.float32], + 'linalg.lu_factor_ex': [torch.float32], + 'linalg.lu_solve': [torch.float32], + 'linalg.matrix_norm': [torch.float32], + 'linalg.norm': [torch.float32], + 'linalg.normsubgradients_at_zero': [torch.float32], + 'linalg.qr': [torch.float32], + 'linalg.slogdet': [torch.float32], + 'linalg.solve': [torch.float32], + 'linalg.solve_ex': [torch.float32], + 'linalg.svdvals': [torch.float32], + 'linalg.tensorsolve': [torch.float32], + 'linalg.vander': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'linalg.vecdot': [torch.float32], + 'logcumsumexp': [torch.float32], + 'logdet': [torch.float32], + 'logit': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'lu': [torch.float32], + 'lu_solve': [torch.float32], + 'lu_unpack': [torch.float32], + 'masked.cumprod': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'masked.median': [torch.float32], + 'masked_scatter': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'matrix_exp': [torch.float32], + 'mode': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'msort': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'mvlgamma': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'mvlgammamvlgamma_p_1': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'mvlgammamvlgamma_p_3': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'mvlgammamvlgamma_p_5': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nanquantile': [torch.float32], + 'nanmean': [torch.float32, torch.float16], + 'nanmedian': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nansum': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'native_dropout_backward': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nextafter': [torch.float32], + 'normnuc': [torch.float32], + 'nn.functional._scaled_dot_product_attention': [torch.float32], + 'nn.functional.fractional_max_pool2d': [torch.float32], + 'nn.functional.fractional_max_pool3d': [torch.float32], + 'nn.functional.adaptive_avg_pool3d': [torch.float16, torch.float32], + 'nn.functional.adaptive_max_pool3d': [torch.float32], + 'nn.functional.interpolatebicubic': [torch.float32], + 'nn.functional.interpolatelinear': [torch.float32], + 'nn.functional.interpolatetrilinear': [torch.float32], + 'nn.functional.max_unpool1dgrad': [torch.float32], + 'nn.functional.max_unpool2dgrad': [torch.float32], + 'nn.functional.max_unpool3dgrad': [torch.float32], + 'nn.functional.avg_pool3d': [torch.float32, torch.int64], + 'nn.functional.ctc_loss': [torch.float32], + 'nn.functional.embedding_bag': [torch.float16, torch.float32], + 'nn.functional.max_pool2d': [torch.float32], + 'nn.functional.max_pool3d': [torch.float32], + 'nn.functional.hardshrink': [torch.float32], + 'nn.functional.hardsigmoid': [torch.float32], + 'nn.functional.logsigmoid': [torch.float32], + 'nn.functional.max_pool3d': [torch.float32], + 'nn.functional.max_unpool1d': [torch.float32], + 'nn.functional.max_unpool2d': [torch.float32], + 'nn.functional.max_unpool3d': [torch.float32], + 'nn.functional.mish': [torch.float32], + 'nn.functional.multi_margin_loss': [torch.float32], + 'nn.functional.multilabel_margin_loss': [torch.float32], + 'nn.functional.multilabel_soft_margin_loss': [torch.float32], + 'nn.functional.pdist': [torch.float32], + 'nn.functional.rrelu': [torch.float32], + 'nn.functional.softshrink': [torch.float32], + 'nn.functional.unfold': [torch.float16, torch.float32], + 'nn.functional.norm': [torch.float32], + 'ormqr': [torch.float32], + 'pca_lowrank': [torch.float32], + 'pinverse': [torch.float32], + 'polar': [torch.float32], + 'polygamma': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_3': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'polygammapolygamma_n_4': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'put': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'qr': [torch.float32], + 'quantile': [torch.float32], + 'remainder': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'renorm': [torch.float16, torch.float32], + 'roll': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'rounddecimals_0': [torch.float32], + 'rounddecimals_3': [torch.float32], + 'rounddecimals_neg_3': [torch.float32], + 'rsub': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reduceamax': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reduceamin': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reducemin': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reducemean': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reduceprod': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'scatter_reducesum': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'searchsorted': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'segment_reduce': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'segment_reduceoffsets': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'segment_reducelengths': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'sinc': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'sort': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.airy_ai': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.bessel_j0': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.bessel_j1': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.bessel_y0': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.bessel_y1': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.chebyshev_polynomial_t': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.chebyshev_polynomial_u': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.entr': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.erfcx': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.hermite_polynomial_h': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.hermite_polynomial_he': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.i0e': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.i1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.i1e': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.laguerre_polynomial_l': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.log_ndtr': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.modified_bessel_i0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.modified_bessel_i1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.modified_bessel_k0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.modified_bessel_k1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.ndtri': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.polygamma': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.polygammaspecial_polygamma_n_0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.scaled_modified_bessel_k0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.scaled_modified_bessel_k1': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.spherical_bessel_j0': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.xlog1py': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'special.zeta': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'std_mean': [torch.float16, torch.float32], + 'std_meanunbiased': [torch.float16, torch.float32], + 'svd_lowrank': [torch.float32], + 'symeig': [torch.float32], + 'take': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'to_sparse': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'var_mean': [torch.float16, torch.float32], + 'var_meanunbiased': [torch.float16, torch.float32], + 'vdot': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'view_as_complex': [torch.float16, torch.float32], + 'xlogy': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + } - # failures due to unsupported data types on MPS backend - 'matmul': ['torch.uint8'], # MPS device does not support mm for non-float inputs - 'bfloat16': None, - 'chalf': None, + EXPECTED_FAILURES = { + # Failures due to unsupported data types on MPS backend + 'matmul': [torch.uint8], # MPS device does not support mm for non-float inputs + 'bfloat16': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'chalf': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], 'nn.functional.conv1d': [torch.int64], 'nn.functional.conv2d': [torch.int64], 'nn.functional.conv_transpose1d': [torch.int64], - 'nn.functional.conv_transpose2d': [torch.int64], + 'nn.functional.softminwith_dtype': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'log_softmaxwith_dtype': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'softmaxwith_dtype': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + '__rmatmul__': [torch.int16, torch.int32, torch.uint8], + 'addmmdecomposed': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'addbmm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'addmm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'addmv': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'baddbmm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'bmm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'cdouble': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'cfloat': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'complex': [torch.float16, torch.float32], + 'double': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'einsum': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.fft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.fft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.fftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.fftshift': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.hfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.hfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.hfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ifftshift': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.ihfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.irfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.irfft2': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.irfftn': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'fft.rfft': [torch.bool, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'float_power': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'full_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'inner': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'linalg.matrix_rank': [torch.float32], + 'linalg.matrix_rankhermitian': [torch.float32], + 'linalg.multi_dot': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'linalg.pinv': [torch.float32], + 'linalg.pinvhermitian': [torch.float32], + 'log_softmax': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'log_softmaxwith_dtype': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'matmul': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'mm': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'mv': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'new_full': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'new_ones': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'new_zeros': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nn.functional.batch_norm': [torch.float32], + 'nn.functional.bilinear': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'nn.functional.linear': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'nn.functional.softmin': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'nn.functional.softminwith_dtype': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'ones_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'signal.windows.blackman': [torch.float16], + 'signal.windows.cosine': [torch.float16], + 'signal.windows.exponential': [torch.float16], + 'signal.windows.gaussian': [torch.float16], + 'signal.windows.general_cosine': [torch.float16], + 'signal.windows.general_hamming': [torch.float16], + 'signal.windows.hamming': [torch.float16], + 'signal.windows.hann': [torch.float16], + 'signal.windows.kaiser': [torch.float16], + 'softmaxwith_dtype': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'stft': [torch.float32], + 'tensordot': [torch.int16, torch.int32, torch.int64, torch.uint8], + 'zeros_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'bincount': [torch.int16, torch.int32, torch.int64, torch.uint8], + + # failures due to issue #102048039: powerWithPrimaryTensor() with integer input may return wrong results + 'pow': [torch.int16, torch.int32, torch.int64, torch.uint8], + '__rpow__': [torch.int16, torch.int32], + } + UNDEFINED_BEHAVIOUR = { # failures due to random output that they generate using # Philox engine causing mismatch with CPU results - 'rand_like': ['torch.float16', 'torch.float32'], - 'randint_like': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'randn_like': ['torch.float16', 'torch.float32'], - 'bernoulli': ['torch.float32'], - 'normal': ['torch.float16', 'torch.float32', 'torch.float16', 'torch.float32'], - 'nn.functional.dropout': ['torch.float32'], - 'nn.functional.dropout2d': ['torch.float32'], - 'nn.functional.dropout3d': ['torch.float32'], + 'rand_like': [torch.float16, torch.float32], + 'randint_like': [torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'randn_like': [torch.float16, torch.float32], + 'bernoulli': [torch.float32], + 'normal': [torch.float16, torch.float32, torch.float16, torch.float32], + 'nn.functional.alpha_dropout': [torch.float32], + 'nn.functional.dropout': [torch.float32], + 'nn.functional.dropout2d': [torch.float32], + 'nn.functional.dropout3d': [torch.float32], # these fill tensors with uninitialized data, causing mismatch with CPU - 'new_empty': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'empty_like': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], - 'empty': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + 'new_empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'empty_like': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + 'empty': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], # problem 103190467, as_strided_scatter has non-deterministic behavior when the update indices are not unique - 'as_strided_scatter': ['torch.bool', 'torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'], + 'as_strided_scatter': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + # duplicate indices are used in the testcase - undefined behaviour + 'index_put': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8], + } + FAST_MATH_PRECISION_ISSUES = { # failures due to precision issues - 'masked.var': ['f16'], - 'nn.functional.gelu': ['torch.float32'], - 'pow': ['torch.float32'], - 'tan': ['torch.float32'], - '__rpow__': ['torch.float32'], + 'tan': [torch.float32], + 'pow_': [torch.float32], + 'masked_softmin': [torch.float32], + 'masked_softmax': [torch.float32], + 'masked_log_softmax': [torch.float32], + 'cdist': [torch.float32], + '__rpow__': [torch.float32] } FP16_LOW_PRECISION_LIST = { @@ -8950,10 +9663,25 @@ class TestConsistency(TestCase): "true_divide" } + MPS_SKIP_LIST = reduce(lambda x,y: dict(x, **y), (FAST_MATH_PRECISION_ISSUES, BLOCKLIST, UNDEFINED_BEHAVIOUR, EXPECTED_FAILURES, UNIMPLEMENTED_OPS)) + # Used for accept mode only NEW_ALLOW_LIST = defaultdict(list) NEW_ALLOW_LIST_GRAD = defaultdict(list) + def get_error_message(self, key, op_name): + if key in self.FAST_MATH_PRECISION_ISSUES: + return f"Running test with {op_name} fails due to precision issues (fast math) so skipping" + elif key in self.BLOCKLIST: + return f"Running test with {op_name} fails so skipping" + elif key in self.UNDEFINED_BEHAVIOUR: + return f"Running test with {op_name} fails due to undefined behaviour / random output so skipping" + elif key in self.EXPECTED_FAILURES: + return f"Running test with {op_name} expected to fail due to unsupported MPS data type so skipping" + elif key in self.UNIMPLEMENTED_OPS: + return f"Running test with {op_name} expected to fail due to missing op implementation" + return f"Running test with {op_name} hangs so skipping" + @ops(op_db, allowed_dtypes=MPS_DTYPES) def test_output_match(self, device, dtype, op): self.assertEqual(device, "cpu") @@ -8961,9 +9689,9 @@ def test_output_match(self, device, dtype, op): self.skipTest("MPS is not available") key = op.name + op.variant_test_name - if key in self.BLOCKLIST: - if self.BLOCKLIST[key] is None or dtype in self.BLOCKLIST[key]: - self.skipTest(f"Running test with {op.name} hangs so skipping") + if key in self.MPS_SKIP_LIST: + if self.MPS_SKIP_LIST[key] is None or dtype in self.MPS_SKIP_LIST[key]: + self.skipTest(self.get_error_message(key, op.name)) # Make this an expecttest manually # When this env variable is set, generate a new ALLOWLIST_OP