Skip to content

Commit

Permalink
Bitpacking (#291)
Browse files Browse the repository at this point in the history
  • Loading branch information
vayuda committed May 30, 2024
1 parent 5485929 commit 38dad9b
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 0 deletions.
70 changes: 70 additions & 0 deletions test/prototype/test_bitpacking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
from torchao.prototype.common.bitpacking import pack, unpack
import pytest
from torch.utils._triton import has_triton
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4

if not TORCH_VERSION_AFTER_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

def test_uint4_to_uint8_CPU():
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8)
packed = pack(test_tensor, 8, 4, device='cpu')
unpacked = unpack(packed, 4, device='cpu')
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

def test_uint3_to_int16_col_wise_cpu():
test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16)
packed = pack(test_tensor,16, 3, False, device='cpu')
unpacked = unpack(packed, 3, False, device='cpu')
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint4_to_uint8():
test_tensor = torch.randint(0, 15, (4, 4), dtype=torch.uint8).cuda()
packed = pack(test_tensor, 8, 4)
unpacked = unpack(packed, 4)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
def test_uint4_to_uint8_compile():
torch._dynamo.config.specialize_int = True
pack_compiled = torch.compile(pack, fullgraph=True)
unpack_compiled = torch.compile(unpack, fullgraph=True)
test_tensor = torch.randint(0, 15, (3, 4), dtype=torch.uint8).cuda()
packed = pack_compiled(test_tensor, 8, 4)
unpacked = unpack_compiled(packed, 4)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint3_to_int16():
test_tensor = torch.randint(0, 7, (5, 8), dtype=torch.int16).cuda()
packed = pack(test_tensor,16, 3)
unpacked = unpack(packed, 3)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
def test_uint2_to_uint8_col_wise_compile():
torch._dynamo.config.specialize_int = True
pack_compiled = torch.compile(pack, fullgraph=True)
unpack_compiled = torch.compile(unpack, fullgraph=True)
test_tensor = torch.randint(0, 3, (8, 8), dtype=torch.uint8).cuda()
packed = pack_compiled(test_tensor, 8, 2, False)
unpacked = unpack_compiled(packed,2, False)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_uint3_to_int16_col_wise():
test_tensor = torch.randint(0, 7, (8, 5), dtype=torch.int16).cuda()
packed = pack(test_tensor,16, 3, False)
unpacked = unpack(packed, 3, False)
unpadded = unpacked[:test_tensor.shape[0], ...]
assert(unpadded.allclose(test_tensor))
101 changes: 101 additions & 0 deletions torchao/prototype/common/bitpacking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
from functools import reduce



def unpack(data, data_size, by_rows = True, device="cuda"):
"""
Unpacks small dtype elements from a larger dtype.
Inputs:
data: torch.Tensor - a tensor of packed elements of a small dtype within a larger dtype.
data_size: int - the size of the small dtype in bits.
optional:
by_rows: bool - specifies whether to unpack...
by rows: tensor(n,m) -> tensor(n*scale, m)
or by columns: tensor(n,m) -> tensor(n,m*scale)
defaults to rows because quantization is typically done by rows
but choose the version which matches how you quantize as this improves memory accesses/performance
Returns: torch.Tensor - a tensor of the unpacked elements.
"""
if by_rows:
return _unpack_by_rows(data, data_size, device)
else:
return _unpack_by_cols(data, data_size)

def pack(data, container_size, data_size, by_rows = True, device="cuda"):
"""
Packs small dtype elements into a larger dtype.
Pads rows to be divisible by the scale.
Inputs:
data: torch.Tensor - a tensor of unpacked elements of a small dtype.
container_size: int - the size of the large dtype in bits.
data_size: int - the size of the small dtype in bits.
optional:
by_rows: bool - specifies whether to pack values...
by rows: tensor(n,m) -> tensor(n//scale, m)
or by columns: tensor(n,m) -> tensor(n,m//scale)
defaults to rows because quantization is typically done by rows
but choose the version which matches how you quantize as this improves memory accesses/performance
Returns: torch.Tensor - a tensor of packed elements.
"""
if by_rows:
return _pack_by_rows(data, container_size, data_size, device)
else:
return _pack_by_cols(data, container_size, data_size, device)

def _unpack_by_rows(data, data_size, device) -> torch.Tensor:
shape = data.shape
scale = data.element_size() * 8 // data_size

unpacked_data = torch.zeros((shape[0]*scale, *shape[1:]), dtype=data.dtype).to(device)
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits
for i in range(scale):
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint
unpacked_data[i::scale] = ((data >> shift_amt) & (nbits))
return unpacked_data

def _unpack_by_cols(data, data_size) -> torch.Tensor:
shape = data.shape
scale = data.element_size() * 8 // data_size
unpacked_data = []
nbits = (1 << data_size) - 1 # mask for the last dtype_size bits
for i in range(scale):
shift_amt = data.element_size() * 8 - data_size * (i + 1) # how much to shift to get the ith uint
unpacked_data.append(((data >> shift_amt) & (nbits)).to(data.dtype))
return torch.stack(unpacked_data,dim=-1).view(*shape[:-1],shape[-1]*scale) # stack the unpacked data and reshape to the original shape

def _pack_by_rows(data, container_size, data_size, device) -> torch.Tensor:

scale = container_size // data_size
assert scale > 1, f"container_size ({container_size}) is not larger than data_size ({data_size})"
assert data.shape[0] >= scale, f"not enough values to pack, data.shape[0] ({data.shape[0]}) < scale ({scale})"
# pad the data to be divisible by scale
if data.shape[0] % scale != 0:
padding = torch.zeros((scale - data.shape[0] % scale, *data.shape[1:],), dtype=data.dtype).to(device)
data = torch.cat([data, padding], dim=0).cuda()

shape = data.shape
ret = reduce(lambda x,y: x|y,[data[i::scale, ...] << container_size-data_size*(i+1) for i in range(scale)])
return ret.view(shape[0] // scale, *shape[1:]).to(device)

def _pack_by_cols(data, container_size, data_size, device) -> torch.Tensor:
scale = container_size // data_size
assert scale > 1, f"container_size ({container_size}) not double the capacity ofdata_size ({data_size})"
# pad the data to be divisible by scale
if data.shape[-1] % scale != 0:
padding = torch.zeros((*data.shape[:-1], scale - data.shape[-1] % scale), dtype=data.dtype).to(device)
data = torch.cat([data, padding], dim=-1).cuda()

shape = data.shape
data = data.contiguous().view(-1)
#shift the data to the different indexes within the larger dtype and then union them together
ret = reduce(lambda x,y: x|y,[data[i::scale] << container_size-data_size*(i+1) for i in range(scale)])
return ret.view(*shape[:-1],shape[-1] // scale).to(device)

0 comments on commit 38dad9b

Please sign in to comment.