Skip to content

Fix slice and padding for TensorCoreTiledLayout #2015

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torchao.core.config import AOBaseConfig
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
float8_weight_only,
int4_dynamic_activation_int4_weight,
Expand All @@ -27,7 +28,7 @@
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.testing.utils import skip_if_rocm
from torchao.testing.utils import skip_if_no_cuda, skip_if_rocm
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
Expand Down Expand Up @@ -307,6 +308,19 @@ def test_alias(self, device, dtype):
quantize_(dummy, Int8DynamicActivationInt8WeightConfig())
_ = dummy.weight[...]

@common_utils.parametrize("device", ["cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16])
@skip_if_no_cuda()
def test_slice(self, device, dtype):
# in_feature not divisible by 1024
# out_feature not divisible by 8
# to test slice + padding for int4 weight only quantization
dummy = nn.Linear(256, 321, dtype=dtype, device=device)
quantize_(dummy, Int4WeightOnlyConfig())
# make sure these run without error
_ = dummy.weight.narrow(0, 0, 64)
_ = dummy.weight.narrow(1, 0, 128)


common_utils.instantiate_parametrized_tests(TestAffineQuantized)
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)
Expand Down
10 changes: 7 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def from_hp_to_intx(
)
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
data, scale, zero_point = _layout.post_process(
data, scale, zero_point, block_size
)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
return cls(
Expand Down Expand Up @@ -335,7 +337,7 @@ def from_hp_to_intx_static(
zero_point_domain,
)

int_data = _layout.post_process(int_data)
int_data, scale, zero_point = _layout.post_process(int_data, scale, zero_point)

tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, _layout)
Expand Down Expand Up @@ -429,7 +431,9 @@ def from_hp_to_fpx(
# Note: these ops are hardcoded to have per axis quantization (axis=1) right now
scale = choose_qparams_affine_floatx(input_float, ebits, mbits)
floatx_unpacked = quantize_affine_floatx(input_float, scale, ebits, mbits)
floatx_packed = _layout.post_process(floatx_unpacked)
floatx_packed, scale, _ = _layout.post_process(
floatx_unpacked, scale, None, block_size
)

tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout)
Expand Down
25 changes: 20 additions & 5 deletions torchao/dtypes/uintx/gemlite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,14 +297,31 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
if func is aten.slice.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we update this class to use the @dispatch pattern

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, can I do this in a separate PR?

self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
assert step == 1, "Only step == 1 is supported in slicing right now"
int_data, scale, zero_point = self.get_plain()
data_len = int_data.shape[dim]
param_dim = 1 - dim
scale_len = scale.shape[param_dim]
ratio = data_len / scale_len
start_scale = int(start / ratio)
end_scale = int(end / ratio)

int_data = aten.slice.Tensor(int_data, dim, start, end, step)
int_data = self._layout.post_process(int_data)
scale = aten.slice.Tensor(
scale, param_dim, start_scale, end_scale, step
)
if zero_point is not None and zero_point.numel() > 0:
zero_point = aten.slice.Tensor(
zero_point, param_dim, start_scale, end_scale, step
)
else:
zero_point = None

sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return return_and_correct_aliasing(func, args, kwargs, sliced)
elif dim == 1:
int_data, scale, zero_point = self.get_plain()
assert step == 1, "Only step == 1 is supported in slicing right now"
int_data, scale, zero_point = self.get_plain()
data_len = int_data.shape[dim]
# scale and zero_point are transposed compared to int_data
param_dim = 1 - dim
Expand All @@ -314,7 +331,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
end_scale = int(end / ratio)

int_data = aten.slice.Tensor(int_data, dim, start, end, step)
# this is to handle padding
scale = aten.slice.Tensor(
scale, param_dim, start_scale, end_scale, step
)
Expand All @@ -324,9 +340,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
)
else:
zero_point = None
# import fbvscode; fbvscode.set_trace()
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return sliced
return return_and_correct_aliasing(func, args, kwargs, sliced)
else:
raise NotImplementedError(
f"GemliteAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
Expand Down
31 changes: 19 additions & 12 deletions torchao/dtypes/uintx/int4_cpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,31 +192,26 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

if func is aten.slice.Tensor:
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for this

if dim == 0:
int_data, scale, zero_point = self.get_plain()
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
# this is to handle padding
int_data = self._layout.post_process(int_data)
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return return_and_correct_aliasing(func, args, kwargs, sliced)
elif dim == 1:
int_data, scale, zero_point = self.get_plain()
if dim in [0, 1]:
assert step == 1, "Only step == 1 is supported in slicing right now"
int_data, scale, zero_point = self.get_plain()
data_len = int_data.shape[dim]
scale_len = scale.shape[dim]
ratio = data_len / scale_len
start_scale = int(start / ratio)
end_scale = int(end / ratio)

int_data = aten.slice.Tensor(int_data, dim, start, end, step)
# this is to handle padding
int_data = self._layout.post_process(int_data)
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
zero_point = aten.slice.Tensor(
zero_point, dim, start_scale, end_scale, step
)
# this is to handle padding
int_data, scale, zero_point = self._layout.post_process(
int_data, scale, zero_point, self.block_size
)
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return sliced
return return_and_correct_aliasing(func, args, kwargs, sliced)
else:
raise NotImplementedError(
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
Expand All @@ -228,6 +223,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

@property
def block_size(self):
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros

scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
cur_shape = self.shape
assert len(cur_shape) == 4
inner_k_tiles = cur_shape[-1] * 2
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
groupsize = int(original_shape[1] / scale.shape[-2])
return (1, groupsize)

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
Expand Down
1 change: 0 additions & 1 deletion torchao/dtypes/uintx/marlin_qqq_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def from_hp_to_intx(
data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
input_float, nbits, group_size
)
data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout)
return cls(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,6 @@ def from_hp_to_intx(
)
# Note: output will be uint8 tensor for sub byte tensors for now

data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(
data, scale, zero_point, _layout, **(tensor_impl_ctr_kwargs or {})
Expand Down
49 changes: 35 additions & 14 deletions torchao/dtypes/uintx/tensor_core_tiled_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,30 @@ def pre_process_static(
zero_point = torch.nn.functional.pad(zero_point, padding_changes)
return input, scale, zero_point

def post_process(self, input: torch.Tensor) -> torch.Tensor:
def post_process(
self,
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: Tuple[int, ...],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
orig_out_features, orig_in_features = input.shape
in_features = find_multiple(orig_in_features, 1024)
out_features = find_multiple(orig_out_features, 8)
input = torch.nn.functional.pad(
input,
(0, in_features - orig_in_features, 0, out_features - orig_out_features),
)
return input
assert (
len(block_size) == 2
), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: {block_size}"
scale_pad_dim_0 = (out_features - orig_out_features) // block_size[0]
scale_pad_dim_1 = (in_features - orig_in_features) // block_size[1]
scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0))
zero_point = torch.nn.functional.pad(
zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0)
)
return input, scale, zero_point

def extra_repr(self):
return f"inner_k_tiles={self.inner_k_tiles}"
Expand Down Expand Up @@ -335,31 +350,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

if func is aten.slice.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also same for this

self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
if dim == 0:
int_data, scale, zero_point = self.get_plain()
int_data = aten.slice.Tensor(int_data, dim, start, end, step)
# this is to handle padding
int_data = self._layout.post_process(int_data)
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return return_and_correct_aliasing(func, args, kwargs, sliced)
elif dim == 1:
if dim in [0, 1]:
int_data, scale, zero_point = self.get_plain()
assert step == 1, "Only step == 1 is supported in slicing right now"
data_len = int_data.shape[dim]
scale_len = scale.shape[dim]
ratio = data_len / scale_len
start_scale = int(start / ratio)
end_scale = int(end / ratio)

int_data = aten.slice.Tensor(int_data, dim, start, end, step)
# this is to handle padding
int_data = self._layout.post_process(int_data)
scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step)
zero_point = aten.slice.Tensor(
zero_point, dim, start_scale, end_scale, step
)
# this is to handle padding
int_data, scale, zero_point = self._layout.post_process(
int_data, scale, zero_point, self.block_size
)
sliced = self.from_plain(int_data, scale, zero_point, self._layout)
return sliced
return return_and_correct_aliasing(func, args, kwargs, sliced)
else:
raise NotImplementedError(
f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
Expand All @@ -371,6 +380,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

@property
def block_size(self):
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros

scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
cur_shape = self.shape
assert len(cur_shape) == 4
inner_k_tiles = cur_shape[-1] * 2
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
groupsize = int(original_shape[1] / scale.shape[-2])
return (1, groupsize)

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
Expand Down
10 changes: 8 additions & 2 deletions torchao/dtypes/uintx/uintx_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,14 @@ class UintxLayout(Layout):
dtype: torch.dtype
pack_dim: int = -1

def post_process(self, input: torch.Tensor) -> torch.Tensor:
return to_uintx(input, self.dtype, self.pack_dim)
def post_process(
self,
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: Tuple[int, ...],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return to_uintx(input, self.dtype, self.pack_dim), scale, zero_point


@register_layout(UintxLayout)
Expand Down
10 changes: 8 additions & 2 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,14 @@ class Layout:
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
return input

def post_process(self, input: torch.Tensor) -> torch.Tensor:
return input
def post_process(
self,
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
block_size: Tuple[int, ...],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return input, scale, zero_point

def pre_process_static(
self,
Expand Down
14 changes: 14 additions & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ def wrapper(*args, **kwargs):
return decorator


def skip_if_no_cuda():
import unittest

def decorator(test_func):
def wrapper(*args, **kwargs):
if not torch.cuda.is_available():
raise unittest.SkipTest("No cuda available")
return test_func(*args, **kwargs)

return wrapper

return decorator


# copied from https://github.com/pytorch/pytorch/blob/941d094dd1b507dacf06ddc6ed3485a9537e09b7/test/inductor/test_torchinductor.py#L11389
def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902
for name, value in my_cls.__dict__.items():
Expand Down
Loading