Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion test/quantization/test_da8w4_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import unittest

import torch
from torch._dynamo.utils import counters
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
TestCase,
Expand Down Expand Up @@ -120,7 +121,6 @@ def test_8da4w_cpu(self, dtype, x_dim, bias, bs, sym_quant_a):
@common_utils.parametrize("x_dim", [2, 3])
@common_utils.parametrize("bias", [True, False])
def test_8da4w_concat_linear_cpu(self, x_dim, bias):
self.skipTest("Disabled for now")
N, K = 64, 128

class Mod(torch.nn.Module):
Expand Down Expand Up @@ -163,6 +163,15 @@ def forward(self, x):
# ensure the expected op occurs only once in the code after fusion
# The trailing "(" is to avoid matching the op in the comment
assert code[0].count("torch.ops.torchao.da8w4_linear_cpu.default(") == 1

# Ensure that when concat linear is enabled, fxgraph cache works
# without being bypassed (fxgraph_cache_bypass = 0), indicating that
# DA8W4ConcatLinearCPUPass properly implements the CustomGraphPass
# interface and uuid() function, allowing fxgraph to be saved and hit
# on subsequent runs (fxgraph_cache_hit > 0).
fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"]
assert fx_cache_bypass_count == 0

with torch._inductor.config.patch(
{"freezing": True, "cpp.enable_concat_linear": False}
):
Expand All @@ -172,6 +181,10 @@ def forward(self, x):
)
assert torch.allclose(y, y_ref)

# Ensure that the fxgraph cache is also not bypassed when concat linear is disabled
fx_cache_bypass_count = counters["inductor"]["fxgraph_cache_bypass"]
assert fx_cache_bypass_count == 0


common_utils.instantiate_parametrized_tests(TestDa8w4Cpu)

Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/uintx/dyn_int8_act_int4_wei_cpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,6 @@ def _linear_int8_act_int4_weight_cpu_impl(input_tensor, weight_tensor, bias):


# Register the concat linear fusion pass
# from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass
from ...prototype.inductor.fx_passes import register_da8w4_concat_linear_cpu_pass

# register_da8w4_concat_linear_cpu_pass()
register_da8w4_concat_linear_cpu_pass()
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
import operator

import torch
from torch._inductor.custom_graph_pass import CustomGraphPass, get_hash_for_files


class DA8W4ConcatLinearCPUPass(CustomGraphPass):
def __call__(self, graph: torch.fx.Graph):
_concat_linear_dq8w4_cpu(graph)

def uuid(self):
return get_hash_for_files((__file__,))


# Inductor FX passes for concat linear for DA8W4
Expand Down Expand Up @@ -213,4 +222,5 @@ def ...
def register_da8w4_concat_linear_cpu_pass():
from torch._inductor import config as inductor_config

inductor_config.post_grad_custom_post_pass = _concat_linear_dq8w4_cpu
da8w4_concat_linear_cpu_pass = DA8W4ConcatLinearCPUPass()
inductor_config.post_grad_custom_post_pass = da8w4_concat_linear_cpu_pass
Loading