Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] add int4 weight only quant kernel, add int4 weight only permute kernel #64094

Merged
merged 27 commits into from
Jul 24, 2024

Conversation

yinfan98
Copy link
Contributor

@yinfan98 yinfan98 commented May 7, 2024

PR Category

Inference

PR Types

New features

Description

pcard-71500

给paddle添加int4量化的kernel和int4量化进行permute的kernel。

TL;DR

支持了一个GPU kernel,它能做int4 weight only量化的工作。并且能支持weight_only_linear
(同时也能和反量化接口对齐,如果你想单纯做量化反量化看看。你可以这么执行代码)

import paddle
x = paddle.randn(shape=[4096, 2048], dtype=paddle.float16)
qt, scale = paddle.nn.quant.weight_quantize(x, algo='weight_only_int4')
## 啊 paddle暂时还不可以形状推导。 但是PR已经在合了
## view之前的shape应该是[1024, 4096],这个shape是做weight only linear用的。后续也可以加一个接口判断是否矩阵乘法来判断是否在c++侧reshape
qt = qt.view([2048, 2048])
x_dq = paddle.nn.quant.weight_dequantize(qt, scale, algo='weight_only_int4')

当然,weight only linear也是支持的

import paddle
from paddle.nn.quant import weight_only_linear, weight_quantize, weight_dequantize

x = paddle.rand(shape=(2, 4096), dtype='float16')

weight = paddle.randn(shape=(4096, 2048), dtype='float32')
weight = weight.astype('float16')

quant_weight, quant_scale = weight_quantize(x=weight, algo='weight_only_int4')
quant_out = weight_only_linear(x=quant_x, weight=quant_weight, weight_scale=quant_scale, weight_dtype="int4")
## 能和它大概对齐吧,毕竟int4量化的精度低的离谱 out = paddle.matmul(x=x, y=weight)

int4 weight only quant总结

参考CPU的实现,SM70以上kernel的实现分几个步骤:

  1. 按行进行pack(2int4pack成一个int8)
  2. permute_B_rows_for_mixed_gemm:排布列方向的元素
  3. subbyte_transpose:把列主序的weight变成行主序的,并且由按行进行pack转化成按列进行pack。
  4. interleave_column_major_tensor:每64个元素进行interleave
  5. add_bias_and_interleave_int4s_inplace:把int8转换成uint8(+8)

但是我们其实不需要这么复杂的实现,我们可以直接就按列进行pack。也能达到一样的效果。并且只需要两个kernel(加上量化需要三个kernel)。方法如下:

int4量化kernel

对于int4量化来说,我们分别实现了按行pack和按列pack。(为了让SM70版本的显卡也能正常工作QAQ)
对按列pack来说,它需要让两个int4pack成一个int8的数进行实现。在代码里,我们让上下两行组成一个int8的数,也就是按列进行的pack。

int4 permute kernel

对于int4量化,我们需要对输入数据进行重排来适配cutlass的快速反量化kernel。
在int4反量化端,我们观察反量化算子实现可以发现。最后所需的输出是:

0   1   8   9  16  17  24  25   2   3  10  11  18  19  26  27
4   5  12  13  20  21  28  29   6   7  14  15  22  23  30  31

参考cutlass的快速反量化实现。
int4快速反量化4个int8一组,能把int8的数据转换为fp16的。但它会改变数据的排布:

0 2 4 6 1 3 5 7 -> 0 1 2 3 4 5 6 7

则我们可以推得在快速反量化之前,我们需要的数据是

//  0   8  16  24   1   9  17  25   2  10  18  26   3  11  19  27
//  4  12  20  28   5  13  21  29   6  14  22  30   7  15  23  31

上面一组数看上去没有任何的规律,但是我们可以给它做一点小小的调整,调整成下面的形式,只需要一些简单的位运算即可

// 0 1 16 17 8 9 24 25 2 3 18 19 10 11 26 27
// 4 5 20 21 12 13 28 29 6 7 22 23 14 15 30 31

我们知道,两个int4 pack成了一个int8,我们也可以把上面的数调整成int8的index

0 8 4 12 1 9 5 13 2 10 6 14 3 11 7 15

那么从

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 -> 0 8 4 12 1 9 5 13 2 10 6 14 3 11 7 15

的坐标为

0 4 8 12 2 6 10 14 1 5 9 13 3 7 11 15

得到这个新的permute_kk(代码里的变量,描述列之间的permute),可以通过int8的permute_kk做一点小小的改变
从int8 permute转换为int4 permute
int8

0 2 4 6 8 10 12 14 1 3 5 7 9 11 13 15

可以把它变成

0 2 4 6 1 3 5 7 8 10 12 14 9 11 13 15

% 8 * 2

0 4 8 12 2 6 10 14 0 4 8 12 2 6 10 14
add 1 for 0 4 8 12 2 6 10 14 [0 4 8 12 2 6 10 14]

简单的位运算kernel(最后执行)

// (0 1) (16 17) (8 9) (24 25) (2 3) (18 19) (10 11) (26 27)
// (4 5) (20 21) (12 13) (28 29) (6 7) (22 23) (14 15) (30 31)

//  0   8  16  24   1   9  17  25   2  10  18  26   3  11  19  27
//  4  12  20  28   5  13  21  29   6  14  22  30   7  15  23  31

我们可以每四个数一组,然后02 13 之间做低四位和高四位的交换即可。

int4 row interleave

对于int8的case,代码在相邻的两行中,每64个元素进行交织。但是对于int4的情况。代码就会在相邻的四行中,每32个元素进行交织。所以在permute的处理时,写成了

int permute_index = permute_kk % 32 + permute_kk / 32 * 128 +
                        32 * (n_id % 4) + total_k * 4 * (n_id / 4);

这样也符合预期。(写着写着天都亮了zzz)

Copy link

paddle-bot bot commented May 7, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label May 7, 2024
Copy link

paddle-ci-bot bot commented May 19, 2024

Sorry to inform you that 19619b4's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link

paddle-ci-bot bot commented Jul 16, 2024

Sorry to inform you that 45b13ac's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@yuanlehome yuanlehome left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yuanlehome yuanlehome merged commit 4991383 into PaddlePaddle:develop Jul 24, 2024
31 checks passed
co63oc pushed a commit to co63oc/Paddle that referenced this pull request Jul 25, 2024
… kernel (PaddlePaddle#64094)

* Add int4 quantzie kernel and permute kernel

* Update weight_quantize_kernel_gpu_impl.h

* dont reshape it version

* update kernel

* fix int4 quant kernel

* Update weight_quantize_kernel_gpu_impl.h

* fix conflicts

* fix int4 per channel quant row pack error

* fix int4 dequant launch kernel

* remove printf

* add int4 gpucpu check

* Update test_weight_only_linear.py

* Update weight_dequantize_kernel.cu

* fix compile error

* fix

* fix ci

* recommit

* fix code

---------

Co-authored-by: yuanlehome <yuanlehome@163.com>
inaomIIsfarell pushed a commit to inaomIIsfarell/Paddle that referenced this pull request Jul 31, 2024
… kernel (PaddlePaddle#64094)

* Add int4 quantzie kernel and permute kernel

* Update weight_quantize_kernel_gpu_impl.h

* dont reshape it version

* update kernel

* fix int4 quant kernel

* Update weight_quantize_kernel_gpu_impl.h

* fix conflicts

* fix int4 per channel quant row pack error

* fix int4 dequant launch kernel

* remove printf

* add int4 gpucpu check

* Update test_weight_only_linear.py

* Update weight_dequantize_kernel.cu

* fix compile error

* fix

* fix ci

* recommit

* fix code

---------

Co-authored-by: yuanlehome <yuanlehome@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants