diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index f9529e37a896..1062702cf92a 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -131,8 +131,7 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string lookup_key = mps::getStridedKey(self, self.sizes(), self.strides(), - self.storage_offset()); + string lookup_key = mps::getStridedKey(self, size, stride, storage_offset); CachedGraph* cachedGraph = static_cast(cache_->LookUp(lookup_key)); if(!cachedGraph) { @@ -163,28 +162,6 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, }); cachedGraph = static_cast(tmpCachedGraph); } - } else { - // Else part takes care of the chaining where multiple view operations - // were implemented on the same underlying data storage ptr - string insert_key = mps::getStridedKey(self, size, stride, storage_offset); - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(insert_key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = cachedGraph->graph(); - newCachedGraph = new CachedGraph(mpsGraph); - newCachedGraph->inputTensor_ = cachedGraph->inputTensor_; - newCachedGraph->outputTensor_ = chainViewOperation(mpsGraph, size, - stride, - storage_offset, - cachedGraph->outputTensor_, - self); - newCachedGraph->size_ = size; - newCachedGraph->stride_ = stride; - newCachedGraph->storage_offset_ = storage_offset; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); } } } diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index a7708f1a327c..e709455d8406 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -74,7 +74,6 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { TORCH_CHECK(repeats.size() >= (size_t)self.dim(), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"); - struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} diff --git a/test/test_mps.py b/test/test_mps.py index ab0a25cc74ec..e2e34370f374 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8,20 +8,15 @@ import warnings import subprocess import os -import pprint import torch import torch.nn as nn import torch.nn.functional as F import itertools -from collections import defaultdict from torch._six import inf from torch.nn import Parameter from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN -from torch.testing._internal.common_dtype import get_all_dtypes import torch.backends.mps from torch.distributions import Uniform -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase import numpy as np @@ -4153,6 +4148,32 @@ def test_conv2d_valid_padding(self, device='mps'): actual = F.conv2d(x, y, padding='valid') self.assertEqual(expect.to('cpu'), actual.to('cpu')) + def test_gemm_permute_transpose(self): + batch_size = 32 + n = 20 + hidden = 768 + num_attention_heads = 12 + attention_head_size = hidden // num_attention_heads + + def transpose_for_scores(x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def attention2(key, *, workaround=False, device): + key = transpose_for_scores(key) + res = key.transpose(-1, -2) + return res + + A = torch.randn(batch_size, n, hidden) + A_mps = A.detach().clone().to("mps") + + r1 = attention2(A, device="cpu") + r2 = attention2(A_mps, device="mps") + + r2_cpu = r2.to("cpu") + self.assertEqual(r1, r2_cpu) + # def test_conv2d_same_padding(self, device='mps'): # x = torch.rand(1, 1, 10, 11, device=device) # y = torch.rand(1, 1, 4, 5, device=device) @@ -4366,632 +4387,7 @@ def test_legacy_constructor(self): b = a.new(1) -MPS_DTYPES = get_all_dtypes() -for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]: - del MPS_DTYPES[MPS_DTYPES.index(t)] - -class TestConsistency(TestCase): - # TODO: This is only used while some ops are being added. - # This list should contain all ops and dtypes eventually - # This can be generated automatically in the `new_mps_allowlist.txt` file - # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU` - # You most likely do NOT want to modify this manually - ALLOWLIST_OP = { - '__radd__': ['torch.bool', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - '__rand__': ['torch.bool', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - '__rmul__': ['torch.bool', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - '__ror__': ['torch.bool', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - '__rxor__': ['torch.bool', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - '_masked.normalize': ['torch.float32'], - 'abs': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'add': ['torch.bool', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'addcdiv': ['torch.float32'], - 'addcmul': ['torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'addmv': ['torch.float32'], - 'addr': ['torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'all': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'any': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'argmax': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'asin': ['torch.float32'], - 'asinh': ['torch.float32'], - 'atan': ['torch.float32'], - 'atan2': ['torch.float32'], - 'atanh': ['torch.float32'], - 'atleast_1d': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'atleast_2d': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'atleast_3d': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'baddbmm': ['torch.float32'], - 'bitwise_and': ['torch.bool', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'bitwise_left_shift': ['torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'bitwise_not': ['torch.bool', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'bitwise_or': ['torch.bool', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'bitwise_right_shift': ['torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'bitwise_xor': ['torch.bool', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'bmm': ['torch.float32'], - 'ceil': ['torch.float32'], - 'chunk': ['torch.float16', 'torch.float32', 'torch.int64'], - 'clone': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'column_stack': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'conj': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'conj_physical': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'contiguous': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'corrcoef': ['torch.float32'], - 'deg2rad': ['torch.float32'], - 'diag': ['torch.float32', 'torch.int32', 'torch.int64'], - 'diagflat': ['torch.int32', 'torch.int64'], - 'diff': ['torch.float32'], - 'dist': ['torch.float32'], - 'dot': ['torch.float32', 'torch.int32'], - 'einsum': ['torch.float32'], - 'erf': ['torch.float32'], - 'fill': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'flatten': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'floor': ['torch.float32'], - 'hstack': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'index_select': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'isinf': ['torch.float16', 'torch.float32'], - 'isnan': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'kron': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'linalg.norm': ['torch.float16', - 'torch.float32', - 'torch.float16', - 'torch.float32'], - 'linalg.svd': ['torch.float32'], - 'linalg.vector_norm': ['torch.float16'], - 'log1p': ['torch.float32'], - 'log_softmax': ['torch.float32'], - 'logaddexp': ['torch.float32'], - 'logaddexp2': ['torch.float32'], - 'masked_select': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'mm': ['torch.float32'], - 'mv': ['torch.float32'], - 'neg': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'nn.functional.adaptive_max_pool1d': ['torch.float32'], - 'nn.functional.adaptive_max_pool2d': ['torch.float32'], - 'nn.functional.binary_cross_entropy': ['torch.float32'], - 'nn.functional.celu': ['torch.float32'], - 'nn.functional.elu': ['torch.float32'], - 'nn.functional.embedding': ['torch.float16', 'torch.float32'], - 'nn.functional.feature_alpha_dropout': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'nn.functional.hardtanh': ['torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64'], - 'nn.functional.hinge_embedding_loss': ['torch.float32'], - 'nn.functional.kl_div': ['torch.float32'], - 'nn.functional.l1_loss': ['torch.float32'], - 'nn.functional.leaky_relu': ['torch.float32'], - 'nn.functional.mse_loss': ['torch.float16', 'torch.float32'], - 'nn.functional.relu': ['torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'nn.functional.relu6': ['torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'nn.functional.selu': ['torch.float32'], - 'nn.functional.silu': ['torch.float32'], - 'nn.functional.smooth_l1_loss': ['torch.float32'], - 'nn.functional.softmin': ['torch.float32'], - 'nn.functional.threshold': ['torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'nn.functional.upsample_bilinear': ['torch.float32'], - 'norm': ['torch.float32', 'torch.float16', 'torch.float32'], - 'positive': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'rad2deg': ['torch.float32'], - 'ravel': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'real': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'repeat': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'repeat_interleave': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'resize_': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'resize_as_': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'resolve_conj': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'resolve_neg': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'round': ['torch.float32'], - 'sgn': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'sign': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'sin': ['torch.float32'], - 'sinh': ['torch.float32'], - 'softmax': ['torch.float32'], - 'split': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'sqrt': ['torch.float32'], - 'square': ['torch.float32', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'squeeze': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'stack': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'sub': ['torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'sum_to_size': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'svd': ['torch.float32'], - 't': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'tanh': ['torch.float32'], - 'tensordot': ['torch.float32'], - 'tile': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'topk': ['torch.float32'], - 'tril': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'triu': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'true_divide': ['torch.float32'], - 'trunc': ['torch.float32'], - 'unsqueeze': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'view': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'view_as': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'vsplit': ['torch.bool', - 'torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'vstack': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8'], - 'zero_': ['torch.float16', - 'torch.float32', - 'torch.int16', - 'torch.int32', - 'torch.int64', - 'torch.uint8']} - - # These ops that are problematic. So never run them even when - # generating the new allowlist. - # If the dtype list is None, all dtypes are excluded. - # All the entries in this list should be removed - BLOCKLIST = { - # Functions that hang - 'masked_fill': [torch.bool, torch.uint8, torch.float32], 'where': [torch.bool], - # Functions that hard crash - 'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64], - 'nn.functional.nll_loss': [torch.float32], - 'nn.functional.padreflect': [torch.float32], 'nn.functional.padreplicate': [torch.float32], - 'nn.functional.smooth_l1_loss': [torch.float16], 'std': [torch.float16], - 'stft': [torch.float32], 'var': [torch.float16], - # Functions that are flaky - # These are detected as "ok" by the expect case but actually fail to run sometimes - 'H': None, - 'T': None, - 'as_strided': None, - 'broadcast_tensors': None, - 'broadcast': None, - 'broadcast_to': None, - 'diagonal': None, - 'divfloor_rounding': None, - 'divno_rounding_mode': None, - 'divtrunc_rounding': None, - 'dsplit': None, - 'hsplit': None, - 'empty': None, - 'expand_as': None, - 'expand': None, - 'ge': None, - 'ne': None, - 'le': None, - 'lt': None, - 'gt': None, - 'transpose': None, - 'splitlist_args': None, - 'select': None, - 'reshape': None, - 'reshape_as': None, - 'permute': None, - 'norm': None, - 'nn.functional.pixel_unshuffle': None, - 'nn.functional.pixel_shuffle': None, - 'nn.functional.cross_entropy': None, - 'nn.functional.one_hot': None, - 'narrow': None, - 'movedim': None, - 'minreduction_with_dim': None, - 'minreduction_no_dim': None, - 'minbinary': None, - 'meshgridvariadic_tensors': None, - 'meshgridlist_of_tensors': None, - 'maxreduction_with_dim': None, - 'maxreduction_no_dim': None, - 'maxbinary': None, - 'maximum': None, - 'minimum': None, - 'mT': None, - 'mH': None, - 'outer': None, - 'softmaxwith_dtype': None, - 'rounddecimals_neg_3': None, - 'rounddecimals_3': None, - 'rounddecimals_0': None, - 'normnuc': None, - 'nn.functional.softminwith_dtype': None, - 'nn.functional.feature_alpha_dropoutwith_train': None, - 'log_softmaxdtype': None, - 'split_with_sizes': None, - 'trapezoid': None, - 'eq': None, - 'mul': None, - 'cartesian_prod': None, - 'nonzero': None, - 'bool': None, - 'inner': None, - 'dstack': None, - 'take_along_dim': None, - } - - # Used for accept mode only - NEW_ALLOW_LIST = defaultdict(list) - - @ops(op_db, allowed_dtypes=MPS_DTYPES) - def test_output_match(self, device, dtype, op): - self.assertEqual(device, "cpu") - if not torch.backends.mps.is_available(): - self.skipTest("MPS is not available") - - key = op.name + op.variant_test_name - if key in self.BLOCKLIST: - if self.BLOCKLIST[key] is None or dtype in self.BLOCKLIST[key]: - self.skipTest(f"Running test with {op.name} hangs so skipping") - - # Make this an expecttest manually - # When this env variable is set, generate a new ALLOWLIST_OP - # that reflects the current state of what passes or not - if os.environ.get("EXPECTTEST_ACCEPT", None) == "1": - generate_new_truth = True - else: - generate_new_truth = False - - if not generate_new_truth: - if op.name not in self.ALLOWLIST_OP: - self.skipTest(f"{op.name} is not in the allow list for test on MPS") - else: - if str(dtype) not in self.ALLOWLIST_OP[op.name]: - self.skipTest(f"{op.name} is in the allow list for MPS but {dtype} is excluded") - try: - cpu_samples = op.sample_inputs(device, dtype) - - for cpu_sample in cpu_samples: - mps_sample = cpu_sample.transform(lambda x: x.to("mps") if isinstance(x, torch.Tensor) else x) - - # TODO: This checks only the function variant. We should also check the method and inplace version - # when they exist - cpu_args = [cpu_sample.input] + list(cpu_sample.args) - cpu_kwargs = cpu_sample.kwargs - mps_args = [mps_sample.input] + list(mps_sample.args) - mps_kwargs = mps_sample.kwargs - - cpu_out = op(*cpu_args, **cpu_kwargs) - mps_out = op(*mps_args, **mps_kwargs) - self.assertEqual(cpu_out, mps_out) - except Exception as e: - if not generate_new_truth: - raise e - else: - if generate_new_truth: - self.NEW_ALLOW_LIST[op.name].append(str(dtype)) - - # We could write it only once. But I don't know how to detect that the current test is the last one - # So each test append to the dict and write it. - with open("new_mps_allowlist.txt", "w") as f: - pprint.pprint(self.NEW_ALLOW_LIST, stream=f) - -# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. -# This requires mps to be properly registered in the device generic test framework which is not the -# case right now. -instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu") if __name__ == "__main__": run_tests()