Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix affine quantized tensor to device calls #726

Merged
merged 4 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ def test_weights_only(self):
else:
_ = torch.load(f, weights_only=False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_device(self):
from torchao.quantization import quantize_
for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to("cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to(device="cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.cuda()



if __name__ == "__main__":
run_tests()
58 changes: 21 additions & 37 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
pack_tinygemm_scales_and_zeros,
)
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import find_multiple
from torchao.dtypes.utils import (
_implements,
_dispatch__torch_function__,
Expand All @@ -29,14 +28,18 @@
)
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import (
find_multiple,
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
)

aten = torch.ops.aten

###############################
# Base Layout Tensor Subclass #
###############################
class AQTLayout(torch.Tensor):
class AQTLayout(TorchAOBaseTensor):
"""
Base class for the layout tensor for `AffineQuantizedTensor`
"""
Expand All @@ -61,19 +64,6 @@ def __repr__(self):
layout_type = self.get_layout_type()
return f"{self.__class__.__name__}(int_data={int_data}, scale={scale}, zero_point={zero_point}, layout_type={layout_type})"

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs

##############################
# Tensor Subclass Definition #
Expand All @@ -83,7 +73,7 @@ def _get_to_kwargs(self, *args, **kwargs):
def _register_quantized_linear_dispatch(dispatch_condition, impl):
_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl

class AffineQuantizedTensor(torch.Tensor):
class AffineQuantizedTensor(TorchAOBaseTensor):
"""
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
quantized_tensor = float_tensor / scale + zero_point
Expand Down Expand Up @@ -223,7 +213,7 @@ def from_float(
input_float = layout_type.pre_process(input_float)
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)

int_data = layout_type.post_process(int_data)
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
Expand Down Expand Up @@ -273,25 +263,9 @@ def from_float_static(
def layout_type(self) -> LayoutType:
return self.layout_tensor.layout_type

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
device = kwargs.pop("device")
# not supported yet
kwargs.pop("memory_format")
return self.__class__(
self.layout_tensor.to(device),
self.block_size,
Expand Down Expand Up @@ -446,6 +420,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.t.default:
tensor = args[0]
new = tensor.__class__(
Expand Down Expand Up @@ -576,10 +555,10 @@ def from_plain(
scale: torch.Tensor,
zero_point: torch.Tensor,
layout_type: LayoutType
):
):

assert isinstance(layout_type, TensorCoreTiledLayoutType)

if TORCH_VERSION_AT_LEAST_2_5:
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype"
Expand Down Expand Up @@ -617,6 +596,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
Expand Down
21 changes: 5 additions & 16 deletions torchao/quantization/linear_activation_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
)
from typing import Callable
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import (
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
)

__all__ = [
"LinearActivationQuantizedTensor",
Expand All @@ -15,7 +18,7 @@

aten = torch.ops.aten

class LinearActivationQuantizedTensor(torch.Tensor):
class LinearActivationQuantizedTensor(TorchAOBaseTensor):
"""
Applies activation quantization for linear operator
"""
Expand Down Expand Up @@ -74,20 +77,6 @@ def _apply_fn_to_data(self, fn):
self.input_quant_func,
)

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
Expand Down
25 changes: 25 additions & 0 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"_register_custom_op",
"get_model_size_in_bytes",
"unwrap_tensor_subclass",
"TorchAOBaseTensor",
"TORCH_VERSION_AT_LEAST_2_2",
"TORCH_VERSION_AT_LEAST_2_3",
"TORCH_VERSION_AT_LEAST_2_4",
Expand Down Expand Up @@ -281,6 +282,30 @@ def unwrap_tensor_subclass(model, filter_fn=None):
unwrap_tensor_subclass(child)
return model

class TorchAOBaseTensor(torch.Tensor):
"""A util tensor subclass that provides commonly used functions
"""
def _get_to_kwargs(self, *args, **kwargs):
# `torch._C._nn._parse_to` can't handle `layout` argument
for arg in args:
if isinstance(arg, torch.layout):
args.remove(arg)
if "layout" in kwargs:
kwargs.pop("layout")
# ignoring `non_blocking` and `memory_format` args since these are not
# very useful for most of the tensor subclasses
# if in the future there are use cases that need these, we'd recommend
# to override `_get_to_kwargs` and return these args
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
kwargs = {
"device": device,
"dtype": dtype,
}
return kwargs



def parse_version(version_string):
# Extract just the X.Y.Z part from the version string
Expand Down
Loading