Skip to content

Add quantize #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 62 additions & 53 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,24 @@
get_symmetric_quantization_config,
)

from torchao.quantization.subclass import (
to_aqt,
to_laqt,
AffineQuantizedTensor,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)

from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
apply_dynamic_quant,
apply_weight_only_int8_quant,
Quantizer,
TwoStepQuantizer,
quantize,
)
from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_3,
Expand All @@ -32,6 +44,7 @@
from pathlib import Path
from sentencepiece import SentencePieceProcessor
from model import Transformer, prepare_inputs_for_model
import copy


def dynamic_quant(model, example_inputs):
Expand Down Expand Up @@ -92,8 +105,8 @@ def __init__(self, m=64, n=32, k=64):
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, self.linear1.in_features).to(torch.float),)
def example_inputs(self, batch_size=1):
return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand Down Expand Up @@ -395,13 +408,6 @@ def test_eval_wrapper(self):
# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
from torchao.quantization.subclass import (
AffineQuantizedTensor,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
groupsize = 32
mapping_type = MappingType.SYMMETRIC
Expand All @@ -423,20 +429,26 @@ def get_per_token_block_size(x):
# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
dynamic_quant(m.linear1)
dynamic_quant(m.linear2)

def apply_weight_quant(weight):
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps)

def apply_act_quant(weight):
return to_laqt(weight, input_quant_func)

# note: order is important
m = quantize(m, apply_weight_quant)
m = quantize(m, apply_act_quant)

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand All @@ -454,11 +466,6 @@ def dynamic_quant(linear):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
import copy

# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
Expand All @@ -469,22 +476,17 @@ def test_quantized_tensor_subclass_int4(self):
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def to_quantized(weight):
return AffineQuantizedTensor.from_float(
weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
)
def apply_weight_quant(weight):
return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
m = quantize(m, apply_weight_quant)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand All @@ -501,10 +503,6 @@ def to_quantized(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
Expand All @@ -515,12 +513,12 @@ def test_quantized_tensor_subclass_int8(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
def apply_weight_quant(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

m = quantize(m, apply_weight_quant)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand All @@ -537,12 +535,6 @@ def to_quantized(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.subclass import LinearActQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
import copy

# weight settings
mapping_type = MappingType.SYMMETRIC
def get_weight_block_size(x):
Expand All @@ -563,20 +555,24 @@ def get_per_token_block_size(x):
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)
input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))
# setting batch_size to 20 to be compatible with the kernel
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))

def apply_weight_quant(weight):
block_size = get_weight_block_size(weight)
return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
def apply_act_quant(weight):
return to_laqt(weight, input_quant_func)

m = quantize(m, apply_weight_quant)
m = quantize(m, apply_act_quant)

dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
Expand All @@ -591,6 +587,19 @@ def dynamic_quant(linear):

self.assertTrue(torch.equal(res, ref))

# workaround for export path
from torchao.quantization.utils import unwrap_tensor_subclass
m_unwrapped = unwrap_tensor_subclass(m)

m = torch.export.export(m_unwrapped, example_inputs).module()
exported_model_res = m(*example_inputs)

self.assertTrue(torch.equal(exported_model_res, ref))

# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)



if __name__ == "__main__":
unittest.main()
50 changes: 49 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Callable

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
Expand Down Expand Up @@ -48,7 +49,8 @@
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"autoquant"
"quantize",
"autoquant",
]

if TORCH_VERSION_AFTER_2_3:
Expand Down Expand Up @@ -215,3 +217,49 @@ def replace_conv2d_1x1(conv):
_replace_with_custom_fn_if_matches_filter(
model, replace_conv2d_1x1, filter_fn=filter_fn
)


def _get_linear_subclass_inserter(constructor):
def insert_subclass(lin):
lin.weight = torch.nn.Parameter(constructor(lin.weight), requires_grad=False)
return lin

return insert_subclass

def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn=None) -> torch.nn.Module:
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`

Args:
model: input model
apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance
filter_fn: used to filter out the modules that we don't want to apply tenosr subclass

Example::

# weight settings
groupsize = 32
mapping_type = MappingType.ASYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

apply_weight_quant = lambda x: to_aqt(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)

# apply to modules under block0 submodule
def filter_fn(module, fqn):
return fqn == "block0"

m = MyModel(...)
m = quantize(m, apply_weight_quant, filter_fn)
"""
_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(apply_tensor_subclass),
_is_linear if filter_fn is None else filter_fn,
)
return model
35 changes: 6 additions & 29 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
"AffineQuantizedTensor",
"LinearActQuantizedTensor",
]


Expand Down Expand Up @@ -266,7 +267,6 @@ def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs):

self.q_scales = q_scales
super().__init__(int_data, transposed)

Expand Down Expand Up @@ -629,32 +629,6 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles

def to_aqt(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min = None,
quant_max = None,
eps = None,
scale_dtype = None,
zero_point_dtype = None,
preserve_zero = True,
zero_point_domain = ZeroPointDomain.INT,
):
return AffineQuantizedTensor.from_float(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain
)

# TODO: merge with nf4 implements decorator
# aten op to their __torch_dispatch__ implemnetations for the tensor subclass
Expand Down Expand Up @@ -777,7 +751,7 @@ def dequantize(self, output_dtype=None):
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)

def __tensor_flatten__(self):
return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

@classmethod
def __tensor_unflatten__(
Expand Down Expand Up @@ -1091,7 +1065,7 @@ def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
input_quant_func = tensor_attributes
input_quant_func, = tensor_attributes
return cls(
original_weight_tensor,
input_quant_func,
Expand Down Expand Up @@ -1176,3 +1150,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
raise NotImplementedError(
f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

to_aqt = AffineQuantizedTensor.from_float
to_laqt = LinearActQuantizedTensor.from_float
Loading
Loading