Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT : Adding BitNet quantization method to HFQuantizer #33410

Merged
merged 27 commits into from
Oct 9, 2024

Conversation

MekkCyber
Copy link
Contributor

What does this PR do?

This pull request introduces a new quantization method: BitNet quantization at 1.58 bits. It enables users to load and utilize quantized & packed models with ternary weights directly in Transformers, providing out-of-the-box inference.

Who can review?

@SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work @MekkCyber ! This looks pretty good ! I've left a few comments. Could you also add some tests and update the documentation about this new quantizer + fix CI ?

Comment on lines 61 to 62
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also run on cpu no since we are just using compile ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it runs on cpu, but it's slow

Comment on lines 65 to 69
if device_map is None:
logger.warning_once(
"You have loaded an BitNet model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to update according to how we fix the above comment

@SunMarc SunMarc requested a review from dacorvo September 10, 2024 16:50
@SunMarc
Copy link
Member

SunMarc commented Sep 10, 2024

Can you have a look @dacorvo on the BitLinear class ? I would love to have your insights !

SunMarc
SunMarc previously approved these changes Sep 12, 2024
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job for adding this @MekkCyber ! Just a few nits ! Could you also update the quantization overview and create a page for bitnet ? I would link to the nanotron PR for users who wants to fine-tune their model and link a script to perform the conversion to the right format (quantization + packing). This way, users will be able to load their 1.58 models with this quantizer.

@SunMarc SunMarc dismissed their stale review September 12, 2024 16:11

Needs to add docs

packed_tensor_shape = (row_dim, *original_shape[1:])

packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8)
unpacked = quantized_weights.to(torch.uint8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for clarity I would have put this line immediately under line 46 as they are both related to the conversion from [-1, 0, 1]/int8 to [0, 1, 2]/uint8.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did the opposite, I put the line where I add +1 immediately before this one, because packed_tensor_shape is not defined before

packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8)
unpacked = quantized_weights.to(torch.uint8)

def lshift(t: torch.Tensor, bits: int):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code reminds me of something ... ;-).
https://github.com/huggingface/optimum-quanto/blob/f62c887731cfc4800f930ba55c3da0262f10f84e/optimum/quanto/tensor/qbits/packed.py#L24
If you don't plan to support the MPS device, then you can inline the << operation in the loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's based on quanto 😅

self.register_buffer(
"weight",
torch.zeros(
(out_features // 4, in_features),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider defining a constant for the number of values per item and use it in the whole file.

device=device,
),
)
self.register_buffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quantizing per-tensor (i.e. using a single scale value for the whole tensor) at that level of quantization will greatly reduce the precision. Are the quantized weights supposed to be fine-tuned ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes bitnet is a QAT method not PTQ, so the model is finetuned with fake quantization layers

Comment on lines +185 to +186
Qn = -(2 ** (num_bits - 1))
Qp = 2 ** (num_bits - 1) - 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this does not produce ternary outputs for bits=2 (but outputs in [-2, 1, 0, 1]).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is only for activations, as they are quantized in the range [-128, 127] for b = 8

@torch.compile
def activation_quant(self, x, num_bits=8):
"""
Activation function : Performs symmetric, per-channel quantization on the input activations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per-channel is a bit misleading here, since by convention 'channel' usually denotes the last dimension, and one would expect the quantized output to have one scale per 'channel', i.e. slice along the last dimension where here you obtain one scale for each slice along the first dimension (typically the tokens in language models).

Comment on lines 193 to 194
out = input / si
out = out / sw
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be done more efficiently as:

out = input / (si * sw)

Comment on lines 199 to 202
w_quant = unpack_weights(w, dtype=self.dtype)
x_quant, x_scale = self.activation_quant(x)
y = F.linear(x_quant.to(self.dtype), w_quant)
y = self.post_quant_process(y, self.weight_scale, x_scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is prone to overflows in the matmul accumulator: when unpacking you basically recreate a tensor of ternary values expressed in full precision, i.e. not only does it occupy a large chunk of device memory, but is also expressed in a range ([-1, 1]) that is usually larger than the one of the original float16 weights.
You should therefore apply the weight scale immediately to come back into the expected computation range. Same thing for the activations.

In other words: apply the scales before the computation, and not afterwards.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree, but now we are doing the multiplication in bfloat16 or float32, they have a sufficient range, not to overflow. The idea after that is to integrate a kernel for int8*int2 mutliplication that is why we don't apply the scales directly after the unpacking


def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
raise ImportError("Loading an BitNet quantized model requires accelerate (`pip install accelerate`)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ImportError("Loading an BitNet quantized model requires accelerate (`pip install accelerate`)")
raise ImportError("Loading a BitNet quantized model requires accelerate (`pip install accelerate`)")


if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
raise ValueError(
"Converting into 8-bit weights from tf/flax weights is currently not supported, please make"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ternary ?

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this new quantizer ! Excited to see how this quantizer will evolve in the future ! Just a few nits.

X_{dequantized} = X_q * scale_x
$$

To learn more about how we trained, and fine-tuned bitnet models checkout the blogpost [here](https://)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To update when the blogpost is released

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +74 to +104
"""
Unpacks a tensor of quantized weights that were stored in a packed format using 2 bits per value.

Parameters:
-----------
packed : torch.Tensor
A tensor containing packed weights where each element represents 4 quantized values (using 2 bits per value).
dtype : torch.dtype
The dtype of the returned Tensor
Returns:
--------
torch.Tensor
A tensor of unpacked weights, where each value is converted from its packed 2-bit representation.

Example:
--------
packed = torch.tensor([[0b10100001, 0b00011000],
[0b10010000, 0b00001010]], dtype=torch.uint8)

# Unpack the values
unpacked = unpack_weights(packed)

# Resulting unpacked tensor
print(unpacked)
# Output: tensor([[ 0, -1],
[-1, 1],
[-1, 1],
[-1, 1],
[ 1, 0],
[ 0, -1],
[ 1, -1],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! Thanks for adding !

Comment on lines 70 to 74
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
logger.warning_once(
"You are attempting to load a BitNet model with a device_map that contains a CPU or disk device."
"This will degrade the inference speed because of weight unpacking"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it really work with cpu or disk offload ? Since the BitNetLinear is composed of buffers essentially, make sure to use with init_empty_weights(include_buffers=True) in _replace_with_bitnet_linear + device_map_kwargs["offload_buffers"] = True in modeling_utils.py. Check fbgemm-fp8 code. Otherwise, it won't offload the buffers correctly.
We can do that in a follow up PR if needed and we can just raise an error here for now.

X_{dequantized} = X_q * scale_x
$$

To learn more about how we trained, and fine-tuned bitnet models checkout the blogpost [here](https://)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

original_row_dim = packed_shape[0] * VALUES_PER_ITEM
unpacked_shape = (original_row_dim, *packed_shape[1:])

unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
Copy link
Contributor

@tengomucho tengomucho Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't make more sense to have dtype=dtype here and then change line 136:

        unpacked[start:end] = torch.tensor((packed & mask) >> (2 * i) - 1, dtype=dtype)

This way you could generalize this to other types than int8 (and possibly avoid a copy).

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀 let's make sure we link the ressources for training, a gist? or a link to a repo!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating ! Can you fix the merge conflits ? Also, in a previous PR, we changed the is_serializable method. It's not a property anymore, so you need to change this also here. Thanks !

@SunMarc SunMarc requested a review from ArthurZucker October 1, 2024 14:47
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO missing one small test and good to go!

@slow
@require_torch_gpu
@require_accelerate
class BitNetTest(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's test serialization as well! 🤗

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating ! Merging !

@SunMarc SunMarc merged commit 36d410d into huggingface:main Oct 9, 2024
24 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…33410)

* rebasing changes

* fixing style

* adding some doc to functions

* remove bitblas

* change dtype

* fixing check_code_quality

* fixing import order

* adding doc to tree

* Small update on BitLinear

* adding some tests

* sorting imports

* small update

* reformatting

* reformatting

* reformatting with ruff

* adding assert

* changes after review

* update disk offloading

* adapting after review

* Update after review

* add is_serializable back

* fixing style

* adding serialization test

* make style

* small updates after review
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants