Skip to content

Commit 62d3689

Browse files
author
Cui, Yuxin
committed
Fix FX Graph Cache issue in register_da8w4_concat_linear_cpu_pass
Fix the bug that the FX Graph Cache was being bypassed when using the register_da8w4_concat_linear_cpu_pass, preventing cache hits on subsequent model runs. Implement DA8W4ConcatLinearCPUPass that inherits from CustomGraphPass. Ensure it can be serialized and saved as fxgraph properly. Add the unit test. When saving fxgraph, the fxgraph_cache_bypass shuold remain at 0, confirming that the custom pass is no longer being rejected by the cache system. Signed-off-by: Cui, Yuxin <yuxin.cui@intel.com>
1 parent bc2c83e commit 62d3689

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

test/quantization/test_da8w4_cpu.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import unittest
99

1010
import torch
11+
from torch._dynamo.utils import counters
1112
from torch.testing._internal import common_utils
1213
from torch.testing._internal.common_utils import (
1314
TestCase,
@@ -120,7 +121,6 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a):
120121
@common_utils.parametrize("x_dim", [2, 3])
121122
@common_utils.parametrize("bias", [True, False])
122123
def test_8da4w_concat_linear_cpu(self, x_dim, bias):
123-
self.skipTest("Disabled for now")
124124
N, K = 64, 128
125125

126126
class Mod(torch.nn.Module):
@@ -163,6 +163,11 @@ def forward(self, x):
163163
# ensure the expected op occurs only once in the code after fusion
164164
# The trailing "(" is to avoid matching the op in the comment
165165
assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1
166+
167+
# ensure the custom DA8W4ConcatLinearCPUPass is properly cached as fxgraph
168+
enable_fxgraph_cache_bypass = counters["inductor"]["fxgraph_cache_bypass"]
169+
assert enable_fxgraph_cache_bypass == 0
170+
166171
with torch._inductor.config.patch(
167172
{"freezing": True, "cpp.enable_concat_linear": False}
168173
):
@@ -172,6 +177,9 @@ def forward(self, x):
172177
)
173178
assert torch.allclose(y, y_ref)
174179

180+
disable_fxgraph_cache_bypass = counters["inductor"]["fxgraph_cache_bypass"]
181+
assert disable_fxgraph_cache_bypass == 0
182+
175183

176184
common_utils.instantiate_parametrized_tests(TestDa8w4Cpu)
177185

torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,6 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias):
314314

315315

316316
# Register the concat linear fusion pass
317-
# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass
317+
from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass
318318

319-
# register_da8w4_concat_linear_cpu_pass()
319+
register_da8w4_concat_linear_cpu_pass()

torchao/prototype/inductor/fx_passes/da8w4_concat_linear_fusion_cpu.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
import operator
88

99
import torch
10+
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files
1011

12+
class DA8W4ConcatLinearCPUPass(CustomGraphPass):
13+
def __call__(self, graph: torch.fx.Graph):
14+
_concat_linear_dq8w4_cpu(graph)
15+
def uuid(self):
16+
return get_hash_for_files((__file__,))
1117

1218
# Inductor FX passes for concat linear for DA8W4
1319
def _is_valid_concat_linear_da8w4_fusion(computation_nodes):
@@ -213,4 +219,5 @@ def ...
213219
def register_da8w4_concat_linear_cpu_pass():
214220
from torch._inductor import config as inductor_config
215221

216-
inductor_config.post_grad_custom_post_pass = _concat_linear_dq8w4_cpu
222+
da8w4_concat_linear_cpu_pass = DA8W4ConcatLinearCPUPass()
223+
inductor_config.post_grad_custom_post_pass = da8w4_concat_linear_cpu_pass

0 commit comments

Comments
 (0)