diff --git a/test/quantization/test_da8w4_cpu.py b/test/quantization/test_da8w4_cpu.py index d4f68c4333..80094beb2d 100644 --- a/test/quantization/test_da8w4_cpu.py +++ b/test/quantization/test_da8w4_cpu.py @@ -8,6 +8,7 @@ import unittest import torch +from torch._dynamo.utils import counters from torch.testing._internal import common_utils from torch.testing._internal.common_utils import ( TestCase, @@ -120,7 +121,6 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a): @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("bias", [True, False]) def test_8da4w_concat_linear_cpu(self, x_dim, bias): - self.skipTest("Disabled for now") N, K = 64, 128 class Mod(torch.nn.Module): @@ -163,6 +163,15 @@ def forward(self, x): # ensure the expected op occurs only once in the code after fusion # The trailing "(" is to avoid matching the op in the comment assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1 + + # Ensure that when concat linear is enabled, fxgraph cache works + # without being bypassed (fxgraph_cache_bypass = 0), indicating that + # DA8W4ConcatLinearCPUPass properly implements the CustomGraphPass + # interface and uuid() function, allowing fxgraph to be saved and hit + # on subsequent runs (fxgraph_cache_hit > 0). + fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"] + assert fx_cache_bypass_count == 0 + with torch._inductor.config.patch( {"freezing": True, "cpp.enable_concat_linear": False} ): @@ -172,6 +181,10 @@ def forward(self, x): ) assert torch.allclose(y, y_ref) + # Ensure that the fxgraph cache is also not bypassed when concat linear is disabled + fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"] + assert fx_cache_bypass_count == 0 + common_utils.instantiate_parametrized_tests(TestDa8w4Cpu) diff --git a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py index 8d0cfaddeb..c0f2fcdfe5 100644 --- a/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py +++ b/torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py @@ -314,6 +314,6 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias): # Register the concat linear fusion pass -# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass +from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass -# register_da8w4_concat_linear_cpu_pass() +register_da8w4_concat_linear_cpu_pass() diff --git a/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py b/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py index 12b1a4696b..8e39826f4c 100644 --- a/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py +++ b/torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py @@ -7,6 +7,15 @@ import operator import torch +from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files + + +class DA8W4ConcatLinearCPUPass(CustomGraphPass): + def __call__(self, graph: torch.fx.Graph): + _concat_linear_dq8w4_cpu(graph) + + def uuid(self): + return get_hash_for_files((__file__,)) # Inductor FX passes for concat linear for DA8W4 @@ -213,4 +222,5 @@ def ... def register_da8w4_concat_linear_cpu_pass(): from torch._inductor import config as inductor_config - inductor_config.post_grad_custom_post_pass = _concat_linear_dq8w4_cpu + da8w4_concat_linear_cpu_pass = DA8W4ConcatLinearCPUPass() + inductor_config.post_grad_custom_post_pass = da8w4_concat_linear_cpu_pass