Skip to content

Commit

Permalink
Add layout option to woq int4 api (#670)
Browse files Browse the repository at this point in the history
* feat: add layout option to woq int4 api

* chore: update tests

* chore: move imports to top of the file
  • Loading branch information
Diogo-V authored Aug 14, 2024
1 parent 174e630 commit 009f55f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 30 deletions.
9 changes: 6 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchao.quantization.dynamic_quant import (
DynamicallyPerAxisQuantizedLinear,
)
from torchao.dtypes import TensorCoreTiledLayoutType
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
Expand Down Expand Up @@ -852,18 +853,20 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
for test_shape in ([(256, 256, 16)] + ([(256, 256, 8)] if device=='cuda' else [])):
for groupsize in [64, 32]:
for inner_k_tiles in [4, 2]:
kwargs = {"groupsize": groupsize, "inner_k_tiles": inner_k_tiles}
kwargs = {"groupsize": groupsize, "layout_type": TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)}

def api(mod):
kwargs_copy = kwargs.copy()
if TORCH_VERSION_AFTER_2_4:
kwargs_copy = kwargs.copy()
kwargs_copy["group_size"] = groupsize
del kwargs_copy["groupsize"]
quantize_(mod, int4_weight_only(**kwargs_copy))
if not TORCH_VERSION_AFTER_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
kwargs_copy["inner_k_tiles"] = inner_k_tiles
del kwargs_copy["layout_type"]
change_linear_weights_to_int4_woqtensors(mod, **kwargs_copy)

self._test_lin_weight_subclass_api_impl(
api,
Expand Down
37 changes: 10 additions & 27 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@
import torch.nn.functional as F
from typing import Any, Callable, Union, Dict, Optional

from torchao.dtypes import PlainLayoutType
from torchao.dtypes.uintx.Uintx import UintxLayoutType
from torchao.dtypes import (
to_affine_quantized,
TensorCoreTiledLayoutType,
PlainLayoutType,
AffineQuantizedTensor,
SemiSparseLayoutType
)
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
Expand Down Expand Up @@ -182,9 +189,6 @@ def _replace_with_custom_fn_if_matches_filter(


def _is_linear(mod, *args):
# avoid circular dep
from torchao.dtypes import AffineQuantizedTensor

# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
# when it is shared by multiple linear modules
return (
Expand Down Expand Up @@ -328,9 +332,6 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
)

def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.ASYMMETRIC
target_dtype = torch.int8
return to_affine_quantized(x, mapping_type, _get_per_token_block_size(x), target_dtype)
Expand All @@ -339,9 +340,6 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
Expand Down Expand Up @@ -373,7 +371,7 @@ def insert_subclass(lin):
return insert_subclass


def int4_weight_only(group_size=128, inner_k_tiles=8):
def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8)):
"""
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
"tensor_core_tiled" layout for speedup with tinygemm kernel
Expand All @@ -389,16 +387,12 @@ def int4_weight_only(group_size=128, inner_k_tiles=8):
Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, choices are [256, 128, 64, 32]
`inner_k_tiles`: parameter for int4 mm kernel, choices are [8, 4, 2]
`layout_type`: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)`
"""
def apply_int4_weight_only_quant(weight):
if weight.shape[-1] % group_size != 0:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
from torchao.dtypes import TensorCoreTiledLayoutType

mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int32
Expand All @@ -408,7 +402,6 @@ def apply_int4_weight_only_quant(weight):
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles)
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type)

return _get_linear_subclass_inserter(apply_int4_weight_only_quant)
Expand All @@ -419,9 +412,6 @@ def int8_weight_only():
Applies int8 weight-only symmetric per-channel quantization to linear layers.
"""
def apply_int8wo_quant(weight):
# avoid circular dep
from torchao.dtypes import to_affine_quantized

mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
Expand All @@ -432,8 +422,6 @@ def apply_int8wo_quant(weight):
return _get_linear_subclass_inserter(apply_int8wo_quant)

def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
# avoid circular dep
from torchao.dtypes import to_affine_quantized
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = 1e-5
Expand All @@ -453,8 +441,6 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
if in_features <= 16:
return weight

# avoid circular dep
from torchao.dtypes import to_affine_quantized
# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
Expand All @@ -479,7 +465,6 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
quantization + 2:4 sparsity to linear layers.
"""
from torchao.dtypes import SemiSparseLayoutType
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())


Expand All @@ -495,8 +480,6 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
quantize_affine,
dequantize_affine,
)
from torchao.dtypes.uintx.Uintx import UintxLayoutType
from torchao.dtypes import to_affine_quantized
from torchao.quantization.quant_api import _get_linear_subclass_inserter
def apply_uintx_weight_only_quant(weight):

Expand Down

0 comments on commit 009f55f

Please sign in to comment.