Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Key bit width calculation #2

Open
HunterHantao opened this issue Aug 6, 2024 · 3 comments
Open

Key bit width calculation #2

HunterHantao opened this issue Aug 6, 2024 · 3 comments

Comments

@HunterHantao
Copy link

Hi,

Thank you for your great work and sharing your codes.

May I know know how to compute the key bit width (e.g. 3bit in your paper Table 1)?

If I want to repeat the experiment in Table 1, is this the correct setting?

python longbench.py --model_name "lmsys/longchat-7b-v1.5-32k"
--dtype "float16"
--key_quantization_bits 256
--key_quantization_bits_initial_layers 512
--initial_layers_count 15
--outlier_count_general 8
--outlier_count_initial_layers 8
--value_quantization_bits 2
--group_size 32
--buffer_size 128
--seed 42
--dataset_name [dataset_name]
--n_data 150

Thank you very much and looking forward to your reply.

@majid-daliri
Copy link
Collaborator

Hi @HunterHantao,

Thank you for your kind words and interest in our work.

Yes, that's a correct setting for repeating the experiment in Table 1. Additionally, you can use the model manually by importing it directly from the model directory without needing the code provided.

Feel free to reach out if you have further questions.

@HunterHantao
Copy link
Author

HunterHantao commented Aug 9, 2024

I am not sure the key bit width calculation. Here is my calculation under the aforementioned setting.

For initial 15 layers, each head and each layer, we use unit8 to store 64 values (the last dimension of the variable key_states_quant ), outliers 16 values. Since Llama2-7B head dimension is 128 and ignore the outliers norm,
so the keys bit width in initial layers is
(64x8+16x8)/128 = 5 bit

For other layers, this setting use uint8 to store 32 values , outliers are also 32 values, the key bit width is
(32x8+32x8)/128 = 4 bit

Total average bit width will be (155+174)/32 = 4.47bits, which not aligns with 3 bit setting.

Could you help me clarify my confusions?

@majid-daliri
Copy link
Collaborator

Hi @HunterHantao,

Respectfully, the method of calculation you've suggested is not entirely accurate. We recommend referring to our paper for a detailed explanation of our methodology. The unit_8 that you mentioned is merely a practical workaround in PyTorch, as there isn't a data type in PyTorch that uses just one bit of memory. Therefore, we need to manipulate the bits and make use of the uint_8 data type.

To clarify, we employ 512-bit accuracy for the first 15 layers and 128-bit accuracy for the remaining layers. On average, this equates to:
(15 * 512 + 17 * 128) / 32 = 9856 / 32 = 308

This roughly translates to:
308 / 128 = 2.406 bits per channel
Please feel free to review the paper for further details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants