@@ -164,9 +164,13 @@ def forward(self, x):
164164 # The trailing "(" is to avoid matching the op in the comment
165165 assert code [0 ].count ("torch.ops.torchao.da8w4_linear_cpu.default(" ) == 1
166166
167- # ensure the custom DA8W4ConcatLinearCPUPass is not bypassed when saving as fxgraph
168- enable_fxgraph_cache_bypass = counters ["inductor" ]["fxgraph_cache_bypass" ]
169- assert enable_fxgraph_cache_bypass == 0
167+ # Ensure that when concat linear is enabled, fxgraph cache works
168+ # without being bypassed (fxgraph_cache_bypass = 0), indicating that
169+ # DA8W4ConcatLinearCPUPass properly implements the CustomGraphPass
170+ # interface and uuid() function, allowing fxgraph to be saved and hit
171+ # on subsequent runs (fxgraph_cache_hit > 0).
172+ fx_cache_bypass_count = counters ["inductor" ]["fxgraph_cache_bypass" ]
173+ assert fx_cache_bypass_count == 0
170174
171175 with torch ._inductor .config .patch (
172176 {"freezing" : True , "cpp.enable_concat_linear" : False }
@@ -177,8 +181,9 @@ def forward(self, x):
177181 )
178182 assert torch .allclose (y , y_ref )
179183
180- disable_fxgraph_cache_bypass = counters ["inductor" ]["fxgraph_cache_bypass" ]
181- assert disable_fxgraph_cache_bypass == 0
184+ # Ensure that the fxgraph cache is also not bypassed when concat linear is disabled
185+ fx_cache_bypass_count = counters ["inductor" ]["fxgraph_cache_bypass" ]
186+ assert fx_cache_bypass_count == 0
182187
183188
184189common_utils .instantiate_parametrized_tests (TestDa8w4Cpu )
0 commit comments