diff --git a/benchmarks/benchmark_bitpacking.py b/benchmarks/benchmark_bitpacking.py deleted file mode 100644 index 9e3d57d50..000000000 --- a/benchmarks/benchmark_bitpacking.py +++ /dev/null @@ -1,93 +0,0 @@ -from math import log -import torch - -from torchao.prototype.common.bitpacking import pack, unpack -from torchao.dtypes.uint4 import unpack_uint4, pack_uint4 - - -def benchmark(function, args, num_runs): - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - - for _ in range(num_runs): - function(*args) - - end_event.record() - torch.cuda.synchronize() - return start_event.elapsed_time(end_event) / num_runs - - -def test_vs_existing(): - def new_(scale): - fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda() - packed = pack(fake_tensor, 4, dim=1) - unpacked = unpack(packed, 4, dim=1) - def old_(scale): - fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda() - packed = pack_uint4(fake_tensor) - unpacked = unpack_uint4(packed) - - - for scale in [256,512, 1024, 2048,4096, 8192]: - new_ = torch.compile(new_, fullgraph=True) - old_ = torch.compile(old_, fullgraph=True) - new_(scale) - old_(scale) - print("scale: ", scale) - print(f"new: {benchmark(new_,[scale], 10)} ms ") - print(f"old: {benchmark(old_,[scale], 10)} ms") - - -def compare_to_fp16(): - class Linear16(torch.nn.Module): - def __init__(self, scale): - super().__init__() - scale += scale % 2 - self.l1 = torch.nn.Linear(scale * 2, scale, bias=False,dtype=torch.float16).cuda() - self.l2 = torch.nn.Linear(scale, scale//2, bias=False,dtype=torch.float16).cuda() - - def forward(self, x): - return self.l2(self.l1(x)) - - class W4A16_symmetric_weight_only(torch.nn.Module): - def __init__(self, scale): - super().__init__() - assert scale % 4 == 0 - self.l1 = torch.randint(2**8,(scale, scale), dtype=torch.uint8).cuda() - self.s1 = torch.tensor((scale),dtype=torch.float16).cuda() - self.l2 = torch.randint(2**8,(scale//2, scale//4), dtype=torch.uint8).cuda() - self.s2 = torch.tensor((scale//4),dtype=torch.float16).cuda() - - - def forward(self, x): - w = unpack(self.l1.detach(), 4, output_dtype=torch.float16) - x = x * self.s1 - x = x @ w - w = unpack(self.l2.detach(), 4, output_dtype=torch.float16) - x = x * self.s2 - x = x @ w - - return x - - torch._dynamo.config.specialize_int = True - for scale in [256,512, 1024, 2048,4096, 8192]: - a = Linear16(scale) - b = W4A16_symmetric_weight_only(scale) - # a = torch.compile(a, fullgraph=True) - b = torch.compile(b, fullgraph=True) - - test_input = torch.randn(scale*2, dtype=torch.float16).cuda() - forward_args = [test_input] - b.forward(test_input) - print("scale: ", scale) - print("fp16 time: ", benchmark(a.forward, forward_args, 100)) - print("uint4 time: ", benchmark(b.forward, forward_args, 100)) - - - -if __name__ == "__main__": - compare_to_fp16() - test_vs_existing() - \ No newline at end of file diff --git a/benchmarks/benchmark_uintx.py b/benchmarks/benchmark_uintx.py new file mode 100644 index 000000000..9887fb8b4 --- /dev/null +++ b/benchmarks/benchmark_uintx.py @@ -0,0 +1,109 @@ +from math import log +from copy import deepcopy + +import torch +from torchao.utils import unwrap_tensor_subclass +from torchao.prototype.uintx import uintx_affine_weight_only, pack, unpack, pack_cpu, unpack_cpu +from torchao.quantization.quant_api import quantize_ + +class Linear16(torch.nn.Module): + def __init__(self, scale): + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Linear(scale*2, scale, bias=True, dtype=torch.float16).cuda(), + torch.nn.Linear(scale, scale, bias=True, dtype=torch.float16).cuda(), + torch.nn.Linear(scale, scale//2, bias=True, dtype=torch.float16).cuda(), + ) + + def forward(self, x): + return self.net(x) + + +def benchmark(function, args, num_runs): + # warmup + torch._dynamo.reset() + for i in range(100): + function(*args) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + + for _ in range(num_runs): + function(*args) + + end_event.record() + torch.cuda.synchronize() + return start_event.elapsed_time(end_event) / num_runs + + +def profile_bitpack(): + from torch.profiler import profile, record_function, ProfilerActivity + fake_tensor = [torch.randint(2**8, (512,512), dtype=torch.uint8).cuda()] + func = torch.compile(unpack_cpu, fullgraph=True) + with profile(activities=[ + ProfilerActivity.CPU, + ProfilerActivity.CUDA], + record_shapes=True, + with_stack=True + ) as prof: + + for _ in range(1000): + unpacked = func(fake_tensor, 4) + + # Print a summary + with open("profile-bitpack.txt", "a") as f: + print(f'{func}',file=f) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10), file=f) + prof.export_chrome_trace("trace.json") + ''' + CPU perf: + unpack_gpu + Self CPU time total: 602.501ms + + unpack_cpu + Self CPU time total: 415.469ms + GPU perf: + unpack_gpu on gpu: + Self CPU time total: 58.512ms + Self CUDA time total: 5.083ms + + unpack_cpu: + Self CPU time total: 96.947ms + Self CUDA time total: 5.253ms + ''' + +def uintx_vs_fp16(nbits= [1,2,3,4,5,6,7], scales=[256, 512, 1024], repeats=30): + results = [] + nbits.sort() + scales.sort() + for scale in scales: + test_input = torch.randn(scale*2, dtype=torch.float16).cuda() + forward_args = [test_input] + times = [scale] + + fp16 = Linear16(scale) + fp16c = torch.compile(fp16, fullgraph=True) + fp16_time = benchmark(fp16c.forward, forward_args, repeats) + times.append(fp16_time) + for bit_size in nbits: + m = deepcopy(fp16) + quantize_(m, uintx_affine_weight_only(bit_size)) + m = torch.compile(m, fullgraph=True) + uintx_time = benchmark(m.forward, forward_args, repeats) + times.append(uintx_time) + print(f'scale={scale} done') + + results.append(times) + print("----------- benchmark results -----------") + for result in results: + print(f"scale: {result[0]} fp16 time:{result[1]: .2f}ms speedups:") + for i in range(2, len(result)): + print(f"int{nbits[i-2]}: {result[1]/result[i]: .2f}x") + + + +if __name__ == "__main__": + uintx_vs_fp16(nbits=[4,7]) + + \ No newline at end of file diff --git a/test/prototype/test_bitpacking.py b/test/prototype/test_bitpacking.py index 52413bf4b..9cd81b35b 100644 --- a/test/prototype/test_bitpacking.py +++ b/test/prototype/test_bitpacking.py @@ -1,143 +1,66 @@ import torch -from torchao.prototype.common.bitpacking import pack, unpack +from torchao.prototype.uintx import pack, unpack, pack_cpu, unpack_cpu import pytest from torch.utils._triton import has_triton -from torchao.utils import TORCH_VERSION_AFTER_2_4 - -if not TORCH_VERSION_AFTER_2_4: - pytest.skip("Unsupported PyTorch version", allow_module_level=True) - -dtypes = ((2, 'trinary', 1), (2, None, 1), (3, None, 2), (4, None, 2), (5, None, 4), (6, None, 4), (7, None, 4)) -dimensions = (2, 1, 0, -1) -orders = (True, False) +element_bit_width = (1,2,3,4,5,6,7) +dimensions = (0, -1, 1) @pytest.fixture(autouse=True) def run_before_and_after_tests(): - # source: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test # noqa: E501 - - # setup (currently do nothing) - - # tests will run here yield + torch._dynamo.reset() # reset cache between tests - # teardown - # avoid dynamo cache limit issues - torch._dynamo.reset() - -@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("element_bit_width", element_bit_width) @pytest.mark.parametrize("dim", dimensions) -@pytest.mark.parametrize("order", orders) -def test_CPU(dtype, dim, order): - element_bit_width, element_type,expected_pack_size = dtype - shape = [4, 4, 4] - if element_type == "trinary": - test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8, device='cpu') - else: - test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8, device='cpu') - - packed = pack(test_tensor, - element_bit_width, - element_type=element_type, - dim = dim, - order = order, - container_dtype = torch.uint8) - assert(packed.shape[dim] == expected_pack_size) - unpacked = unpack(packed, - element_bit_width, - element_type=element_type, - dim = dim, - order = order) +def test_CPU(element_bit_width, dim): + test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8, device='cpu') + packed = pack_cpu(test_tensor, element_bit_width, dim = dim) + unpacked = unpack_cpu(packed, element_bit_width, dim = dim) assert(unpacked.allclose(test_tensor)) - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("dtype", dtypes) + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("element_bit_width", element_bit_width) @pytest.mark.parametrize("dim", dimensions) -@pytest.mark.parametrize("order", orders) -def test_GPU(dtype, dim, order): - element_bit_width, element_type,expected_pack_size = dtype - shape = [4, 4, 4] - if element_type == "trinary": - test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() - else: - test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() - - packed = pack(test_tensor, - element_bit_width, - element_type=element_type, - dim = dim, - order = order, - container_dtype = torch.uint8) - assert(packed.shape[dim] == expected_pack_size) - unpacked = unpack(packed, - element_bit_width, - element_type=element_type, - order = order, - dim = dim) +def test_GPU(element_bit_width, dim): + test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda() + packed = pack(test_tensor, element_bit_width, dim = dim) + unpacked = unpack(packed, element_bit_width, dim = dim) assert(unpacked.allclose(test_tensor)) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("dim", dimensions) -@pytest.mark.parametrize("order", orders) -def test_padding(dtype, dim, order): - element_bit_width, element_type,expected_pack_size = dtype - torch._dynamo.config.specialize_int = True - shape =[4, 4, 4] - shape[dim] = 5 - - if element_type == "trinary": - test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() - else: - test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.uint8).cuda() - - packed = pack(test_tensor, - element_bit_width, - element_type=element_type, - dim = dim, - container_dtype = torch.uint8, - order = order, - pad= True) - assert packed.shape[dim] == expected_pack_size+1, f"packed.shape[dim] {packed.shape[dim]}" # +1 for this scenario - unpacked = unpack(packed, - element_bit_width, - element_type=element_type, - dim = dim, - order = order) - slices = [slice(None)] * packed.ndim - slices[dim] = slice(None, 5) - assert unpacked[slices].allclose(test_tensor) - - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("element_bit_width", element_bit_width) @pytest.mark.parametrize("dim", dimensions) -@pytest.mark.parametrize("order", orders) -def test_compile(dtype, dim, order): - pack_compile = torch.compile(pack, fullgraph=True, dynamic=True) - unpack_compile = torch.compile(unpack, fullgraph=True, dynamic=True) - element_bit_width, element_type,expected_pack_size = dtype +def test_compile(element_bit_width, dim): torch._dynamo.config.specialize_int = True - shape = [4, 4, 4] - if element_type == "trinary": - test_tensor = torch.randint(-1, 1, shape, dtype=torch.int8).cuda() - else: - test_tensor = torch.randint(0, 2**element_bit_width, shape, dtype=torch.int8).cuda() - - packed = pack_compile(test_tensor, element_bit_width, - element_type=element_type, - dim = dim, - container_dtype = torch.int8, - order = order) - assert(packed.shape[dim] == expected_pack_size) - unpacked = unpack_compile(packed, - element_bit_width, - element_type=element_type, - dim = dim, - order = order) + pack_compile = torch.compile(pack, fullgraph=True) + unpack_compile = torch.compile(unpack, fullgraph=True) + test_tensor = torch.randint(0, 2**element_bit_width, (32,32,32), dtype=torch.uint8).cuda() + packed = pack(test_tensor, element_bit_width, dim = dim) + unpacked = unpack(packed, element_bit_width, dim = dim) assert(unpacked.allclose(test_tensor)) + +# these test cases are for the example pack walk through in the bitpacking.py file +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pack_example(): + test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8).cuda() + shard_4,shard_2 = pack(test_tensor, 6) + print(shard_4, shard_2) + assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).cuda().allclose(shard_4) + assert torch.tensor([39, 146], dtype=torch.uint8).cuda().allclose(shard_2) + unpacked = unpack([shard_4, shard_2], 6) + assert unpacked.allclose(test_tensor) + +def test_pack_example_CPU(): + test_tensor = torch.tensor([0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22], dtype=torch.uint8) + shard_4,shard_2 = pack(test_tensor, 6) + print(shard_4, shard_2) + assert torch.tensor([0, 105, 151, 37], dtype=torch.uint8).allclose(shard_4) + assert torch.tensor([39, 146], dtype=torch.uint8).allclose(shard_2) + unpacked = unpack([shard_4, shard_2], 6) + assert unpacked.allclose(test_tensor) + + \ No newline at end of file diff --git a/test/prototype/test_uintx.py b/test/prototype/test_uintx.py new file mode 100644 index 000000000..0a43e3d0c --- /dev/null +++ b/test/prototype/test_uintx.py @@ -0,0 +1,94 @@ +from math import log +from copy import deepcopy +import pytest + +import torch + +from torchao.prototype.uintx import uintx_affine_weight_only, to_uintx +from torchao.quantization.quant_api import quantize_ +from torchao.utils import TORCH_VERSION_AFTER_2_5 + +from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + choose_qparams_affine, + quantize_affine, + dequantize_affine, + ) + +bit_sizes = (1,2,3,4,5,6,7) +group_sizes = [32,64,128] +devices = ["cpu", "cuda"] +@pytest.fixture(autouse=True) +def run_before_and_after_tests(): + yield + torch._dynamo.reset() # reset cache between tests + + + +class Linear16(torch.nn.Module): + def __init__(self, scale, device): + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Linear(scale * 2, scale, bias=False, dtype=torch.float16, device=device), + torch.nn.Linear(scale, scale, bias=False, dtype=torch.float16, device=device), + torch.nn.Linear(scale, scale//2, bias=False, dtype=torch.float16, device=device), + ) + + def forward(self, x): + return self.net(x) + +@pytest.mark.parametrize("bit_size", bit_sizes) +@pytest.mark.parametrize("group_size", group_sizes) +@pytest.mark.parametrize("device", devices) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build") +def test_uintx_affine_weight_only_model_quant(bit_size, group_size, device): + scale = 512 + fp16 = Linear16(scale, device) + quantize_(fp16, uintx_affine_weight_only(bit_size, group_size=group_size)) + uintx = torch.compile(fp16, fullgraph=True) + test_input = torch.randn(scale*2, dtype=torch.float16, device=device) + output = uintx.forward(test_input) + assert output != None, "model quantization failed" + +@pytest.mark.parametrize("bit_size", bit_sizes) +@pytest.mark.parametrize("group_size", group_sizes) +@pytest.mark.parametrize("device", devices) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_5, reason="only works with fix in the nightly build") +def test_uintx_affine_weight_only_quant(bit_size, group_size, device): + input_float = torch.randn((1,256), dtype=torch.float16, device = device) + mapping_type = MappingType.SYMMETRIC + quant_min = 0 + quant_max = 2**bit_size - 1 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + zero_point_domain = ZeroPointDomain.INT + target_dtype = torch.uint8 + block_size = (1, group_size) + + scale, zero_point = choose_qparams_affine( + input_float, mapping_type, block_size, + target_dtype, quant_min, quant_max, eps, torch.float32, + zero_point_dtype, True, zero_point_domain + ) + + aqt = quantize_affine( + input_float, block_size, scale, + zero_point, target_dtype, + quant_min = quant_min, + quant_max = quant_max, + zero_point_domain = zero_point_domain + ) + + q = to_uintx(aqt, bit_size, -1) + assert q != None, "quantization failed" + deqaunt = dequantize_affine( + q, block_size, scale, + zero_point, target_dtype, + quant_min = quant_min, + quant_max = quant_max, + zero_point_domain = zero_point_domain + ) + assert deqaunt != None, "deqauntization failed" diff --git a/torchao/prototype/common/bitpacking.py b/torchao/prototype/common/bitpacking.py deleted file mode 100644 index 867fecdda..000000000 --- a/torchao/prototype/common/bitpacking.py +++ /dev/null @@ -1,124 +0,0 @@ -import torch -from typing import Optional, Union - -def mod_shape(shape, mod, dim): - """changes a select dimension of the input shape to mod""" - a = list(shape) - a[dim] = mod - return tuple(a) - -def unpack(data: torch.Tensor, - element_bit_width: int, - element_type: Optional[str] = None, - dim: Optional[int] = 0, - order: Optional[bool] = True, - output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """ - Unpacks small dtype elements from a larger dtype. - - Inputs: - data: - a tensor of packed elements - element_bit_width: the size in bits of the elements to unpack - element_type: the dtype of the elements to unpack (uint,trinary,float, etc) - dim: the dimension to unpack along - output_dtype: specify the dtype of the output tensor if it is not the same as the input tensor - order: make sure it matches the value set in the pack function - - Returns: torch.Tensor - a tensor of the unpacked elements. - """ - container_size = torch.iinfo(data.dtype).bits - scale = container_size // element_bit_width - device = data.device - - unpacked = _unpack(data, element_bit_width, container_size, scale, order, dim, device) - - if element_type == "trinary": - unpacked = unpacked.to(torch.int8) - 1 - elif output_dtype is not None: - unpacked = unpacked.to(output_dtype) - - return unpacked - -def _unpack(data, element_size, container_size, scale, order, dim, device): - shape = data.shape - unpacked_data = torch.zeros(mod_shape(shape, shape[dim]*scale, dim), dtype=data.dtype).to(device) - nbits = (1 << element_size) - 1 # mask for the last dtype_size bits - for i in range(scale): - if order: - shift_amt = container_size - element_size * (i + 1) - else: - shift_amt = element_size * i - slices = [slice(None)] * unpacked_data.ndim - slices[dim] = slice(i, None, scale) - unpacked_data[slices] = ((data >> shift_amt) & (nbits)).to(data.dtype) - - # stack the unpacked data and reshape to the original shape - return unpacked_data.view(mod_shape(shape,scale*shape[dim], dim)) - - -def pack(data: torch.Tensor, - element_bit_width: int, - element_type: Optional[str] = None, - dim: Optional[int] = 0, - container_dtype: Optional[torch.dtype] = None, - pad: Optional[bool] = False, - order: Optional[bool] = True) -> torch.Tensor: - """ - Packs small dtype elements into a container of a larger dtype. - - Inputs: - data: a tensor of unpacked elements of a small dtype. The dtype used for the data will be used for the container. - dim: the dimension to pack along - element_dtype: the dtype of the elements to pack - container_dtype: specify the dtype of the container if the data is not already inside a tensor of that dtype - pad: if set to true, pads the dimension to be divisible by the scale - order: if set to true, packs elements such that the lower index elements occupy the most significant bits - - Returns: torch.Tensor - a tensor of packed elements. - - - For example, packing 4-bit elements into 8-bit containers. - along dimension 0: along dimension 1: - (0, 9, B, 4) --> ( 9, B4) - (3, 8, F, C) --> (38, FC) - | | | | - v v v v - (3, 98, BF, 4C) - - if order was set to false: - (30, 89, FB, C4) - """ - - if element_type == "trinary": - data = data + 1 - - if container_dtype is not None: - data = data.to(container_dtype) - - device = data.device - - container_size = torch.iinfo(data.dtype).bits - scale = container_size // element_bit_width - - if pad and data.shape[dim] % scale != 0: - padding = torch.zeros(mod_shape(data.shape, scale - data.shape[dim] % scale, dim), dtype=data.dtype).to(device) - data = torch.cat([data, padding], dim=dim).to(device) - - - torch._assert(data.shape[dim] >= scale, f"not enough values to pack along dimension {dim}") - torch._assert(data.shape[dim] % scale == 0, "size of pack dimension not divisble by scale") - return _pack(data, container_size, element_bit_width, scale, dim, order, device) - - - -def _pack(data, container_size, element_bit_width, scale, dim, order, device) -> torch.Tensor: - packed = torch.zeros(mod_shape(data.shape, data.shape[dim] // scale, dim), dtype=data.dtype).to(device) - slices = [slice(None)] * packed.ndim - for i in range(scale): - slices[dim] = slice(i, None, scale) - if order: - packed |= data[slices] << container_size-element_bit_width*(i+1) - else: - packed |= data[slices] << element_bit_width*i - return packed - \ No newline at end of file diff --git a/torchao/prototype/uintx/Uintx.py b/torchao/prototype/uintx/Uintx.py new file mode 100644 index 000000000..bb923132b --- /dev/null +++ b/torchao/prototype/uintx/Uintx.py @@ -0,0 +1,207 @@ +import functools +import math +from collections import defaultdict +from typing import Any, Callable, Dict, Optional, Tuple, Union, List +from dataclasses import dataclass +import torch +from torch._dynamo.comptime import comptime + +from torch.utils._python_dispatch import return_and_correct_aliasing +from .bitpacking import pack, unpack, numbits +from torchao.dtypes.utils import ( + LayoutType, + _implements, + _register_layout_cls, + _dispatch__torch_function__, + _dispatch__torch_dispatch__, +) +from torchao.dtypes.affine_quantized_tensor import PlainAQTLayout, register_layout_cls + + +aten = torch.ops.aten + +class UintxTensor(torch.Tensor): + """ + Splits int data into packed shards based on bit size + fields: + int4_shard (torch.Tensor): 4 bit packed shard + int2_shard (torch.Tensor): 2 bit packed shard + int1_shard (torch.Tensor): 1 bit packed shard + bit_size (int): element size in bits + pack_dim: (int) dimension to pack along + """ + bits_to_shard = { + 1: ["int1_shard"], + 2: ["int2_shard"], + 3: ["int2_shard", "int1_shard"], + 4: ["int4_shard"], + 5: ["int4_shard", "int1_shard"], + 6: ["int4_shard", "int2_shard"], + 7: ["int4_shard", "int2_shard", "int1_shard"], + } + def __new__( + cls, + shards: List[torch.Tensor], + packed_shape: List[int], + bit_size: int, + pack_dim: int = -1, + ): + kwargs = {"device": shards[0].device} + kwargs["device"] = shards[0].device + kwargs["layout"] = shards[0].layout + kwargs["requires_grad"] = False + kwargs["dtype"] = torch.uint8 + return torch.Tensor._make_wrapper_subclass(cls, packed_shape, **kwargs) + + def __init__( + self, + shards: List[torch.Tensor], + packed_shape: List[int], + bit_size: int, + pack_dim: int = -1, + ): + for i, attrib in enumerate(self.bits_to_shard[bit_size]): + setattr(self, attrib, shards[i]) + + self.packed_shape = packed_shape + self.bit_size = bit_size + self.pack_dim = pack_dim + + def get_shards(self): + return [getattr(self,i) for i in self.__class__.bits_to_shard[self.bit_size]] + + def __repr__(self): + return f"Int{self.bit_size}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_size, dim = self.pack_dim)})" + + def __tensor_flatten__(self): + return self.__class__.bits_to_shard[self.bit_size], [self.packed_shape, self.bit_size, self.pack_dim] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + shards = list(tensor_data_dict.values()) + packed_shape, bit_size, pack_dim = tensor_attributes + return cls(shards, packed_shape, bit_size, pack_dim) + + implements = classmethod(_implements) + __torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) + __torch_function__ = classmethod(_dispatch__torch_function__) + + def get_plain(self): + return unpack(self.get_shards(), self.bit_size, dim = self.pack_dim) + + # temporary until kernels on packed tensors are created + def apply_transformation(self, fn): + og = self.get_plain() + new = fn(og) + return self.from_uint8(new, self.bit_size, self.pack_dim) + + # temporary until kernels on packed tensors are created + def apply_fn_to_shards(self, fn): + new_shards = [fn(shard) for shard in self.get_shards()] + return self.__class__(new_shards, self.packed_shape, self.bit_size, self.pack_dim) + + @classmethod + def from_uint8(cls, int_data: torch.Tensor, bit_size, pack_dim: int = -1): + shards = pack(int_data, bit_size, dim=pack_dim) + shape = list(int_data.shape) + shape[pack_dim] = shape[pack_dim] * bit_size // 8 + return cls(shards, int_data.shape, bit_size, pack_dim) + + +implements = UintxTensor.implements + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0].apply_fn_to_shards(torch.detach) + ) + +@implements(aten.view.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:])) + ) + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0] + ) + +@implements(aten.sub.Tensor) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)) + ) + +@implements(aten.mul.Tensor) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)) + ) +# quantization api integrations +to_uintx = UintxTensor.from_uint8 + +@dataclass(frozen=True) +class UintxLayoutType(LayoutType): + bit_size: int + pack_dim: int = -1 + + def post_process(self, input: torch.Tensor) -> torch.Tensor: + return to_uintx(input, self.bit_size, self.pack_dim) + +@register_layout_cls(UintxLayoutType) +class UintxAQTLayout(PlainAQTLayout): + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return self.int_data.get_plain(), self.scale, self.zero_point + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, UintxLayoutType) + return cls(int_data, scale, zero_point, layout_type) + + +def uintx_affine_weight_only(bit_size, group_size=64, pack_dim=-1): + """ + Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where + x is the number of bits specified by the `nbits` argument + """ + from torchao.quantization.quant_primitives import ( + MappingType, + ZeroPointDomain, + choose_qparams_affine, + quantize_affine, + dequantize_affine, + ) + from torchao.dtypes import to_affine_quantized + from torchao.quantization.quant_api import _get_linear_subclass_inserter + def apply_uintx_weight_only_quant(weight): + + layout_type = UintxLayoutType(bit_size=bit_size, pack_dim=pack_dim) + mapping_type = MappingType.ASYMMETRIC + block_size = (1, group_size) + quant_min = 0 + quant_max = 2**bit_size - 1 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + zero_point_domain = ZeroPointDomain.INT + + return to_affine_quantized( + weight, mapping_type, block_size, torch.uint8, + quant_min = quant_min, quant_max = quant_max, + eps = eps, zero_point_dtype=zero_point_dtype, + zero_point_domain=zero_point_domain, + layout_type=layout_type, + ) + + return _get_linear_subclass_inserter(apply_uintx_weight_only_quant) \ No newline at end of file diff --git a/torchao/prototype/uintx/__init__.py b/torchao/prototype/uintx/__init__.py new file mode 100644 index 000000000..610f244f0 --- /dev/null +++ b/torchao/prototype/uintx/__init__.py @@ -0,0 +1,2 @@ +from .Uintx import UintxTensor, to_uintx, uintx_affine_weight_only +from .bitpacking import pack, unpack, pack_cpu, unpack_cpu, numbits diff --git a/torchao/prototype/uintx/bitpacking.py b/torchao/prototype/uintx/bitpacking.py new file mode 100644 index 000000000..13568e899 --- /dev/null +++ b/torchao/prototype/uintx/bitpacking.py @@ -0,0 +1,223 @@ +import torch +import numpy as np +from typing import Optional, Union, List +from functools import reduce + +# for selecting the shards from 8 bits +maskbits = { + 1: (0x01,), + 2: (0x03,), + 3: (0x03, 0x04), + 4: (0x0f,), + 5: (0x0f, 0x10), + 6: (0x0f, 0x30), + 7: (0x0f, 0x30, 0x40), +} + +unpack_mask = { + 1: (0x01,0x02,0x04,0x08, 0x10,0x20,0x40,0x80), + 2: (0x03,0x0c,0x30,0xc0), + 4: (0x0f,0xf0), +} + +# size of each shard +numbits = { + 1: (1,), + 2: (2,), + 3: (2, 1), + 4: (4,), + 5: (4, 1), + 6: (4, 2), + 7: (4, 2, 1), +} + +# shift amount for each shard +shifts = { + 1: (0,), + 2: (0,), + 3: (0, 2), + 4: (0,), + 5: (0, 4), + 6: (0, 4), + 7: (0, 4, 6), +} + +# for shifting groups left but right if shift is negative +def abs_lsh(data, shift): + if shift == 0: + return data + elif shift < 0: + return data >> -shift + else: + return data << shift + + +# inverse of abs_lsh for unpacking +def abs_rsh(data, shift): + if shift == 0: + return data + elif shift < 0: + return data << -shift + else: + return data >> shift + + +def pack_cpu(data: torch.Tensor, + elem_size: int, + dim: Optional[int] = -1) -> List[torch.Tensor]: + """ + Inputs: + data: a tensor of sub byte elements in uint8 + elem_size: the size in bits of the elements to pack + dim: the dimension to pack along + Returns: a list of packed shards + + ================================================================================================== + given an array such as [0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22] which are 8 uint6 elements + first seperate into two shards: the upper 2 bits and the lower 4 bits by using a mask (0x30 and 0x0f respectively) + 2 bit shard: + mask: 0x30 + [0x30, 0x20, 0x10, 0x00, 0x00, 0x10, 0x00, 0x20 ] + [0b00110000, 0b00100000, 0b00010000, 0b00000000, 0b00100000, 0b00010000, 0b00000000, 0b00100000] + + Group elements into subsets that will be shifted to the same position within the 8bit container + group1 >> 4, group2 >> 2, group3 >> 0, group4 << 2 + + [0b00000011, 0b00000010, 0b00000100, 0b00000000, 0b00100000, 0b00010000, 0b00000000, 0b10000000] + |------ group 1 ------| |------ group 2 ------| |------ group 3 ------| |------ group 4 ------| + + Finally bitwise-or the groups together + [0b00000011, 0b00000010, + 0b00000100, 0b00000000, + 0b00100000, 0b00010000, + 0b00000000, 0b01000000] + + [0b00100111, 0b10010010] + ================================================================================================== + Similarly for 4 bit shards: + mask: 0x0f + [0x00, 0x09, 0x07, 0x05, 0x00, 0x16, 0x9, 0x02] + [0b00000000, 0b00001001, 0b00000111, 0b00000101, 0b00000000, 0b00000110, 0b00001001, 0b00000010] + + group1 << 0, group2 << 4 + [0b00000000, 0b00001001, 0b00000111, 0b00000101, 0b00000000, 0b01100000, 0b10010000, 0b00100000] + |------------------ group 1 ------------------| |------------------ group 2 ------------------| + + bitwise-or: + [0b00000000, 0b00001001, 0b00000111, 0b00000101, + 0b00000000, 0b01100000, 0b10010000, 0b00100000] + + [0b00000000, 0b01101001, 0b10010111, 0b00100101] + ================================================================================================== + After pack, data went from 8 elements to 6: [[0, 105, 151, 37], [39, 146]] + In general this means pack reduces input tensor size from n * 8 to n * elem_size + """ + torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") + torch._assert(data.dtype == torch.uint8, "data must be uint8") + output_shape = list(data.shape) + + output = [] + for i in range(len(numbits[elem_size])): + output_shape[dim] = data.shape[dim] * numbits[elem_size][i] // 8 + shard = torch.zeros(output_shape, dtype=torch.uint8, device=data.device) + bit_size = numbits[elem_size][i] + rel_pos = shifts[elem_size][i] + bits = data & maskbits[elem_size][i] + scale = 8 // bit_size + slice_len = bits.shape[dim] // scale + for j in range(scale): + bit_slice = bits.narrow(dim, slice_len * j, slice_len) + shard = torch.bitwise_or(shard, abs_lsh(bit_slice, j * bit_size - rel_pos)) + output.append(shard) + return output + + +def unpack_cpu(data: List[torch.Tensor], + elem_size: int, + dim: Optional[int] = -1) -> torch.Tensor: + """ + Unpacks small dtype elements from a larger dtype. + + Inputs: + data: - a list of packed shards + elem_size: the size in bits of the elements to unpack + dim: the dimension to unpack along + + Returns: torch.Tensor - a tensor of the unpacked elements. + """ + # define the output tensor + output_shape = list(data[0].shape) + output_shape[dim] = data[0].shape[dim] * 8 // numbits[elem_size][0] + output = torch.zeros(output_shape, dtype=torch.uint8, device=data[0].device) + + for i in range(len(numbits[elem_size])): + # define variables for the current shard + bit_size = numbits[elem_size][i] + rel_pos = shifts[elem_size][i] + scale = 8 // bit_size + group_size = bit_size * output_shape[dim] // 8 + # mask and shift every group of bits to the correct position + for j in range(scale): + output_narrow = output.narrow(dim, j * group_size, group_size) + group = data[i] & unpack_mask[bit_size][j] + shift_amt = j * bit_size - rel_pos + output_narrow.copy_(torch.bitwise_or(output_narrow, abs_rsh(group, j * bit_size - rel_pos))) + return output + +# these are faster on the GPU + +def _pack(data, elem_size, scale, dim): + ''' + Inner for loop from above pack function + ''' + packed_shape = list(data.shape) + packed_shape[dim] = packed_shape[dim] // scale + + packed = torch.zeros(packed_shape, dtype=data.dtype, device=data.device) + + for i in range(scale): + narrow_slice = data.narrow(dim, data.shape[dim]*i//scale, data.shape[dim] // scale) + packed |= narrow_slice << (elem_size * i) + + return packed + +def _unpack(data, element_size, scale, dim): + ''' + Inner for loop from above unpack function + ''' + unpacked_shape = list(data.shape) + unpacked_shape[dim] *= scale + + nbits = (1 << element_size) - 1 # mask for the last element_size bits + + unpacked_data = torch.zeros(unpacked_shape, dtype=data.dtype, device=data.device) + + for i in range(scale): + shift_amt = element_size * i + chunk = unpacked_data.narrow(dim, unpacked_data.shape[dim]*i//scale, unpacked_data.shape[dim] // scale).copy_((data >> shift_amt) & nbits) + + return unpacked_data + + +def pack(data: torch.Tensor, + elem_size: int, + dim: Optional[int] = -1) -> List[torch.Tensor]: + ''' + a less branching but more compute version so better for gpu + ''' + torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") + torch._assert(data.dtype == torch.uint8, "data must be uint8") + container_size = 8 + shards = [(data & maskbits[elem_size][i]) >> shifts[elem_size][i] for i in range(len(maskbits[elem_size]))] + return tuple([_pack(shards[i], numbits[elem_size][i], container_size//numbits[elem_size][i], dim) for i in range(len(maskbits[elem_size]))]) + +def unpack(data: List[torch.Tensor], + elem_size: int, + dim: Optional[int] = 0) -> torch.Tensor: + ''' + a less branching but more compute version so better for gpu + ''' + container_size = 8 + # unpack each 4,2,1 bit shard and unshift them back to the correct position + data = [_unpack(data[i], numbits[elem_size][i], container_size // numbits[elem_size][i], dim) << shifts[elem_size][i] for i in range(len(data))] + return reduce(torch.bitwise_or, data)