Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Jul 19, 2024
1 parent 0b0a3a8 commit 49b47a2
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
4 changes: 3 additions & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torchao.utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AFTER_2_5,
is_fbcode,
)

Expand Down Expand Up @@ -98,7 +99,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
.to(torch.int32)
.reshape_as(w)
)
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if TORCH_VERSION_AFTER_2_5:
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)

return w_int4x8

Expand Down
8 changes: 6 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from typing import ClassVar
from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

aten = torch.ops.aten

Expand Down Expand Up @@ -500,8 +501,11 @@ def from_plain(
layout_type: LayoutType
):
assert isinstance(layout_type, TensorCoreTiledLayoutType)
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
if TORCH_VERSION_AFTER_2_5:
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype"
else:
assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
Expand Down
4 changes: 3 additions & 1 deletion torchao/prototype/hqq/hqq_tinygemm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from hqq.core.utils import *

import torch.nn.functional as F
from torchao.utils import TORCH_VERSION_AFTER_2_5


class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
Expand Down Expand Up @@ -198,7 +199,8 @@ def hqq_quants_to_torch_quants(
.reshape(shape)
.contiguous()
)
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
if TORCH_VERSION_AFTER_2_5:
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)

# group_dequantize_tensor_from_qparams
# W_r = W_q*scales + min_val
Expand Down
4 changes: 3 additions & 1 deletion torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
dequantize_affine,
int_scaled_matmul,
)
from torchao.utils import TORCH_VERSION_AFTER_2_5

__all__ = [
"compute_error",
Expand Down Expand Up @@ -349,7 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
quant_max = 2 ** n_bit - 1

int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
if TORCH_VERSION_AFTER_2_5:
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
return int_data

def groupwise_affine_dequantize_tensor_from_qparams(
Expand Down

0 comments on commit 49b47a2

Please sign in to comment.