Skip to content

Commit

Permalink
Fix int4pack_mm error (#517)
Browse files Browse the repository at this point in the history
* Fix int4pack_mm error

* fix CI

* Fix CI

* Fix CI

* Fix CI

* Fix CI
  • Loading branch information
yanbing-j authored Jul 29, 2024
1 parent e5b705c commit 8fa11a6
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 26 deletions.
2 changes: 1 addition & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_tensor_core_layout_transpose(self):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
t = l.weight
Expand Down
16 changes: 8 additions & 8 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
Expand All @@ -642,7 +642,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
Expand Down Expand Up @@ -737,7 +737,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand All @@ -748,7 +748,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down Expand Up @@ -823,7 +823,7 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand All @@ -838,7 +838,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down Expand Up @@ -1028,7 +1028,7 @@ def test_save_load_int8woqtensors(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch 2.3+.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
@torch.no_grad()
def test_save_load_int4woqtensors(self, device, dtype):
if dtype != torch.bfloat16:
Expand Down Expand Up @@ -1488,7 +1488,7 @@ def test_get_model_size_autoquant(self, device, dtype):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_get_model_size_aqt(self, api, test_device, test_dtype):
if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not supported yet")
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def test_quantized_tensor_subclass_8da4w(self):
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
# @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
# use 1024 so that we don't need padding
Expand Down
9 changes: 8 additions & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AFTER_2_5,
is_fbcode,
)

Expand Down Expand Up @@ -99,6 +100,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
.to(torch.int32)
.reshape_as(w)
)
if TORCH_VERSION_AFTER_2_5:
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)

return w_int4x8

Expand Down Expand Up @@ -500,7 +503,11 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
n_bit = 4
groupsize = 128

w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
if TORCH_VERSION_AFTER_2_5:
input_uint8 = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input_uint8, scales, zeros, n_bit, groupsize)
else:
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize)

self.assertTrue(torch.equal(w_bf16, w_bf16_ref))
Expand Down
21 changes: 16 additions & 5 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,24 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
N, K = shape
assert K % (inner_k_tiles * kTileSizeK) == 0 and N % kTileSizeN == 0

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
if TORCH_VERSION_AFTER_2_5:
t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed_w, inner_k_tiles)
if TORCH_VERSION_AFTER_2_5:
unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8)
assert torch.equal(t, unpacked)

# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
test_utils = [
Expand All @@ -122,6 +126,8 @@ def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
test_utils.append("test_aot_dispatch_dynamic")

t = torch.randint(0, 16, dtype=torch.int, size=shape, device="cuda")
if TORCH_VERSION_AFTER_2_5:
t = (t[::, ::2] << 4 | t[::, 1::2]).to(torch.uint8)
packed_w = torch.ops.aten._convert_weight_to_int4pack(t, inner_k_tiles)

opcheck(
Expand Down Expand Up @@ -151,7 +157,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
n, k = shape
Expand Down Expand Up @@ -210,7 +216,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in

# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
n, k = shape
Expand All @@ -229,6 +235,9 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap

# Unpack and dequantize
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
if TORCH_VERSION_AFTER_2_5:
unpacked = (unpacked[::, ::2] << 4 | unpacked[::, 1::2]).to(torch.uint8)

dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
unpacked, scales, zeros, n_bit=4, groupsize=group_size
)
Expand Down Expand Up @@ -264,13 +273,15 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
assert diff_op_ao < 1e-1

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
# @pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size):
n, k = shape
device = "cuda"

q = torch.randint(0, 16, shape, dtype=torch.int, device=device)
if TORCH_VERSION_AFTER_2_5:
q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8)
packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles)
q_groups = k // group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
Expand Down
11 changes: 7 additions & 4 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from typing import ClassVar
from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

aten = torch.ops.aten

Expand Down Expand Up @@ -245,7 +246,6 @@ def from_float(

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

int_data = layout_type.post_process(int_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
Expand Down Expand Up @@ -570,9 +570,12 @@ def from_plain(
layout_type: LayoutType
):
assert isinstance(layout_type, TensorCoreTiledLayoutType)
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), layout_type.inner_k_tiles)
if TORCH_VERSION_AFTER_2_5:
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype"
else:
assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
Expand Down
3 changes: 3 additions & 0 deletions torchao/prototype/hqq/hqq_tinygemm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hqq.core.utils import *

import torch.nn.functional as F
from torchao.utils import TORCH_VERSION_AFTER_2_5


class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
Expand Down Expand Up @@ -198,6 +199,8 @@ def hqq_quants_to_torch_quants(
.reshape(shape)
.contiguous()
)
if TORCH_VERSION_AFTER_2_5:
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)

# group_dequantize_tensor_from_qparams
# W_r = W_q*scales + min_val
Expand Down
23 changes: 17 additions & 6 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
dequantize_affine,
int_scaled_matmul,
)
from torchao.utils import TORCH_VERSION_AFTER_2_5

__all__ = [
"compute_error",
Expand Down Expand Up @@ -349,6 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
quant_max = 2 ** n_bit - 1

int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
if TORCH_VERSION_AFTER_2_5:
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
return int_data

def groupwise_affine_dequantize_tensor_from_qparams(
Expand All @@ -359,18 +362,26 @@ def groupwise_affine_dequantize_tensor_from_qparams(
groupsize=128,
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int4x8.shape[-1]
assert w_int4x8.shape[-1] % groupsize == 0
assert w_int4x8.dim() == 2
if TORCH_VERSION_AFTER_2_5:
data = w_int4x8.to(torch.int32)
high_bits = data >> 4
low_bits = data & 0x0F
w_int32 = torch.zeros((w_int4x8.shape[0], w_int4x8.shape[1] * 2), dtype=torch.int32, device=w_int4x8.device)
w_int32[::, ::2] = high_bits
w_int32[::, 1::2] = low_bits
else:
w_int32 = w_int4x8

# needed for GPTQ single column dequantize
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int32.shape[-1]
assert w_int32.shape[-1] % groupsize == 0
block_size = (1, groupsize)
input_dtype = torch.int32
quant_min = 0
quant_max = 2**n_bit - 1
return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)

return dequantize_affine(w_int32, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype)

def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
scales, zeros = get_groupwise_affine_qparams(w, n_bit, groupsize, dtype)
Expand Down

0 comments on commit 8fa11a6

Please sign in to comment.