Skip to content

Commit

Permalink
Replace implementation for int8 dynamic quantization with call to `qu…
Browse files Browse the repository at this point in the history
…antize`

Summary:
Previously we added `quantize` as a general API (pytorch#256) for
Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general.

The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant
and 8da4w (for executorch).

This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor
subclass. We'll make sure the performance does not regress for vit model.

Test Plan:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py

reference: elapsed_time:  1.4821058654785155  milliseconds
after refactor: elapsed_time:  1.4804757690429688  milliseconds

generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 30, 2024
1 parent 374fec4 commit 22ca192
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 44 deletions.
2 changes: 1 addition & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def test_quantized_tensor_subclass_int8(self):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
Expand Down
156 changes: 125 additions & 31 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def _apply_fn_to_data(self, fn):
fn(self.zero_point),
)

def _change_shape(self, shape):
return self.__class__(
self.int_data.view(shape), self.scale, self.zero_point
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
Expand Down Expand Up @@ -245,6 +250,7 @@ def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
# TODO: fix the unflatten logic
return cls(packed_weight, scale_and_zero)

def to(self, *args, **kwargs):
Expand Down Expand Up @@ -282,6 +288,74 @@ def get_plain(self):
f"Unpacking for tensor core tiled storage is not yet implemented"
)

@register_aqt_layout_cls("transposed")
class TransposedAQTLayout(PlainAQTLayout):
"""
Layout storage class for transposed layout for affine quantized tensor, it's the same as
plain layout but stores transposed int_data.
fields:
int_data (torch.Tensor): the transposed quantized integer data Tensor
scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor
zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor
"""
def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
):
self.int_data = int_data.contiguous().t()
self.scale = scale
self.zero_point = zero_point

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point"], []

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
return cls(int_data.t(), scale, zero_point)

def _change_shape(self, shape):
return self.__class__(
self.int_data.t().view(shape), self.scale, self.zero_point
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.view.default:
assert len(args) == 2
new = args[0]._change_shape(args[1])
return return_and_correct_aliasing(
func, args, kwargs, new
)

raise NotImplementedError(
f"TransposedAQTLayout dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self):
return self.int_data.t(), self.scale, self.zero_point

class AffineQuantizedTensor(torch.Tensor):
"""
Base affine quantized tensor subclass. When the from_float method is used,
Expand Down Expand Up @@ -356,7 +430,7 @@ def __init__(

def __repr__(self):
return (
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
f"{self.__class__.__name__}(data={self.dequantize(self.dtype)}, shape={self.shape}, "
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
)

Expand Down Expand Up @@ -470,6 +544,11 @@ def _apply_fn_to_data(self, fn):
strides=self.stride(),
)

def _change_shape(self, shape, block_size):
return self.__class__(
self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
Expand All @@ -491,13 +570,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

@implements_aqt_torch_function(torch.nn.functional.linear)
def functional_linear(*args, **kwargs):
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
def _quantized_linear_op(input_tensor, weight_qtensor, bias):
is_cuda = weight_qtensor.is_cuda
is_cpu = weight_qtensor.device == torch.device("cpu")
if isinstance(weight_qtensor, AffineQuantizedTensor):
Expand All @@ -516,9 +589,14 @@ def functional_linear(*args, **kwargs):
is_cuda and
input_is_int8 and
input_tensor_dtype_is_expected and
input_tensor.dtype == weight_qtensor.dtype and
input_tensor.layout == "plain" and
weight_qtensor.layout == "plain"
weight_qtensor.layout == "transposed"
):
assert input_tensor.shape[-1] == weight_qtensor.layout_tensor.int_data.shape[0], (
f"need mat1 shape: {input_tensor.shape} final "
f"dim to match mat2 shape: {weight_qtensor.layout_tensor.int_data.shape} first dim "
)
#
# 1. do the matrix form of dot(X_i, W_j)
#
Expand All @@ -532,7 +610,7 @@ def functional_linear(*args, **kwargs):

x_vals_int8 = input_tensor.layout_tensor.int_data
x_scales = input_tensor.layout_tensor.scale
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.contiguous().t()
w_vals_int8_t = weight_qtensor.layout_tensor.int_data
w_scales = weight_qtensor.layout_tensor.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1))
Expand Down Expand Up @@ -579,42 +657,58 @@ def functional_linear(*args, **kwargs):
# TODO: enable mps path as well
# per channel int8 weight only quantizated mm
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.layout_tensor.int_data, weight_qtensor.layout_tensor.scale)
else:
weight_tensor = weight_qtensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
else:
raise NotImplementedError("No specialized dispatch found for quantized linear op")


@implements_aqt_torch_function(torch.nn.functional.linear)
def functional_linear(*args, **kwargs):
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
try:
return _quantized_linear_op(input_tensor, weight_qtensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
weight_tensor = weight_qtensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


# TODO: add dispatch to efficient kernels
@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default])
def aten_mm(func, *args, **kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(f"{func} is not implemented for non floating point input")

if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
input_tensor, weight_qtensor, bias = (
args[1],
args[2],
args[0],
)
try:
return _quantized_linear_op(input_tensor, weight_qtensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_qtensor.dequantize()
return func(bias, input_tensor, weight_tensor)
else:
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
None if len(args) == 2 else args[2],
None
)
weight_tensor = weight_qtensor.dequantize()
return func(input_tensor, weight_tensor, bias)
try:
return _quantized_linear_op(input_tensor, weight_qtensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_qtensor.dequantize()
return func(bias, input_tensor, weight_tensor)

@implements_aqt_aten_ops([aten.detach.default])
def detach(func, *args, **kwargs):
Expand All @@ -641,10 +735,10 @@ def _to_copy(func, *args, **kwargs):

@implements_aqt_aten_ops([aten.t.default])
def t(func, *args, **kwargs):
# TODO: need to implement this
# args[0].transposed = not args[0].transposed
# new = args[0]._change_shape(args[0].shape[::-1])
# return return_and_correct_aliasing(func, args, kwargs, new)
raise Exception("transpose not implemented yet")
block_size = args[0].block_size
assert len(block_size) == 2
transposed_block_size = (block_size[1], block_size[0])
new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size)
return return_and_correct_aliasing(func, args, kwargs, new)

to_aq = AffineQuantizedTensor.from_float
11 changes: 7 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,12 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
*args
)

_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
)
if TORCH_VERSION_AFTER_2_4:
quantize(model, get_apply_int8dyn_quant(), filter_fn)
else:
_replace_with_custom_fn_if_matches_filter(
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn
)


def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs):
Expand Down Expand Up @@ -393,7 +396,7 @@ def get_per_token_block_size(x):
input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None)

block_size = get_weight_block_size(weight)
weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, extended_layout="transposed")
weight = to_laq(weight, input_quant_func)
return weight
return apply_int8dyn_quant
53 changes: 46 additions & 7 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ def __new__(
dtype = original_weight_tensor.dtype
kwargs["dtype"] = dtype
kwargs["requires_grad"] = False
kwargs["device"] = original_weight_tensor.device
shape = original_weight_tensor.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

Expand Down Expand Up @@ -664,6 +665,27 @@ def _apply_fn_to_data(self, fn):
self.input_quant_func,
)

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.original_weight_tensor.to(**kwargs),
self.input_quant_func,
)

def __torch_dispatch__(cls, func, types, args, kwargs):
if (
func in [aten.mm.default, aten.addmm.default]
Expand All @@ -674,25 +696,29 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
input_tensor, weight_qtensor, bias = (
input_tensor, weight_tensor, bias = (
args[1],
args[2],
args[0],
)
aqt = self.input_quant_func(input_tensor)
return func(bias, aqt, weight_tensor)
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
return func(bias, aqt, original_weight_tensor)
else:
# aten.mm.default
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
input_tensor, weight_qtensor, bias = (
input_tensor, weight_tensor = (
args[0],
args[1],
None if len(args) == 2 else args[2],
)
aqt = self.input_quant_func(input_tensor)
return func(aqt, weight_tensor, bias)
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
aqt = input_quant_func(input_tensor)
return func(aqt, original_weight_tensor)

if func is aten.detach.default:
return return_and_correct_aliasing(
Expand All @@ -704,6 +730,19 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten._to_copy.default:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)

if func is aten.t.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
)

raise NotImplementedError(
f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)
Expand Down
7 changes: 6 additions & 1 deletion tutorials/quantize_vit/run_vit_b_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
inductorconfig.force_fuse_int_mm_with_mul = True
## Quantization code - end

model = torch.compile(model, mode='max-autotune')
## workaround for tensor subclass
from torchao.quantization.utils import unwrap_tensor_subclass
model = unwrap_tensor_subclass(model)
## workaround for tensor subclass end

model = torch.compile(model, mode='max-autotune', fullgraph=True)

# Must run with no_grad when optimizing for inference
with torch.no_grad():
Expand Down

0 comments on commit 22ca192

Please sign in to comment.