-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature request] np.packbits / np.unpackbits, general BitTensors (maybe can be just tensors with dtype torch.bits8 or have a new dtype torch.bits introduced) and bit packed tensors utilities for saving memory / accesses, support for BitTensors wherever BoolTensors are used #292
Comments
One difficulty is that in many points of our code we assume all tensor elements are addressable (for example, for views), and this would not be the case with bit-packed tensors. |
I wonder if we could design some explicit pack/unpack/load/store/index util methods that would be enough for basic usage (like numpy does with packbits/unpackbits) Maybe we could have some unpack method that is optimal if the user themselves provided dtype-aligned indexes This new bittensors wouldn't be first class objects, but still utility methods could be enough for first experimentation. Maybe unpack method could be some variant of |
A simple interface could be packbits / unpackbits like in NumPy with additional bitness argument (to support 1-bit, 2-bit and 4-bit) and dim argument. It should maybe support out argument for unpacked uint8 tensor. Unpacked dimension could always be a new zeroth dimension. |
https://www.microsoft.com/en-us/research/uploads/prod/2018/02/KoPhiliposeTashevZarar_ICASSP_2018.pdf suggests that XNOR and POPCNT functionality is useful for 1-bit networks |
Arguments that can be helpful for packbits/unpackbits:
I guess on CPU packbits/unpackbits can be implemented with those compress/expand SIMD instructions, if the op is performed across contiguous dimension (actually contiguity doesn't matter after a load to a vector register is done already) |
in my code I'd do sth like: |
I think one can think about this feature request as surfacing compress/expand SIMD functionality to user land and reimplementing it on GPU |
pack and unpack seem worth doing. The other parts (i.e. compress/expand) could be useful, but I'm not sure it's worth doing -- it seems like at that point you'd be writing specialized ops in C++ anyway. |
I thought that given a mask pack/unpack are precisely equivalent to SIMD compress/expand? Aren’t they? |
We haven't looked! You're probably right :) |
General int4 support request in pytorch/pytorch#33859 is also related |
Another useful functionality would be scatter/gather-like functionality for compressing index tensors. In a practical usecase it can help to compress the hybrid sparse+dense tensor indices by a lot: https://discuss.pytorch.org/t/sparse-torch-topk/71832/4 |
It seems that <8bit quantization starts to appear: pytorch/pytorch#34783 and seems somewhat related to this discussion |
cc @jspark1105 |
related: pytorch/pytorch#36380 |
I renamed to enlarge the scope a little bit :) BitTensors could be very helpful for binary neural networks. Even if few operators on them are supported (such as bit packbits/unpackbits, binary operations, popcnt), they are already useful for reducing memory footprint, e.g. for storing masks instead of full inputs when sufficient for backward ops. E.g. in pytorch/pytorch#41034 if a mask is stored, the silu/swish operation would become bijective if additional bit is stored to represent direction (half-space) away from function minimum. |
I made a draft (https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a): import math
import torch
def tensor_dim_slice(tensor, dim, s):
return tensor[(slice(None),) * (dim if dim >= 0 else dim + tensor.dim()) + (s, )]
def packshape(shape, dim, mask, dtype):
nbits_element = torch.iinfo(dtype).bits
nbits = 1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111 else None
assert nbits is not None and nbits <= nbits_element and nbits_element % nbits == 0
packed_size = nbits_element // nbits
shape = list(shape)
shape[dim] = int(math.ceil(shape[dim] / packed_size))
return shape, packed_size, nbits
def packbits(tensor, dim = -1, mask = 0b00000001, out = None, dtype = torch.uint8):
shape, packed_size, nbits = packshape(tensor.shape, dim = dim, mask = mask, dtype = dtype)
out = out.zero_() if out is not None else torch.zeros(shape, device = tensor.device, dtype = dtype)
assert tuple(out.shape) == tuple(shape)
for e in range(packed_size):
sliced_input = tensor_dim_slice(tensor, dim, slice(e, None, packed_size))
compress = (sliced_input << (nbits * (packed_size - e - 1)))
sliced_output = out.narrow(dim, 0, sliced_input.shape[dim])
sliced_output |= compress
return out
def unpackbits(tensor, shape, dim = -1, mask = 0b00000001, out = None, dtype = torch.uint8):
_, packed_size, nbits = packshape(shape, dim = dim, mask = mask, dtype = tensor.dtype)
out = out.zero_() if out is not None else torch.zeros(shape, device = tensor.device, dtype = dtype)
assert tuple(out.shape) == tuple(shape)
for e in range(packed_size):
sliced_output = tensor_dim_slice(out, dim, slice(e, None, packed_size))
expand = (tensor >> (nbits * (packed_size - e - 1))) & ((1 << nbits) - 1)
sliced_input = expand.narrow(dim, 0, sliced_output.shape[dim])
sliced_output.copy_(sliced_input)
return out
if __name__ == '__main__':
shape = (10, 17)
K = 10
for nbits in [1, 2, 4, 8]:
mask = (1 << nbits) - 1
for dtype in [torch.uint8, torch.int32, torch.int64]:
for k in range(K):
x = torch.randint(0, 1 << nbits, shape, dtype = dtype)
y = packbits(x, mask = mask)
z = unpackbits(y, mask = mask, dtype = x.dtype, shape = x.shape)
assert torch.allclose(x, z)
Discussion about including in core tensor_slice is in https://discuss.pytorch.org/t/use-python-like-slice-indexing-across-a-given-dimension/89606/7 |
Any advice on how to fuse this properly and for cuda? Should just torch.jit.script work? |
An efficient version for padded tensors so that the relevant dim is multiple of 8 for BoolTensors, multiple of 4 for 2-bit-valued tensors: def packbits_padded(tensor, dim = -1, mask = 0b1, out = None, dtype = torch.uint8):
dim = dim if dim >= 0 else dim + tensor.dim()
nbits_element, nbits = 8, (1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111 else None)
nibbles = nbits_elmement // nbits
assert tensor.shape[dim] % nibbles == 0
out = out if out is not None else torch.empty(*tensor.shape[:dim], tensor.shape[dim] // nibbles, *tensor.shape[1 + dim:], dtype = dtype, device = tensor.device)
shift = torch.arange(nbits_element - nbits, -1, -nbits, dtype = torch.uint8, device = tensor.device)
shift = shift.view(nibbles, *((1,) * (tensor.dim() - dim - 1)))
torch.sum(tensor.view(*tensor.shape[:dim], -1, nibbles, *tensor.shape[1 + dim:]) << shift , dim = 1 + dim, out = out)
def unpackbits_padded(tensor, dim = -1, mask = 0b1, out = None):
dim = dim if dim >= 0 else dim + tensor.dim()
nbits_element, nbits = 8, (1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111 else None)
nibbles = nbits_elmement // nbits
out = out if out is not None else torch.empty(*tensor.shape[:dim], tensor.shape[dim] * nibbles, *tensor.shape[1 + dim:], dtype = torch.uint8, device = tensor.device)
shift = torch.arange(nbits_element - nbits, -1, -nbits, dtype = torch.uint8, device = tensor.device)
shift = shift.view(nibbles, *((1,) * (tensor.dim() - dim - 1)))
torch.bitwise_and((tensor.unsqueeze(1 + dim) >> shift).view_as(out), mask, out = out) Couple issues that would bring further speed-up:
|
If the bitwise ops you need are supported by the fuser, yeah, the JIT fuser should be able to speed some of the things up. I didn't read your code so I don't know for certain if this is the case or not. |
Is it possible to ask the fuser to do loop unrolling? It hopefully may exploit op reordering / asynchrony / parallelism. In this case the number of loop iterations can be known semi-statically. Practically speaking, I'm using the padded version now which runs much faster. Some problems encountered with JIT: |
^ @suo |
(code in https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a) about shapes:
|
I think, for being useful packbits/unpackbits + bit ops + some fused kernels that currently produce only BoolTensor are okay for already being useful. Going forward, it might be nice to have all things that currently support torch.bool to also support torch.uint1. Btw, if we have some generic tiled dtypes, dispatch system for sub-byte dtypes can also become simpler as they are naturally tiled (e.g. think usecase for torch.gt(float_tensor, 4, dtype = torch.bitmap) where a tile of 8 floats could ideally be read at once and be processed into a single byte (tile of 8 bits) on the output). These are probably already dispatched in the similar way for the vectorized kernels on CPU? (as tiled vec types exist) Also, I must bikeshed-confess that uint1 is super unintuitive name for the high-level bitmap/bitmask/bittensor concept. Maybe worth naming it For subclass, I think it would be good to have auto-upcast to torch.bool, so that existing bool kernels can consume bitmaps right away, hiding a reallocation for the eager mode (and ideally optimized away by inductor for the compiled case? ) |
torch.bit is in thanks to @jerryzh168 pytorch/pytorch#117208 :) |
I might be missing the actual changes, but as I understand it pytorch/pytorch#117208 only high-level declares the existence of uint1 etc. Am I misunderstanding this? |
BTW, here is a PyTorch (not native CUDA) implementation of arbitrary k-bit bitpacking (into int64s) in the context of quantization: |
Yeah none of the actual implementation exists yet. Also, the implementation is going to all be in Python, so to get low overhead performance you are going to have to use torch.compile eventually |
Would we really expect that torch.compile will be able to perform efficient bitpacking? Single bit accesses seem something that ideally should be optimized on a very low level (eventually potentially instruction level). At the end of the day, a lot depends on the specifics of memory accesses, but, for any low bit cases, we need to access at least 8 bits for reading a single bit, and memory address-wise, the data type should probably not exist outside of 1+ dimensional tensors. I believe it requires significant consideration to support it properly. (And hardware considerations should be taken into account if it becomes part of a stable release.) Just to clarify, the plan is still that we would still be storing 8 bits in 1 byte (contrasting the current 1 bit in 1 byte of bool)? |
Yes.
So, at least in the short term the way I would expect to implement these operators in terms of regular bitwise operations on the bit tensor reinterpreted as a uint8_t (or larger) tensor. The planned implementation for these sub-byte tensors doesn't really allow for non-aligned accesses anyway (see https://github.com/albanD/subclass_zoo/blob/main/uint4_tensor.py for PoC). |
Also, it could be good to support many ops relevant for torch.bit tensors (bit ops / pack ops etc) on all tensors as well using reinterpret, but encourage/advertise their use only on torch.bit examples (e.g. in pytorch/pytorch#105465 it might make sense to support bit flips on float tensors as well as it could be used for flipping sign or on setting LSB to some external value). At the end of the day, IMHO, what matters is existence of these ops with nice, clear, specific naming and description. So maybe forcing a subclass or reinterpret could even be avoided, although the op examples could promote use of torch.bit for bitset/bitmap/compressedbooltensor usecases I think sub-byte indexing helpers per se is not the most pressing op anyway :) |
I guess it would be good to have benchmarks in test comparing at least initially these Python-compiled bit ops with some manual CUDA loops to be confident in them, as unrolling multiple loop iterations related to bit to have a single memory store per byte (or even several bytes) might be a non-trivial optimization. Maybe for this it would be good to have these wide-dtypes in Python as well (e.g. like float32x8, then it would be sth like: read 8-tuple of floats from a float32tensor.view(torch.float32x8) and then store a single byte as produced output as a single store. or maybe even wider dtypes as outputs (and inputs) would be a good abstraction to help torch.compile produce effective code for these kinds of patterns (read, compute/pack/reduce, write) |
Regarding sub-byte access, I would propose to treat this by introducing some variants of getbit/setbit/togglebit ops to be available for any dtype: pytorch/pytorch#105465, at least this would unblock and make access more elegant and error-free. |
To chime in, my use case for this is to efficiently represent occupancy voxel grids where each voxel is either occupied or not (don't need 8 bits info). I'm running into memory issues when batching these voxel grids, and so better memory efficiency would be great. However, these voxel grids are also frequently read from/written to, so performance is also a concern when considering packing/unpacking. |
Chiming in to show support tor the feature. Similar use case to LemonPi above me - mask tensors using 8 bits per element does chew up more memory than needed. In addition, I'm passing the mask to a CUDA kernel thru Pybind and its a little weird having to cast it as |
Somewhat related (although used ternary) - BitNet: https://arxiv.org/abs/2402.17764 |
Maybe given the revival of interest in extremely low-bit quantization methods, maybe would be nice to also include some eager methods in core (mainly for correct tested semantics, not for actualy speed - speed / Linear op support may come later). I personally am more in favor of adding them first as eager methods working with dense/bits tensors without adding tensor subclasses. Now there is a lot of experimentation with quantization methods, it's developing very fast and unclear if any particular method wins. A lof of this experimentation is working with very simple Linear module swaps, and for this having some eager method under the hood and passing around a few tensors manually representing qparams state is completely fine. |
Related on pack/unpack tuples of floats: #208. It would be nice to have various performant pack/unpack bit utils in core PyTorch (both in eager for semantics experiments + for fusion with triton codegen) |
@vadimkantorov as a baseline we were thinking of writing bit packs in pure pytorch and codegening efficient kernels for them here #291 - it's a good baseline and we'll look to see how those kernels could be made faster |
I think it's important to have a manual kernel as a baseline for perf comparison at least for some cases to be sure that the generated code is fast enough (or at least to review it / have some tests so that it doesn't evolve to something inefficient with Inductor evolvement and development - in terms of memory loads/stores and vectorization) After basic bitpacks, it might be good to be able to produce packed bit outputs instead of booltensors directly of some select functions like torch.gt or being able to use them as input to |
maybe also worth supporting the original numpy's API as well: np.packbits/np.unpackbits - which could redirect under the hood to the more general, versatile interface |
I guess, for packbits/unpackbits ao is indeed a better place, but for more general BitTensor support in ops as compressed explicit-pack-calls-less alternative to BoolTensor - not sure if this discussion should stay in ao or back in core |
@vadimkantorov - I think wherever the discussions happens (here or in core), it becomes easier once we have a prototype. |
BitTensor would also be useful (maybe not absolutely needed for compiled optimizers though, but for expressiveness still useful) for compact storage/encoding of sign output (pytorch/pytorch#130215) - or if TritTensor can be added to store on of three values: -1, 0, 1 |
https://github.com/microsoft/BitNet released code, so having some core support / basic ops support (including pack/unpack and quantize(=torch.gt/lt with bit tensor output instead of bool tensor output)/dequantize) for torch.bit (and ternary tensors) might promote more research in extremely compressed binary nets |
* arg handling * phase ordering issue resolved
A usecase: storing a full backtracking pointer matrix can be okay for needleman/ctc alignment (4x memory saving compared to uint8 representation), if 2bit data type is used. Currently it's possible to do this with bit manipulation magic, but probably not very efficient (store and load will require masking and shifting, not fused)
Another usecase: compressed BoolTensor for binary neural networks
Another usecase: extremely low-bit quantized representations.
Is something like this already implemented for quantization? Probably a simple version of this feature could be providing some explicitly utility functions like calculating size of the holder
uint8
tensor, fused store and load functions (potentially explicitly batched, e.g. actual store is delayed until some aligned number of memory lines has arrived)In NumPy the related functionality is
np.packbits
andnp.unpackbits
, however these are designed to work only with 1-bit contained type. 2-bit/4-bit would be cool as well.On 1-bit side, another related project is RoaringBitmap https://github.com/RoaringBitmap/RoaringBitmap (http://roaringbitmap.org/) - for compressed bitsets for set operations.
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @izdeby
The text was updated successfully, but these errors were encountered: