Skip to content

Commit

Permalink
[PHI] add int4 weight only quant kernel, add int4 weight only permute…
Browse files Browse the repository at this point in the history
… kernel (#64094)

* Add int4 quantzie kernel and permute kernel

* Update weight_quantize_kernel_gpu_impl.h

* dont reshape it version

* update kernel

* fix int4 quant kernel

* Update weight_quantize_kernel_gpu_impl.h

* fix conflicts

* fix int4 per channel quant row pack error

* fix int4 dequant launch kernel

* remove printf

* add int4 gpucpu check

* Update test_weight_only_linear.py

* Update weight_dequantize_kernel.cu

* fix compile error

* fix

* fix ci

* recommit

* fix code

---------

Co-authored-by: yuanlehome <yuanlehome@163.com>
  • Loading branch information
yinfan98 and yuanlehome authored Jul 24, 2024
1 parent 347bad6 commit 4991383
Show file tree
Hide file tree
Showing 4 changed files with 448 additions and 42 deletions.
10 changes: 6 additions & 4 deletions paddle/phi/kernels/funcs/weight_dequant_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight,

int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32;
int tile_id = blockIdx.x * blockDim.x / 32 + warp_id;
// Every two rows of the original weights are interleaved into a row with
// stride of 64, so if each thread processes 16 elements(for int8, we can use
// ldg.128 to load weights), then every group of four adjacent threads will
// alternately process two different row weights for example every 128
// Every 4 rows of the original weights are interleaved into a row with
// stride of 32, so if each thread processes 16 elements(for int8, we can use
// ldg.128 to load weights), then every group of two adjacent threads will
// alternately process four different row weights for example every 128
// consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave
// layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before
// interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1
Expand Down Expand Up @@ -383,6 +383,7 @@ void WeightDequantize(const Context& dev_ctx,
k,
group_size);
} else if (algo == "weight_only_int4" && group_size == -1) {
k *= 2;
grid.x /= 2;
int4_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),
Expand All @@ -391,6 +392,7 @@ void WeightDequantize(const Context& dev_ctx,
n,
k);
} else if (algo == "weight_only_int4" && group_size > 0) {
k *= 2;
grid.x /= 2;
int4_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),
Expand Down
28 changes: 22 additions & 6 deletions paddle/phi/kernels/gpu/weight_quantize_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,40 @@ void WeightQuantizeKernel(const Context& dev_ctx,
x.data<T>(),
quanted_x.data<int8_t>(),
scale->data<float>(),
weight_shape);
weight_shape,
arch,
algo);
trans(dev_ctx, quanted_x, out, axis);
} else if (algo == "weight_only_int8") {
dev_ctx.template Alloc<T>(scale);
weight_quant_gpu<T, Context>(dev_ctx,
x.data<T>(),
quanted_x.data<int8_t>(),
scale->data<T>(),
weight_shape);
weight_shape,
arch,
algo);
weight_permute_gpu<Context>(dev_ctx,
quanted_x.data<int8_t>(),
out->data<int8_t>(),
weight_shape,
arch);
arch,
algo);
} else if (algo == "weight_only_int4") {
PADDLE_FATAL(
"Weight quant gpu kernel currently don't support weight_only_int4 "
"algo, please use cpu version.");
dev_ctx.template Alloc<T>(scale);
weight_quant_gpu<T, Context>(dev_ctx,
x.data<T>(),
quanted_x.data<int8_t>(),
scale->data<T>(),
weight_shape,
arch,
algo);
weight_permute_gpu<Context>(dev_ctx,
quanted_x.data<int8_t>(),
out->data<int8_t>(),
weight_shape,
arch,
algo);
} else {
PADDLE_FATAL(
"The algo must be in ['weight_only_int8', 'weight_only_int4', "
Expand Down
Loading

0 comments on commit 4991383

Please sign in to comment.