-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT : Adding BitNet quantization method to HFQuantizer #33410
Conversation
ee0d770
to
3848966
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work @MekkCyber ! This looks pretty good ! I've left a few comments. Could you also add some tests and update the documentation about this new quantizer + fix CI ?
if not torch.cuda.is_available(): | ||
raise RuntimeError("No GPU found. A GPU is needed for quantization.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also run on cpu no since we are just using compile ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it runs on cpu, but it's slow
if device_map is None: | ||
logger.warning_once( | ||
"You have loaded an BitNet model on CPU and have a CUDA device available, make sure to set " | ||
"your model on a GPU device in order to run your model." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to update according to how we fix the above comment
Can you have a look @dacorvo on the BitLinear class ? I would love to have your insights ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice job for adding this @MekkCyber ! Just a few nits ! Could you also update the quantization overview and create a page for bitnet ? I would link to the nanotron PR for users who wants to fine-tune their model and link a script to perform the conversion to the right format (quantization + packing). This way, users will be able to load their 1.58 models with this quantizer.
packed_tensor_shape = (row_dim, *original_shape[1:]) | ||
|
||
packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8) | ||
unpacked = quantized_weights.to(torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for clarity I would have put this line immediately under line 46 as they are both related to the conversion from [-1, 0, 1]/int8 to [0, 1, 2]/uint8.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did the opposite, I put the line where I add +1 immediately before this one, because packed_tensor_shape
is not defined before
packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8) | ||
unpacked = quantized_weights.to(torch.uint8) | ||
|
||
def lshift(t: torch.Tensor, bits: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code reminds me of something ... ;-).
https://github.com/huggingface/optimum-quanto/blob/f62c887731cfc4800f930ba55c3da0262f10f84e/optimum/quanto/tensor/qbits/packed.py#L24
If you don't plan to support the MPS device, then you can inline the <<
operation in the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it's based on quanto 😅
self.register_buffer( | ||
"weight", | ||
torch.zeros( | ||
(out_features // 4, in_features), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider defining a constant for the number of values per item and use it in the whole file.
device=device, | ||
), | ||
) | ||
self.register_buffer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quantizing per-tensor (i.e. using a single scale value for the whole tensor) at that level of quantization will greatly reduce the precision. Are the quantized weights supposed to be fine-tuned ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes bitnet is a QAT method not PTQ, so the model is finetuned with fake quantization layers
Qn = -(2 ** (num_bits - 1)) | ||
Qp = 2 ** (num_bits - 1) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this does not produce ternary outputs for bits=2 (but outputs in [-2, 1, 0, 1]).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is only for activations, as they are quantized in the range [-128, 127] for b = 8
@torch.compile | ||
def activation_quant(self, x, num_bits=8): | ||
""" | ||
Activation function : Performs symmetric, per-channel quantization on the input activations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
per-channel is a bit misleading here, since by convention 'channel' usually denotes the last dimension, and one would expect the quantized output to have one scale per 'channel', i.e. slice along the last dimension where here you obtain one scale for each slice along the first dimension (typically the tokens in language models).
out = input / si | ||
out = out / sw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be done more efficiently as:
out = input / (si * sw)
w_quant = unpack_weights(w, dtype=self.dtype) | ||
x_quant, x_scale = self.activation_quant(x) | ||
y = F.linear(x_quant.to(self.dtype), w_quant) | ||
y = self.post_quant_process(y, self.weight_scale, x_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is prone to overflows in the matmul accumulator: when unpacking you basically recreate a tensor of ternary values expressed in full precision, i.e. not only does it occupy a large chunk of device memory, but is also expressed in a range ([-1, 1]) that is usually larger than the one of the original float16 weights.
You should therefore apply the weight scale immediately to come back into the expected computation range. Same thing for the activations.
In other words: apply the scales before the computation, and not afterwards.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree, but now we are doing the multiplication in bfloat16 or float32, they have a sufficient range, not to overflow. The idea after that is to integrate a kernel for int8*int2 mutliplication that is why we don't apply the scales directly after the unpacking
|
||
def validate_environment(self, *args, **kwargs): | ||
if not is_accelerate_available(): | ||
raise ImportError("Loading an BitNet quantized model requires accelerate (`pip install accelerate`)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise ImportError("Loading an BitNet quantized model requires accelerate (`pip install accelerate`)") | |
raise ImportError("Loading a BitNet quantized model requires accelerate (`pip install accelerate`)") |
|
||
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): | ||
raise ValueError( | ||
"Converting into 8-bit weights from tf/flax weights is currently not supported, please make" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ternary ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this new quantizer ! Excited to see how this quantizer will evolve in the future ! Just a few nits.
X_{dequantized} = X_q * scale_x | ||
$$ | ||
|
||
To learn more about how we trained, and fine-tuned bitnet models checkout the blogpost [here](https://) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To update when the blogpost is released
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | ||
Unpacks a tensor of quantized weights that were stored in a packed format using 2 bits per value. | ||
|
||
Parameters: | ||
----------- | ||
packed : torch.Tensor | ||
A tensor containing packed weights where each element represents 4 quantized values (using 2 bits per value). | ||
dtype : torch.dtype | ||
The dtype of the returned Tensor | ||
Returns: | ||
-------- | ||
torch.Tensor | ||
A tensor of unpacked weights, where each value is converted from its packed 2-bit representation. | ||
|
||
Example: | ||
-------- | ||
packed = torch.tensor([[0b10100001, 0b00011000], | ||
[0b10010000, 0b00001010]], dtype=torch.uint8) | ||
|
||
# Unpack the values | ||
unpacked = unpack_weights(packed) | ||
|
||
# Resulting unpacked tensor | ||
print(unpacked) | ||
# Output: tensor([[ 0, -1], | ||
[-1, 1], | ||
[-1, 1], | ||
[-1, 1], | ||
[ 1, 0], | ||
[ 0, -1], | ||
[ 1, -1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ! Thanks for adding !
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): | ||
logger.warning_once( | ||
"You are attempting to load a BitNet model with a device_map that contains a CPU or disk device." | ||
"This will degrade the inference speed because of weight unpacking" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it really work with cpu or disk offload ? Since the BitNetLinear is composed of buffers essentially, make sure to use with init_empty_weights(include_buffers=True) in _replace_with_bitnet_linear
+ device_map_kwargs["offload_buffers"] = True in modeling_utils.py. Check fbgemm-fp8 code. Otherwise, it won't offload the buffers correctly.
We can do that in a follow up PR if needed and we can just raise an error here for now.
X_{dequantized} = X_q * scale_x | ||
$$ | ||
|
||
To learn more about how we trained, and fine-tuned bitnet models checkout the blogpost [here](https://) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
original_row_dim = packed_shape[0] * VALUES_PER_ITEM | ||
unpacked_shape = (original_row_dim, *packed_shape[1:]) | ||
|
||
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't make more sense to have dtype=dtype
here and then change line 136:
unpacked[start:end] = torch.tensor((packed & mask) >> (2 * i) - 1, dtype=dtype)
This way you could generalize this to other types than int8 (and possibly avoid a copy).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀 let's make sure we link the ressources for training, a gist? or a link to a repo!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating ! Can you fix the merge conflits ? Also, in a previous PR, we changed the is_serializable
method. It's not a property anymore, so you need to change this also here. Thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO missing one small test and good to go!
@slow | ||
@require_torch_gpu | ||
@require_accelerate | ||
class BitNetTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's test serialization as well! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating ! Merging !
…33410) * rebasing changes * fixing style * adding some doc to functions * remove bitblas * change dtype * fixing check_code_quality * fixing import order * adding doc to tree * Small update on BitLinear * adding some tests * sorting imports * small update * reformatting * reformatting * reformatting with ruff * adding assert * changes after review * update disk offloading * adapting after review * Update after review * add is_serializable back * fixing style * adding serialization test * make style * small updates after review
What does this PR do?
This pull request introduces a new quantization method: BitNet quantization at 1.58 bits. It enables users to load and utilize quantized & packed models with ternary weights directly in Transformers, providing out-of-the-box inference.
Who can review?
@SunMarc