Skip to content

Commit

Permalink
Refactor int8 dynamic quantization with call to quantize (#294)
Browse files Browse the repository at this point in the history
Summary:
Previously we added `quantize` as a general API (#256) for
Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general.

The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant
and 8da4w (for executorch).

This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor
subclass. We'll make sure the performance does not regress for vit model.

Test Plan:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

reference: elapsed_time:  1.4821058654785155  milliseconds
after refactor: elapsed_time:  1.4804757690429688  milliseconds

generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored May 31, 2024
1 parent e7837d7 commit 68ce5b8
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 52 deletions.
1 change: 1 addition & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,7 @@ def _test_lin_weight_subclass_api_impl(


@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen")
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
change_linear_weights_to_int8_dqtensors, device, 35, test_dtype=dtype
Expand Down
62 changes: 60 additions & 2 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,26 @@ def forward(self, x):
x = self.linear2(x)
return x


def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
"""
The deprecated implementation for int8 dynamic quant API, used as a reference for
numerics and performance
"""
from torchao.quantization.quant_api import _in_features_greater_than_16
from torchao.quantization.quant_api import _is_linear
from torchao.quantization.quant_api import _get_subclass_inserter
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight

if filter_fn is None:
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
*args
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
)

class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = ToyLinearModel().eval()
Expand Down Expand Up @@ -492,8 +512,8 @@ def test_quantized_tensor_subclass_int8(self):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
# use multiples of 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
Expand Down Expand Up @@ -525,6 +545,44 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
# make sure it compiles
torch._export.aot_compile(m_unwrapped, example_inputs)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skip("This perf test is supposed to be run locally for sanity check performance when there is a change of int8 dynamic quant implementation")
def test_quantized_tensor_subclass_int8_dyn_quant_perf(self):
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_ref = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")

from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
change_linear_weights_to_int8_dqtensors(m)

# reference
_ref_change_linear_weights_to_int8_dqtensors(m_ref)

res = m(*example_inputs)
ref = m_ref(*example_inputs)

self.assertTrue(torch.equal(res, ref))

# perf comparison
from torchao.utils import benchmark_model
# warmup
WARMUP = 5
RUNS = 100
input_tensor = example_inputs[0]
m = torch.compile(m, mode='max-autotune', fullgraph=True)

benchmark_model(m, WARMUP, input_tensor)
elapsed_time = benchmark_model(m, RUNS, input_tensor)

m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
benchmark_model(m_ref, WARMUP, input_tensor)
ref_elapsed_time = benchmark_model(m_ref, RUNS, input_tensor)

print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
self.assertTrue(elapsed_time < 1.05 * ref_elapsed_time)



if __name__ == "__main__":
Expand Down
118 changes: 81 additions & 37 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def _apply_fn_to_data(self, fn):
fn(self.zero_point),
)

def _change_shape(self, shape):
return self.__class__(
self.int_data.view(shape), self.scale, self.zero_point
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
Expand All @@ -186,6 +191,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.view.default:
assert len(args) == 2
new = args[0]._change_shape(args[1])
return return_and_correct_aliasing(func, args, kwargs, new)

raise NotImplementedError(
f"PlainAQTLayout dispatch: attempting to run {func}, this is not supported"
)
Expand Down Expand Up @@ -245,6 +255,7 @@ def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
# TODO: fix the unflatten logic
return cls(packed_weight, scale_and_zero)

def to(self, *args, **kwargs):
Expand Down Expand Up @@ -470,6 +481,11 @@ def _apply_fn_to_data(self, fn):
strides=self.stride(),
)

def _change_shape(self, shape, block_size):
return self.__class__(
self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
Expand All @@ -491,13 +507,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

@implements_aqt_torch_function(torch.nn.functional.linear)
def functional_linear(*args, **kwargs):
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
def _quantized_linear_op(input_tensor, weight_qtensor, bias):
is_cuda = weight_qtensor.is_cuda
is_cpu = weight_qtensor.device == torch.device("cpu")
if isinstance(weight_qtensor, AffineQuantizedTensor):
Expand All @@ -508,14 +518,10 @@ def functional_linear(*args, **kwargs):
# if input tensor is quantized, either dispatch to the int8 mm kernel
# or just dequantize the input tensor
input_is_int8 = _aqt_is_int8_reduced_range(input_tensor)
input_tensor_dtype_is_expected = input_tensor.dtype in [
torch.float,
torch.bfloat16
]
if (
is_cuda and
input_is_int8 and
input_tensor_dtype_is_expected and
input_tensor.dtype == weight_qtensor.dtype and
input_tensor.layout == "plain" and
weight_qtensor.layout == "plain"
):
Expand Down Expand Up @@ -576,45 +582,83 @@ def functional_linear(*args, **kwargs):
weight_qtensor.block_size[1] == weight_qtensor.shape[1] and
weight_qtensor.layout == "plain"
):
# TODO: enable mps path as well
# TODO: enable cpu and mps efficient path
# per channel int8 weight only quantizated mm
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.layout_tensor.int_data, weight_qtensor.layout_tensor.scale)
else:
weight_tensor = weight_qtensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
else:
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t().contiguous()
orig_dtype = input_tensor.dtype
y = (
torch.mm(
input_tensor.reshape(-1, input_tensor.shape[-1]),
w_vals_int8_t.to(input_tensor.dtype),
)
* weight_qtensor.scale
)
y = y.reshape(*input_tensor.shape[:-1], y.shape[-1])
if bias is not None:
y += bias
return y.to(orig_dtype)

# is_cpu and is_mps only, some issue with is_contiguous() currently
# return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)

raise NotImplementedError("No specialized dispatch found for quantized linear op")


@implements_aqt_torch_function(torch.nn.functional.linear)
def functional_linear(*args, **kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default])
def aten_mm(func, *args, **kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(f"{func} is not implemented for non floating point input")

# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
# make the branches easier to understand in `_quantized_linear_op`
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
input_tensor, weight_qtensor, bias = (
input_tensor, weight_tensor, bias = (
args[1],
args[2],
args[0],
)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return func(bias, input_tensor, weight_tensor)
else:
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
input_tensor, weight_qtensor, bias = (
input_tensor, weight_tensor, bias = (
args[0],
args[1],
None if len(args) == 2 else args[2],
None
)
weight_tensor = weight_qtensor.dequantize()
return func(input_tensor, weight_tensor, bias)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return func(input_tensor, weight_tensor)

@implements_aqt_aten_ops([aten.detach.default])
def detach(func, *args, **kwargs):
Expand All @@ -641,10 +685,10 @@ def _to_copy(func, *args, **kwargs):

@implements_aqt_aten_ops([aten.t.default])
def t(func, *args, **kwargs):
# TODO: need to implement this
# args[0].transposed = not args[0].transposed
# new = args[0]._change_shape(args[0].shape[::-1])
# return return_and_correct_aliasing(func, args, kwargs, new)
raise Exception("transpose not implemented yet")
block_size = args[0].block_size
assert len(block_size) == 2
transposed_block_size = (block_size[1], block_size[0])
new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size)
return return_and_correct_aliasing(func, args, kwargs, new)

to_aq = AffineQuantizedTensor.from_float
18 changes: 13 additions & 5 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from typing import Any, Callable

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
from .utils import (
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
unwrap_tensor_subclass,
)

from .subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Expand Down Expand Up @@ -187,9 +191,13 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
*args
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
)
if TORCH_VERSION_AFTER_2_4:
quantize(model, get_apply_int8dyn_quant(), filter_fn)
unwrap_tensor_subclass(model, filter_fn)
else:
_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
)


def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
Expand Down Expand Up @@ -282,7 +290,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
apply_weight_quant = lambda x: to_aqt(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
apply_weight_quant = lambda x: to_aq(x, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain)
# apply to modules under block0 submodule
def filter_fn(module, fqn):
Expand Down
Loading

0 comments on commit 68ce5b8

Please sign in to comment.