Skip to content

Commit cfa39c8

Browse files
Support PLAIN_INT32 for AWQ on Intel GPU (#3019)
* Support PLAIN_INT32 for AWQ on Intel GPU * Support PLAIN_INT32 for AWQ on Intel GPU * Support PLAIN_INT32 for AWQ on Intel GPU
1 parent ae204cc commit cfa39c8

File tree

4 files changed

+66
-3
lines changed

4 files changed

+66
-3
lines changed

test/prototype/test_awq.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def forward(self, x):
5151
devices.append("cuda")
5252

5353

54+
if torch.xpu.is_available():
55+
devices.append("xpu")
56+
57+
5458
class TestAWQ(TestCase):
5559
def test_awq_config(self):
5660
base_config = Int4WeightOnlyConfig()
@@ -79,6 +83,10 @@ def test_awq_functionality(self, device):
7983
# baseline quantization
8084
if device == "cuda":
8185
base_config = Int4WeightOnlyConfig(group_size=group_size)
86+
elif device == "xpu":
87+
base_config = Int4WeightOnlyConfig(
88+
group_size=group_size, int4_packing_format="plain_int32"
89+
)
8290
elif device == "cpu":
8391
base_config = Int4WeightOnlyConfig(
8492
group_size=group_size, int4_packing_format="opaque"
@@ -137,6 +145,10 @@ def test_awq_loading(self, device):
137145
# calibrate
138146
if device == "cuda":
139147
base_config = Int4WeightOnlyConfig(group_size=group_size)
148+
elif device == "xpu":
149+
base_config = Int4WeightOnlyConfig(
150+
group_size=group_size, int4_packing_format="plain_int32"
151+
)
140152
elif device == "cpu":
141153
base_config = Int4WeightOnlyConfig(
142154
group_size=group_size, int4_packing_format="opaque"
@@ -198,6 +210,10 @@ def test_awq_loading_vllm(self, device):
198210
# calibrate
199211
if device == "cuda":
200212
base_config = Int4WeightOnlyConfig(group_size=group_size)
213+
elif device == "xpu":
214+
base_config = Int4WeightOnlyConfig(
215+
group_size=group_size, int4_packing_format="plain_int32"
216+
)
201217
elif device == "cpu":
202218
base_config = Int4WeightOnlyConfig(
203219
group_size=group_size, int4_packing_format="opaque"

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
Int4WeightOnlyConfig,
2020
quantize_,
2121
)
22+
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
2223
from torchao.quantization.utils import compute_error
2324
from torchao.utils import (
2425
torch_version_at_least,
@@ -77,6 +78,25 @@ def test_module_path(self, dtype):
7778
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
7879
)
7980

81+
def test_activation_prescaling(self):
82+
dtype = torch.bfloat16
83+
device = "xpu"
84+
input = torch.randn(1, 128, dtype=dtype, device=device)
85+
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
86+
original = linear(input)
87+
quantize_(linear, get_config(128))
88+
qw = linear.weight
89+
assert isinstance(qw, SupportsActivationPreScaling), (
90+
"Expected int4 tensor supports activation prescaling"
91+
)
92+
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
93+
_ACT_PRE_SCALE = 2
94+
qw.act_pre_scale = _ACT_PRE_SCALE
95+
quantized = linear(input)
96+
97+
# making sure activation pre scaling is successfully applied to the activation
98+
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)
99+
80100

81101
instantiate_parametrized_tests(Int4PlainInt32Tensor)
82102

torchao/prototype/awq/example.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ def quantize_and_eval(
254254

255255
if device == "cuda":
256256
base_config = Int4WeightOnlyConfig(group_size=group_size)
257+
elif device == "xpu":
258+
base_config = Int4WeightOnlyConfig(
259+
group_size=group_size, int4_packing_format="plain_int32"
260+
)
257261
elif device == "cpu":
258262
base_config = Int4WeightOnlyConfig(
259263
group_size=group_size, int4_packing_format="opaque"

torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from typing import List
8+
from typing import List, Optional
99

1010
import torch
1111

@@ -38,10 +38,16 @@ class Int4PlainInt32Tensor(TorchAOBaseTensor):
3838
block_size: the block size for quantization, representing the granularity.
3939
shape: shape of the original Tensor
4040
41+
Optional Tensor Data Attributes:
42+
act_pre_scale (Optional[Tensor]): Optional scale for activation Tensor, if present,
43+
we'll multiply activation Tensor with act_pre_scale before applying dynamic
44+
quantization to activation or running quantized mm op
45+
4146
"""
4247

4348
tensor_data_names = ["qdata", "scale", "zero_point"]
4449
tensor_attribute_names = ["block_size", "shape"]
50+
optional_tensor_data_names = ["act_pre_scale"]
4551

4652
def __new__(
4753
cls,
@@ -50,21 +56,34 @@ def __new__(
5056
zero_point,
5157
block_size,
5258
shape,
59+
act_pre_scale: Optional[torch.Tensor] = None,
5360
):
5461
kwargs = {}
5562
kwargs["device"] = qdata.device
5663
kwargs["dtype"] = scale.dtype
5764
kwargs["requires_grad"] = False
5865
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
5966

60-
def __init__(self, qdata, scale, zero_point, block_size, shape):
67+
def __init__(
68+
self,
69+
qdata,
70+
scale,
71+
zero_point,
72+
block_size,
73+
shape,
74+
act_pre_scale: Optional[torch.Tensor] = None,
75+
):
6176
self.qdata = qdata
6277
self.scale = scale
6378
self.zero_point = zero_point
6479
self.block_size = block_size
80+
self.act_pre_scale = act_pre_scale
6581

6682
def _quantization_type(self):
67-
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
83+
s = f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
84+
if self.act_pre_scale is not None:
85+
s += f", act_pre_scale.shape={self.act_pre_scale.shape}"
86+
return s
6887

6988
@classmethod
7089
def from_hp(
@@ -122,6 +141,7 @@ def from_hp(
122141
zero_point.transpose(0, 1).contiguous().to(torch.int8),
123142
block_size,
124143
original_shape,
144+
act_pre_scale=None,
125145
)
126146

127147

@@ -148,6 +168,9 @@ def _(func, types, args, kwargs):
148168
f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}"
149169
)
150170

171+
if weight_tensor.act_pre_scale is not None:
172+
input_tensor = input_tensor * weight_tensor.act_pre_scale
173+
151174
act_mat = input_tensor
152175
packed_weight = weight_tensor.qdata
153176
scale = weight_tensor.scale

0 commit comments

Comments
 (0)