diff --git a/test/inductor/test_decompose_mem_bound_mm.py b/test/inductor/test_decompose_mem_bound_mm.py index 68997635f3cfb9..2ebc11844fa51b 100644 --- a/test/inductor/test_decompose_mem_bound_mm.py +++ b/test/inductor/test_decompose_mem_bound_mm.py @@ -5,6 +5,7 @@ import torch import torch._inductor from torch._dynamo.utils import counters +from torch._inductor.fx_passes.decompose_mem_bound_mm import check_device from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck @@ -117,6 +118,29 @@ def test_decompose_bmm(self, b, m, n, k, should_decompose): ) counters.clear() + @parametrize( + "b,m,k,n,should_decompose", + [(1, 2, 2, 2, True), (2, 2, 2, 2, False)], + ) + def test_decompose_bmm_cpu(self, b, m, n, k, should_decompose): + torch._logging.set_logs(inductor=logging.DEBUG) + mat1 = torch.randn(b, m, k) + mat2 = torch.randn(b, k, n) + + counters.clear() + + module = MyModule2() + traced = torch.compile(module) + input = [mat1, mat2] + self.compare_pred(module, traced, input) + + expected_val = 1 if should_decompose else 0 + self.assertEqual( + counters["inductor"]["decompose_bmm"], + expected_val, + ) + counters.clear() + @parametrize( "m,k,n, should_decompose", [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], @@ -247,6 +271,28 @@ def test_decompose_mm(self, m, n, k, has_bias, should_decompose): ) counters.clear() + @parametrize( + "m,k,n, should_decompose", + [(1, 64, 16, True), (2, 64, 16, False), (1, 64, 32, False)], + ) + def test_decompose_mm_cpu(self, m, n, k, should_decompose): + torch._logging.set_logs(inductor=logging.DEBUG) + mat1 = torch.randn(m, k) + mat2 = torch.randn(k, n) + counters.clear() + + module = MyModule3() + traced = torch.compile(module) + input = [mat1, mat2] + self.compare_pred(module, traced, input) + + expected_val = 1 if should_decompose else 0 + self.assertEqual( + counters["inductor"]["decompose_mm"], + expected_val, + ) + counters.clear() + @parametrize( "m,k,n, should_decompose", [(20480, 5, 2, True), (20480, 32, 2, False), (2048, 2, 2, False)], @@ -347,6 +393,29 @@ def foo(x, y): # two kernels generated FileCheck().check_count(".run(", 2, exactly=True).run(code[0]) + def test_check_device(self): + m = 5 + k = 5 + n = 2 + torch._logging.set_logs(inductor=logging.DEBUG) + + input1 = torch.randn(m, k, device=GPU_TYPE) + input2 = torch.randn(k, n, device=GPU_TYPE) + self.assertTrue(check_device(input1, input2)) + self.assertFalse(check_device(input1, input2, device="cpu")) + + input1 = torch.randn(m, k) + input2 = torch.randn(k, n) + self.assertTrue(check_device(input1, input2, device="cpu")) + self.assertFalse(check_device(input1, input2)) + + input1 = torch.randn(m, k, device=GPU_TYPE) + input2 = torch.randn(k, n) + self.assertFalse(check_device(input1, input2, device="gpu")) + self.assertFalse(check_device(input1, input2, device="cpu")) + + self.assertFalse(check_device(input1, input2, device="mtia")) + if __name__ == "__main__": run_tests() diff --git a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py index d6777987c08662..a38d48f50a6847 100644 --- a/torch/_inductor/fx_passes/decompose_mem_bound_mm.py +++ b/torch/_inductor/fx_passes/decompose_mem_bound_mm.py @@ -29,8 +29,8 @@ ].get("max_other_dimention_decomposition", MAX_OTHER_DIMENSION_DECOMPOSITION) -def check_device(a: Tensor, b: Tensor) -> bool: - return a.is_cuda and b.is_cuda +def check_device(a: Tensor, b: Tensor, device="cuda") -> bool: + return (a.device.type == b.device.type) and (b.device.type == device) def realize_inputs(inputs: List[torch.fx.Node]): @@ -45,11 +45,9 @@ def should_decompose_bmm(mat1, mat2) -> bool: mat2 = mat2.meta["val"] else: return False - if not check_device(mat1, mat2): + if len(mat1.shape) != 3 or len(mat2.shape) != 3: return False - else: - if len(mat1.shape) != 3 or len(mat2.shape) != 3: - return False + if check_device(mat1, mat2, device="cuda"): if mat1.shape[0] < min_first_dimension_decomposition: return False # 2 of m, n, k must be <= MAX_OTHER_DIMENSION_DECOMPOSITION @@ -57,7 +55,11 @@ def should_decompose_bmm(mat1, mat2) -> bool: mat1.shape[2] < max_other_dimention_decomposition ) + (mat2.shape[2] < max_other_dimention_decomposition) < 2: return False - return True + return True + elif check_device(mat1, mat2, device="cpu"): + if mat1.shape[0] == 1 and mat2.shape[0] == 1: + return True + return False def should_decompose_mm(mat1, mat2) -> bool: @@ -66,13 +68,18 @@ def should_decompose_mm(mat1, mat2) -> bool: mat2 = mat2.meta["val"] else: return False + if len(mat1.shape) != 2 or len(mat2.shape) != 2: + return False return ( - check_device(mat1, mat2) - and len(mat1.shape) == 2 - and len(mat2.shape) == 2 + check_device(mat1, mat2, device="cuda") and mat1.shape[0] >= min_first_dimension_decomposition and mat2.shape[0] < max_other_dimention_decomposition and mat2.shape[1] < max_other_dimention_decomposition + ) or ( + check_device(mat1, mat2, device="cpu") + and mat1.shape[0] == 1 + and mat2.shape[0] <= 64 + and mat2.shape[1] <= 16 )