From 176b277316a5ddb8d79e6fee8edf293cf770fded Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 27 Dec 2023 13:53:08 +0800 Subject: [PATCH] [inference] Support groupwise mode of gemv kernel (#60204) * support gemv-groupwise func && weightQuanter-groupwise && weightDeQuanter-groupwise * fix build bug * add unit_test && fix bug * delete useless code * fix ci build bug * fix ci && optimize * fix merge conflict * add op change info * fix weight_only_linear_pass * fix format * solve ci unit_test --- .../fusion/fused_weight_only_linear_pass.cc | 11 +- paddle/phi/api/yaml/backward.yaml | 4 +- paddle/phi/api/yaml/op_version.yaml | 24 + paddle/phi/api/yaml/ops.yaml | 6 +- paddle/phi/infermeta/backward.cc | 7 + paddle/phi/infermeta/backward.h | 1 + paddle/phi/infermeta/binary.cc | 51 +- paddle/phi/infermeta/binary.h | 1 + paddle/phi/infermeta/multiary.cc | 36 +- paddle/phi/infermeta/multiary.h | 1 + paddle/phi/infermeta/unary.cc | 17 +- paddle/phi/infermeta/unary.h | 1 + .../phi/kernels/cpu/weight_quantize_kernel.cc | 29 +- .../kernels/funcs/weight_dequant_functor.h | 142 ++- paddle/phi/kernels/funcs/weight_only_gemv.cu | 949 ++++++++++++------ paddle/phi/kernels/funcs/weight_only_gemv.h | 23 +- .../kernels/gpu/weight_dequantize_kernel.cu | 3 +- .../gpu/weight_only_linear_grad_kernel.cu | 16 +- .../kernels/gpu/weight_only_linear_kernel.cu | 49 +- .../phi/kernels/gpu/weight_quantize_kernel.cu | 7 + .../impl/weight_quantize_kernel_impl.h | 70 ++ paddle/phi/kernels/weight_dequantize_kernel.h | 1 + .../kernels/weight_only_linear_grad_kernel.h | 1 + .../phi/kernels/weight_only_linear_kernel.h | 1 + paddle/phi/kernels/weight_quantize_kernel.h | 1 + python/paddle/nn/quant/quantized_linear.py | 47 +- test/quantization/test_weight_only_linear.py | 263 ++++- 27 files changed, 1331 insertions(+), 431 deletions(-) diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index 57485355ad22d..fa83418e562ba 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -96,9 +96,14 @@ class FusedWeightOnlyLinearPattern return getSMVersion(); }); + const auto &group_size_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> int { return -1; }); + const auto &weight_quantize = res.Op(paddle::dialect::WeightQuantizeOp::name(), - {{"algo", weight_only_int8_attr}, {"arch", arch_attr}}); + {{"algo", weight_only_int8_attr}, + {"arch", arch_attr}, + {"group_size", group_size_attr}}); weight_quantize({&res.Tensor("w")}, {&res.Tensor("quanted_weight_tensor"), &res.Tensor("weight_scale_tensor")}); @@ -110,7 +115,9 @@ class FusedWeightOnlyLinearPattern const auto &weight_only_linear = res.Op(paddle::dialect::WeightOnlyLinearOp::name(), - {{"weight_dtype", weight_dtype_attr}, {"arch", arch_attr}}); + {{"weight_dtype", weight_dtype_attr}, + {"arch", arch_attr}, + {"group_size", group_size_attr}}); weight_only_linear({&res.Tensor("x"), &res.Tensor("quanted_weight_tensor"), &res.Tensor("bias"), diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 21ec2126c8f94..938ea9d500046 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -2603,8 +2603,8 @@ no_need_buffer : input - backward_op : weight_only_linear_grad - forward : weight_only_linear(Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch) -> Tensor(out) - args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, Tensor out_grad, str weight_dtype, int arch) + forward : weight_only_linear(Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch, int group_size) -> Tensor(out) + args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, Tensor out_grad, str weight_dtype, int arch, int group_size) output : Tensor(x_grad) infer_meta : func : WeightOnlyLinearGradInferMeta diff --git a/paddle/phi/api/yaml/op_version.yaml b/paddle/phi/api/yaml/op_version.yaml index bd296a6191de3..7c9618f52b17b 100644 --- a/paddle/phi/api/yaml/op_version.yaml +++ b/paddle/phi/api/yaml/op_version.yaml @@ -472,6 +472,30 @@ comment : The axis to apply unique. If None, the input will be flattened. default : std::vector{} +- op : weight_dequantize + version : + - checkpoint : Upgrade weight_dequantize, add a new attribute [group_size] + action : + - add_attr : group_size + comment : The group size of the dequantization scales. + default : -1 + +- op : weight_only_linear + version : + - checkpoint : Upgrade weight_only_linear, add a new attribute [group_size] + action : + - add_attr : group_size + comment : The group size of the dequantization scales. + default : -1 + +- op : weight_quantize + version : + - checkpoint : Upgrade weight_quantize, add a new attribute [group_size] + action : + - add_attr : group_size + comment : The group size of the quantization scales. + default : -1 + - op : yolo_box version : - checkpoint : Upgrade yolo box to add new attribute [iou_aware, iou_aware_factor]. diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c15fb2fdb1998..de7c49250ea16 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2833,7 +2833,7 @@ backward : warprnnt_grad - op : weight_dequantize - args : (Tensor x, Tensor scale, str algo="weight_only_int8", DataType out_dtype=DataType::FLOAT16) + args : (Tensor x, Tensor scale, str algo = "weight_only_int8", DataType out_dtype = DataType::FLOAT16, int group_size = -1) output : Tensor(out) infer_meta : func : WeightDequantizeInferMeta @@ -2842,7 +2842,7 @@ data_type : out_dtype - op : weight_only_linear - args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch = 80) + args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch = 80, int group_size = -1) output : Tensor(out) infer_meta : func : WeightOnlyLinearInferMeta @@ -2853,7 +2853,7 @@ backward: weight_only_linear_grad - op : weight_quantize - args : (Tensor x, str algo = "weight_only_int8", int arch = 80) + args : (Tensor x, str algo = "weight_only_int8", int arch = 80, int group_size = -1) output : Tensor(out), Tensor(scale) infer_meta : func : WeightQuantizeInferMeta diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 6d6eab8097337..ee2388762668b 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -1191,12 +1191,19 @@ void WeightOnlyLinearGradInferMeta(const MetaTensor& x, const MetaTensor& out_grad, const std::string& weight_dtype, const int32_t arch, + const int32_t group_size, MetaTensor* x_grad) { PADDLE_ENFORCE_EQ( ((arch == 80) || (arch == 86)), true, phi::errors::InvalidArgument( "Currently weightonly linear grad only support arch = 80 or 86. ")); + PADDLE_ENFORCE_EQ( + group_size, + -1, + phi::errors::InvalidArgument( + "Currently weightonly linear grad only support per-channel mode. ")); + x_grad->set_dims(x.dims()); x_grad->set_dtype(x.dtype()); } diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 86878c5feb082..922bafed0add8 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -469,6 +469,7 @@ void WeightOnlyLinearGradInferMeta(const MetaTensor& x, const MetaTensor& out_grad, const std::string& weight_dtype, const int32_t arch, + const int32_t group_size, MetaTensor* x_grad); void YoloLossGradInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 8b85a3efc4dd8..b771fba031317 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -3381,6 +3381,7 @@ void WeightDequantizeInferMeta(const MetaTensor& x, const MetaTensor& scale, const std::string& algo, DataType out_dtype, + const int32_t group_size, MetaTensor* out) { PADDLE_ENFORCE_EQ(x.dims().size(), 2UL, @@ -3388,18 +3389,44 @@ void WeightDequantizeInferMeta(const MetaTensor& x, "The x tensor of dequantize op must be 2D, but got[%d]", x.dims().size())); PADDLE_ENFORCE_EQ( - scale.dims().size(), - 1UL, - phi::errors::InvalidArgument( - "The scale tensor of dequantize op must be 1D, but got[%d]", - scale.dims().size())); - PADDLE_ENFORCE_EQ(scale.dims()[0], - x.dims()[0], - phi::errors::InvalidArgument( - "The scale tensor's shape must be equal to the x " - "tensor's shape, but got [%d] not equal to [%d]", - scale.dims()[0], - x.dims()[0])); + (group_size == -1 || group_size == 64 || group_size == 128), + true, + phi::errors::InvalidArgument("group_size must be -1, 64 or 128.")); + + auto dim_scale = scale.dims(); + + // per-channel dequantization + if (group_size == -1) { + PADDLE_ENFORCE_EQ( + dim_scale.size(), + 1UL, + phi::errors::InvalidArgument("The scale tensor of dequantize op must " + "be 1D in per-channel mode, but got[%d]", + scale.dims().size())); + PADDLE_ENFORCE_EQ(dim_scale[0], + x.dims()[0], + phi::errors::InvalidArgument( + "The scale tensor's shape must be equal to the x " + "tensor's shape, but got [%d] not equal to [%d]", + scale.dims()[0], + x.dims()[0])); + } else /* groupwise dequantization */ { + PADDLE_ENFORCE_EQ( + dim_scale.size(), + 2UL, + phi::errors::InvalidArgument("The scale tensor of dequantize op must " + "be 2D in group-wise mode, but got[%d]", + scale.dims().size())); + PADDLE_ENFORCE_EQ( + dim_scale[0], + (x.dims()[1] + (group_size - 1)) / group_size, + errors::InvalidArgument("The input(weight_scale) dim[0] must be equal " + "to (Input(weight).dim[1] + (group_size -1))" + " / group_size" + "But receive %d and %d", + dim_scale[0], + (x.dims()[1] + (group_size - 1)) / group_size)); + } int n = x.dims()[1]; int k = x.dims()[0]; out->set_dims(common::make_ddim({n, k})); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index c081c1690c28d..82f5fc64d57a5 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -539,6 +539,7 @@ void WeightDequantizeInferMeta(const MetaTensor& x, const MetaTensor& scale, const std::string& algo, DataType out_dtype, + const int32_t group_size, MetaTensor* out); } // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 0b2ef29389137..6250b3a3b23c8 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -3943,10 +3943,16 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x, const MetaTensor& weight_scale, const std::string& weight_dtype, const int32_t arch, + const int32_t group_size, MetaTensor* out) { + PADDLE_ENFORCE((group_size == -1 || group_size == 64 || group_size == 128), + errors::InvalidArgument("group_size must be -1, 64 or 128.")); + + auto weight_scale_dims = weight_scale.dims(); + auto x_dims = x.dims(); auto w_dims = weight.dims(); - auto n = weight_scale.dims()[0]; + auto n = group_size == -1 ? weight_scale_dims[0] : weight_scale_dims[1]; PADDLE_ENFORCE( weight_dtype == "int8" || weight_dtype == "int4", errors::InvalidArgument("quant_method must be 'int8' or 'int4'.")); @@ -3954,10 +3960,6 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x, w_dims.size(), 2UL, errors::InvalidArgument("The input(weight) must be a 2D Tensor.")); - PADDLE_ENFORCE_EQ( - weight_scale.dims().size(), - 1UL, - errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor.")); PADDLE_ENFORCE_EQ( w_dims[0] % 16, 0, @@ -3978,6 +3980,30 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x, "But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)", x_dims[x_dims.size() - 1], w_dims[1])); + + // per-channel dequantization + if (group_size == -1) { + PADDLE_ENFORCE_EQ( + weight_scale_dims.size(), + 1UL, + errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor." + "in per-channel mode.")); + } else /* groupwise dequantization */ { + PADDLE_ENFORCE_EQ( + weight_scale_dims.size(), + 2UL, + errors::InvalidArgument("The input(weight_scale) must be a 2D Tensor" + " in groupwise mode.")); + PADDLE_ENFORCE_EQ( + weight_scale_dims[0], + (w_dims[1] + (group_size - 1)) / group_size, + errors::InvalidArgument("The input(weight_scale) dim[0] must be equal " + "to Input(weight) dim[1] / group_size" + "But receive %d and %d", + weight_scale_dims[0], + (w_dims[1] + (group_size - 1)) / group_size)); + } + auto out_dims = x_dims; out_dims[out_dims.size() - 1] = n; out->set_dims(out_dims); diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index be3f1fba94a80..f51c3dacb1909 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -720,6 +720,7 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x, const MetaTensor& weight_scale, const std::string& weight_dtype, const int32_t arch, + const int32_t group_size, MetaTensor* out); void WeightedSampleNeighborsInferMeta(const MetaTensor& row, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 16da7fbc02128..af60d6ae8da5c 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -5203,6 +5203,7 @@ void UnStackInferMeta(const MetaTensor& x, void WeightQuantizeInferMeta(const MetaTensor& x, const std::string& algo, const int32_t arch, + const int32_t group_size, MetaTensor* out, MetaTensor* scale) { PADDLE_ENFORCE_EQ( @@ -5229,7 +5230,21 @@ void WeightQuantizeInferMeta(const MetaTensor& x, phi::errors::InvalidArgument( "The second dimension of input must be divisible by 16, but got[%d]", x_dims[1])); - std::vector dim_scale({x_dims[1]}); + PADDLE_ENFORCE_EQ( + ((group_size == -1) || (group_size == 64) || (group_size == 128)), + true, + phi::errors::InvalidArgument( + "Currently, group_size only support -1, 64 or 128.")); + + std::vector dim_scale; + if (group_size != -1) { + int64_t scale_dim0 = (x_dims[0] + (group_size - 1)) / group_size; + int64_t scale_dim1 = x_dims[1]; + dim_scale = std::vector({scale_dim0, scale_dim1}); + } else { + dim_scale = std::vector({x_dims[1]}); + } + std::vector dim_out; if (algo == "weight_only_int8" || algo == "llm.int8") { dim_out = std::vector({x_dims[1], x_dims[0]}); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index f4fca6cd7770d..eae4614a8eb5c 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -477,6 +477,7 @@ void QuantizeXPUInferMeta(const MetaTensor& x, void WeightQuantizeInferMeta(const MetaTensor& x, const std::string& algo, const int32_t arch, + const int32_t group_size, MetaTensor* out, MetaTensor* scale); diff --git a/paddle/phi/kernels/cpu/weight_quantize_kernel.cc b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc index e85b83700b173..313c59e2e6676 100644 --- a/paddle/phi/kernels/cpu/weight_quantize_kernel.cc +++ b/paddle/phi/kernels/cpu/weight_quantize_kernel.cc @@ -28,7 +28,8 @@ void quant_compute(const DeviceContext& dev_ctx, DenseTensor* out, DenseTensor* scale, const std::string& algo, - const int32_t arch) { + const int32_t arch, + const int32_t group_size) { PADDLE_ENFORCE_EQ( ((arch == 80) || (arch == 86) || (arch == 75) || (arch == 70)), true, @@ -51,7 +52,8 @@ void quant_compute(const DeviceContext& dev_ctx, DenseTensor x_int(out->type()); - if ((arch == 80) || (arch == 75) || (arch == 86)) { + if ((arch == 80) || (arch == 75) || (arch == 86) || (arch == 89) || + (arch == 90)) { x_int.Resize({static_cast(m), static_cast(n)}); } else { // phi::Copy may change tensor meta info, here we transpose the quanted @@ -71,9 +73,19 @@ void quant_compute(const DeviceContext& dev_ctx, int_processed_2.Resize(out->dims()); dev_ctx.template Alloc(&int_processed_2); D* int_processed_2_data = int_processed_2.data(); - per_channel_scale(scale_data, x_data, m, n, bits == 8 ? 127.0f : 7.0f); - - per_channel_quant(x_int_data, x_data, scale_data, m, n); + if (group_size == -1) { + per_channel_scale(scale_data, x_data, m, n, bits == 8 ? 127.0f : 7.0f); + per_channel_quant(x_int_data, x_data, scale_data, m, n); + } else { + group_wise_scale(scale_data, + x_data, + m, + n, + bits == 8 ? 127.0f : 7.0f, + static_cast(group_size)); + + group_wise_quant(x_int_data, x_data, scale_data, m, n, group_size); + } if (algo == "llm.int8") { std::vector axis = {1, 0}; funcs::Transpose trans; @@ -105,14 +117,17 @@ void WeightQuantizeKernel(const Context& dev_ctx, const DenseTensor& x, const std::string& algo, const int32_t arch, + const int32_t group_size, DenseTensor* out, DenseTensor* scale) { dev_ctx.template Alloc(out); dev_ctx.template Alloc(scale); if (algo == "weight_only_int8" || algo == "llm.int8") { - quant_compute(dev_ctx, x, out, scale, algo, arch); + quant_compute( + dev_ctx, x, out, scale, algo, arch, group_size); } else if (algo == "weight_only_int4") { - quant_compute(dev_ctx, x, out, scale, algo, arch); + quant_compute( + dev_ctx, x, out, scale, algo, arch, group_size); } else { phi::errors::Unimplemented( "The algo must be in ['weight_only_int8', 'weight_only_int4', " diff --git a/paddle/phi/kernels/funcs/weight_dequant_functor.h b/paddle/phi/kernels/funcs/weight_dequant_functor.h index 1728fa0577ab4..4eed94de7bf4d 100644 --- a/paddle/phi/kernels/funcs/weight_dequant_functor.h +++ b/paddle/phi/kernels/funcs/weight_dequant_functor.h @@ -231,12 +231,133 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight, } } +template +__global__ void int8_weight_only_dequant(const uint8_t* weight, + const T* scales, + T* output, + const int n, + const int k, + const int group_size) { + using Converter = FastWeightOnlyHalfConverter; + AlignedVector vec_weight; + T vec_weight_f16[16]; + AlignedVector vec_out; + + int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; + int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; + // Every two rows of the original weights are interleaved into a row with + // stride of 64, so if each thread processes 16 elements(for int8, we can use + // ldg.128 to load weights), then every group of four adjacent threads will + // alternately process two different row weights for example every 128 + // consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave + // layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before + // interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 + // before interleaving. So if each thread loads 16 int8 elements, then the + // elements of the first four and last four threads of each 8 consecutive + // threads will come from row 2N and row 2N+1 respectively before + // interleaving. + int row_id = tile_id * 2 + ((lane_id % 8) > 3 ? 1 : 0); + weight += tile_id * k * 2; + output += row_id * k; + + scales += row_id; +#pragma unroll + for (int i = lane_id * 16; i < k * 2; i += 16 * 32) { + int scale_offset = i / 2 / group_size; + float scale = static_cast(scales[scale_offset * n]); + Load(&weight[i], &vec_weight); +#pragma unroll + for (int p = 0; p < 16; p += Converter::kHalfLength) { + // The rearrangement here counteracts the effect of + // cutlass::add_bias_and_interleave_int8s_inplace Input int8 data layout + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + // + // Converted fp16 data layout + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 16 bits) + // vec_weight_f16[p] = static_cast(static_cast(vec_weight[p]) * + // scale); + // fast_cvt_4_packed_signed_i8s_to_2_half2s() + Converter::convert(vec_weight_f16 + p, &vec_weight[p], scale); + } +#pragma unroll + for (int p = 0; p < 16; ++p) { + // The index remapping here is to counteracts the effect of + // cutlass::permute_B_rows_for_mixed_gemm input 0 1 2 3 4 5 6 7 8 9 10 11 + // 12 13 14 15 weight 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 + vec_out[p] = vec_weight_f16[4 * ((p % 8) / 2) + p % 2 + 2 * (p / 8)]; + } + Store(vec_out, &output[i / 128 * 64 + (i % 64)]); + } +} + +template +__global__ void int4_weight_only_dequant(const uint8_t* weight, + const T* scales, + T* output, + const int n, + const int k, + const int group_size) { + using Converter = FastWeightOnlyHalfConverter; + + AlignedVector vec_weight; + T vec_weight_f16[32]; + AlignedVector vec_out; + + int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32; + int tile_id = blockIdx.x * blockDim.x / 32 + warp_id; + // Every two rows of the original weights are interleaved into a row with + // stride of 64, so if each thread processes 16 elements(for int8, we can use + // ldg.128 to load weights), then every group of four adjacent threads will + // alternately process two different row weights for example every 128 + // consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave + // layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before + // interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1 + // before interleaving. So if each thread loads 16 int8 elements, then the + // elements of the first four and last four threads of each 8 consecutive + // threads will come from row 2N and row 2N+1 respectively before + // interleaving. + int row_id = tile_id * 4 + ((lane_id % 8) / 2); + weight += tile_id * k / 2 * 4; + output += row_id * k; + scales += row_id; +#pragma unroll + for (int i = lane_id * 32; i < k * 4; i += 32 * 32) { + Load(&weight[i / 2], &vec_weight); + int scale_offset = i / 4 / group_size; + float scale = static_cast(scales[scale_offset * n]); +#pragma unroll + for (int p = 0; p < 32; p += Converter::kHalfLength) { + // The rearrangement here counteracts the effect of + // cutlass::add_bias_and_interleave_int4s_inplace Input int8 data layout + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt + // occupies 4 bits) + // + // Converted fp16 data layout + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt + // occupies 16 bits) + // vec_weight_f16[p] = + // static_cast(static_cast(vec_weight[p]) * scale); + Converter::convert(vec_weight_f16 + p, &vec_weight[p / 2], scale); + } +#pragma unroll + for (int p = 0; p < 32; ++p) { + // The index remapping here is to counteracts the effect of + // cutlass::permute_B_rows_for_mixed_gemm input 0 1 2 3 4 5 6 7 8 9 10 11 + // 12 13 14 15 ... 31 weight 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 + // 12 13 20 21 28 29 6 7 14 15 22 23 30 31 + vec_out[p] = vec_weight_f16[8 * ((p % 8) / 2) + p % 2 + 2 * (p / 8)]; + } + Store(vec_out, &output[i / 256 * 64 + (i % 64)]); + } +} + template void WeightDequantize(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& scale, const std::string& algo, const bool transpose, + const int32_t group_size, DenseTensor* out) { using DataType = typename PDDataTypeTraits::DataType; @@ -246,14 +367,22 @@ void WeightDequantize(const Context& dev_ctx, dim3 grid(n / 32); auto stream = dev_ctx.stream(); - if (algo == "weight_only_int8") { + if (algo == "weight_only_int8" && group_size == -1) { int8_weight_only_dequant<<>>( reinterpret_cast(x.data()), reinterpret_cast(scale.data()), reinterpret_cast(out->data()), n, k); - } else if (algo == "weight_only_int4") { + } else if (algo == "weight_only_int8" && group_size > 0) { + int8_weight_only_dequant<<>>( + reinterpret_cast(x.data()), + reinterpret_cast(scale.data()), + reinterpret_cast(out->data()), + n, + k, + group_size); + } else if (algo == "weight_only_int4" && group_size == -1) { grid.x /= 2; int4_weight_only_dequant<<>>( reinterpret_cast(x.data()), @@ -261,6 +390,15 @@ void WeightDequantize(const Context& dev_ctx, reinterpret_cast(out->data()), n, k); + } else if (algo == "weight_only_int4" && group_size > 0) { + grid.x /= 2; + int4_weight_only_dequant<<>>( + reinterpret_cast(x.data()), + reinterpret_cast(scale.data()), + reinterpret_cast(out->data()), + n, + k, + group_size); } } diff --git a/paddle/phi/kernels/funcs/weight_only_gemv.cu b/paddle/phi/kernels/funcs/weight_only_gemv.cu index aeccf6f2370cd..ff9285693b55f 100644 --- a/paddle/phi/kernels/funcs/weight_only_gemv.cu +++ b/paddle/phi/kernels/funcs/weight_only_gemv.cu @@ -367,6 +367,8 @@ __global__ void int8_weight_only_gemv(const T* input, enum class WeightOnlyQuantType { Int4b, Int8b }; +enum class WeightOnlyType { PerChannel, GroupWise }; + template struct WeightLayoutDetails; @@ -530,8 +532,6 @@ struct WeightOnlyKernelDetails { kElemsPerThread / kActivationElemNumPerAccess; }; -enum class WeightOnlyType { PerChannel, GroupWise }; - struct WeightOnlyPerChannel; template struct WeightOnlyGroupWise; @@ -551,13 +551,12 @@ struct WeightOnlyProperties> { static constexpr int kGroupSize = GS; }; -template + int BlockSize> struct WeightOnlyScaleLoader { - using ElemType = T; using Details = WeightOnlyKernelDetails; static constexpr bool kIsFineGrained = WeightOnlyProperties::kIsFineGrained; @@ -565,25 +564,19 @@ struct WeightOnlyScaleLoader { WeightOnlyProperties::kGroupSize; private: - const ElemType* _scales; - const ElemType* _zeros; + const T* _scales; + const T* _zeros; int _stride; int _offset; public: - __device__ __forceinline__ WeightOnlyScaleLoader(const ElemType* scales, - const ElemType* zeros, + __device__ __forceinline__ WeightOnlyScaleLoader(const T* scales, + const T* zeros, int initial_offset, int stride) : _scales(scales), _zeros(zeros), _stride(stride) { _scales += initial_offset; -#ifndef WIN32 - // linux - if constexpr (Zero) { -#else - // windows if (Zero) { -#endif _zeros += initial_offset; } // Calculate the k dimension index of the element processed by the current @@ -594,10 +587,10 @@ struct WeightOnlyScaleLoader { (threadIdx.x % Details::kThreadsNumPerTile) * Details::kElemsPerThread; } - __device__ __forceinline__ void load(ElemType& scale, // NOLINT - ElemType& zero, // NOLINT - int nid) { + __device__ __forceinline__ void load(T* scale, T* zero, int nid) { int offset = nid * Details::kInterleave; + +// TODO(freeliuzc): cpplint has bug here #ifndef WIN32 if constexpr (kIsFineGrained) { #else @@ -605,15 +598,17 @@ struct WeightOnlyScaleLoader { #endif offset += _offset / kGroupSize * _stride; } - scale = _scales[offset]; + *scale = _scales[offset]; + +// TODO(freeliuzc): cpplint has bug here #ifndef WIN32 if constexpr (Zero) { #else if (Zero) { #endif - zero = _zeros[offset]; + *zero = _zeros[offset]; } else { - zero = static_cast(0.f); + *zero = static_cast(0.f); } } @@ -624,6 +619,272 @@ struct WeightOnlyScaleLoader { __device__ __forceinline__ int offset() { return _offset; } }; // NOLINT +template +struct WeightOnlyConverter {}; + +template <> +struct WeightOnlyConverter { + static __device__ inline void convert(half halves[4], + int8_t signed_chars[4]) { + uint32_t* h = reinterpret_cast(halves); + uint32_t i8s = *reinterpret_cast(signed_chars); + + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(h[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + } +}; + +template <> +struct WeightOnlyConverter<__nv_bfloat16, WeightOnlyQuantType::Int8b> { + static __device__ inline void convert(__nv_bfloat16 halves[4], + int8_t signed_chars[4]) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + uint32_t* bf16_result_ptr = reinterpret_cast(halves); + uint32_t i8s = *reinterpret_cast(signed_chars); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; + } + +// Truncate the fp32 representation and pack up as bfloat16s. +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware + // for bf16 mma. If one wishes to use HMMA on older hardware, they should + // Convert directly to FP16 using FP16 converters. + assert(false); +#endif + } +}; + +template <> +struct WeightOnlyConverter { + static __device__ inline void convert(half halves[8], + int8_t signed_chars[4]) { + uint32_t* h = reinterpret_cast(halves); + uint32_t i4s = *reinterpret_cast(signed_chars); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is + // thanks to the register packing format and the fact that we force our + // integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput + // in order to convert elt_23 and elt_67 to fp16 without having to shift + // them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide + // RAW dependency if we issue immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), + "n"(BOTTOM_MASK), + "n"(I4s_TO_F16s_MAGIC_NUM), + "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit + // float2half instructions if I use the half2 ctor. In this case, I chose + // performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[0]) + : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[1]) + : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(h[2]) + : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(h[3]) + : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + } +}; + +template <> +struct WeightOnlyConverter<__nv_bfloat16, WeightOnlyQuantType::Int4b> { + static __device__ inline void convert(__nv_bfloat16 halves[8], + int8_t signed_chars[4]) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + uint32_t* h = reinterpret_cast(halves); + uint32_t const source_i4s = *reinterpret_cast(signed_chars); + + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, + // so we must loop. No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); +#pragma unroll + for (int ii = 1; ii < 4; ++ii) { + i4s >>= 4; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile( + "lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + +// Finally, we construct the output numbers. +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + // Since this section is for Ampere+, we use bf16 fma to do the bias + // subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(h[ii]) + : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } + +#else + // Disable this on architectures older than Ampere since they lack hardware + // for bf16 mma. If one wishes to use HMMA on older hardware, they should + // Convert directly to FP16 using FP16 converters. + assert(false); +#endif + } +}; + +template +__device__ __forceinline__ void load(T0* dst, T1* src, size_t offset = 0) { + *reinterpret_cast(dst) = + *(reinterpret_cast(src) + offset); +} + +template +struct WeightPostProcessor { + static __device__ __forceinline__ void run(T* weights_vec, + T* weights_f16, + T* scale, + T* zero, + int NPerBlock, + int idx) {} +}; + +template +struct WeightPostProcessor { + static __device__ __forceinline__ void run(T* weights_vec, + T* weights_f16, + T* scale, + T* zero, + int NPerBlock, + int idx) { + using HALF_2_TYPE = typename CUDA_HALF_2_TYPE_TARIS::type; +#pragma unroll + for (int i = 0; i < Details::kShuffleContinous; ++i) { +#pragma unroll + for (int j = 0; j < Details::kShuffleStrided; ++j) { + // Dequantize the weights and arrange the shuffled elements back to + // the correct order in the register array + HALF_2_TYPE v = *reinterpret_cast( + weights_vec + i * Details::kShuffleBasicTile + + j * Details::kShuffleContinous * Details::kShuffleBasicTile); + v = HalfMulAdd::apply( + v, + ConvertDstFunc_2::apply(scale[idx]), + ConvertDstFunc_2::apply(zero[idx])); + weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile + + j * Details::kShuffleBasicTile + 0) * + NPerBlock + + idx] = v.x; + weights_f16[(i * Details::kShuffleStrided * Details::kShuffleBasicTile + + j * Details::kShuffleBasicTile + 1) * + NPerBlock + + idx] = v.y; + } + } + } +}; + +template +struct WeightPostProcessor { + static __device__ __forceinline__ void run(T* weights_vec, + T* weights_f16, + T* scale, + T* zero, + int NPerBlock, + int idx) { +#pragma unroll + for (int p = 0; p < 16; ++p) { + weights_f16[p * NPerBlock + idx] = + weights_vec[p / 8 + (p % 8) * 2] * scale[idx]; + } + } +}; + template -__global__ void weight_only_batched_gemv_multi_warp(const int8_t* qweight, +__global__ void weight_only_batched_gemv_multi_warp(const T* in, + const int8_t* qweight, + const T* bias, const T* scales, const T* zeros, - const T* in, - const T* bias, T* out, const int n, const int k) { @@ -650,8 +911,10 @@ __global__ void weight_only_batched_gemv_multi_warp(const int8_t* qweight, using CvtSrcType = int8_t; using CvtResType = T; using ScaleLoader = - WeightOnlyScaleLoader; - extern __shared__ int8_t shmem[]; + WeightOnlyScaleLoader; + using WeightProcessor = WeightPostProcessor; + + extern __shared__ uint8_t shmem[]; constexpr int Interleave = Details::kInterleave; constexpr int WarpSize = 32; constexpr int Num = Batch * NPerBlock; @@ -673,48 +936,47 @@ __global__ void weight_only_batched_gemv_multi_warp(const int8_t* qweight, // threads and fp32 for accumulation between threads. T accumulator[Num]; for (int i = 0; i < Num; ++i) { - accumulator[i] = ConvertFloatFunc::apply(0.f); + accumulator[i] = ConvertDstFunc::apply(0.f); } // Iteration in k dimensions for (int local_k = tid * Details::kElemsPerThread; local_k < k * Interleave; local_k += BlockSize * Details::kElemsPerThread) { - T weights_f16[Details::kElemsPerThread * NPerBlock]; // 16 * 2 = 32 + T weights_f16[Details::kElemsPerThread * NPerBlock]; T scale[NPerBlock], zero[NPerBlock]; #pragma unroll for (int idx = 0; idx < NPerBlock; ++idx) { // Load quantized weight and scales/zeros int8_t weights_quantized[Details::kBytePerThread]; - *reinterpret_cast(weights_quantized) = - *reinterpret_cast( - qweight + idx * Interleave * k / Details::kElemsPerByte + - local_k / Details::kElemsPerByte); - scale_loader.load(scale[idx], zero[idx], idx); + load(weights_quantized, + qweight + idx * Interleave * k / Details::kElemsPerByte + + local_k / Details::kElemsPerByte); + scale_loader.load(scale + idx, zero + idx, idx); T weights_vec[Details::kElemsPerThread]; + #pragma unroll for (int i = 0; i < Details::kConvertIters; ++i) { // Use cutlass::FastInterleavedAndBiasedNumericArrayConverter for I2F // type conversion - fast_cvt_4_packed_signed_i8s_to_2_half2s( + WeightOnlyConverter::convert( weights_vec + i * Details::kConvertCount, weights_quantized + i * Details::kConvertCount / Details::kElemsPerByte); } - // TODO(wangbojun) no zero support here -#pragma unroll - for (int p = 0; p < 16; ++p) { - weights_f16[p * NPerBlock + idx] = - weights_vec[p / 8 + (p % 8) * 2] * scale[idx]; - } + // Assign weight and apply scales. + // Currently not support zero. + WeightProcessor::run( + weights_vec, weights_f16, scale, zero, NPerBlock, idx); } #pragma unroll for (int b = 0; b < Batch; ++b) { T in_v[Details::kElemsPerThread]; - // load activation elements - *(float4*)in_v = // NOLINT - *(float4*)(in + b * k + scale_loader.offset()); // NOLINT - *(float4*)(in_v + 8) = // NOLINT - *(float4*)(in + b * k + scale_loader.offset() + 8); // NOLINT +#pragma unroll + for (int idx = 0; idx < Details::kActivationAccessNum; ++idx) { + load(in_v + idx * Details::kActivationElemNumPerAccess, + in + b * k + scale_loader.offset() + + idx * Details::kActivationElemNumPerAccess); + } // Perform vector inner product and accumulate #ifndef WIN32 if constexpr (NPerBlock == 1) { @@ -729,7 +991,7 @@ __global__ void weight_only_batched_gemv_multi_warp(const int8_t* qweight, *reinterpret_cast(in_v + y), v); } - accumulator[b] = accumulator[b] + static_cast(v.x + v.y); + accumulator[b] = accumulator[b] + ConvertDstFunc::apply(v.x + v.y); } else { #pragma unroll for (int x = 0; x < NPerBlock / 2; ++x) { @@ -752,7 +1014,7 @@ __global__ void weight_only_batched_gemv_multi_warp(const int8_t* qweight, float reses[Num]; #pragma unroll for (int i = 0; i < Num; ++i) { - reses[i] = static_cast(accumulator[i]); + reses[i] = ConvertFloatFunc::apply(accumulator[i]); } // Each warp completes the internal reduce and writes the [Batch * NPerBlock * @@ -773,343 +1035,384 @@ __global__ void weight_only_batched_gemv_multi_warp(const int8_t* qweight, #else if (Bias) { #endif - bias_v = static_cast(bias[n_start_id + nid]); + bias_v = ConvertFloatFunc::apply(bias[n_start_id + nid]); } int b = i / NPerBlock / Interleave; - out[b * n + n_start_id + nid] = ConvertDstFunc::apply( GeluActivation::apply(v + bias_v)); } } - -#endif - -template -void int8_weight_only_gemv_launcher(const T* input, - const int8_t* weight, - const T* scale_list, - const T* bias, - T* output, - const int k, - const int n, - const bool gelu, - gpuStream_t stream) { -#ifdef PADDLE_WITH_CUDA - dim3 block(kWarpSize * kPerBlockWarpNum); // equal to 512; - dim3 grid(n / kPerBlockWarpNum / - 2); // Note(zhengzekang): Since each warp process 2 rows of matrix. - if (bias) { - if (gelu) { - int8_weight_only_gemv<<>>( - input, weight, scale_list, bias, output, k, n); - } else { - int8_weight_only_gemv<<>>( - input, weight, scale_list, bias, output, k, n); - } - } else { - if (gelu) { - int8_weight_only_gemv<<>>( - input, weight, scale_list, bias, output, k, n); - } else { - int8_weight_only_gemv<<>>( - input, weight, scale_list, bias, output, k, n); - } - } #endif -} - -template <> -void int8_weight_only_gemv_launcher(const float* input, - const int8_t* weight, - const float* scale_list, - const float* bias, - float* output, - const int k, - const int n, - const bool gelu, - gpuStream_t stream) { - // Weightonly GEMV do not support float. - assert(false); -} - -template <> -void int8_weight_only_gemv_launcher(const phi::dtype::bfloat16* input, - const int8_t* weight, - const phi::dtype::bfloat16* scale_list, - const phi::dtype::bfloat16* bias, - phi::dtype::bfloat16* output, - const int k, - const int n, - const bool gelu, - gpuStream_t stream) { - // Environment do not support bf16. - assert(false); -} template -void select_batch_gemv_multi_warp_by_batch(const T* input, - const int8_t* weight, - const T* scale_list, - const T* bias, - T* output, - const int m, - const int k, - const int n, - gpuStream_t stream) { +void select_activation_and_bias(const T* input, + const int8_t* weight, + const T* bias, + const T* scales, + const int m, + const int n, + const int k, + const std::string& act_method, + T* output, + cudaStream_t stream) { #ifdef PADDLE_WITH_CUDA - VLOG(3) << "launch batched gemv multi_block mnk:" << m << " " - << " " << n << " " << k; + static constexpr int kInterleave = WeightLayoutDetails::kInterleave; dim3 grid(n / NPerBlock / kInterleave); dim3 block(BlockSize); - int smem_size = sizeof(float) * BlockSize / 32 * m * NPerBlock * kInterleave; - switch (m) { - case 1: { + int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave; + if (bias) { + if (act_method == "gelu") { weight_only_batched_gemv_multi_warp - <<>>( - weight, scale_list, /*zeros*/ nullptr, input, bias, output, n, k); - break; - } - case 2: { + <<>>( + input, weight, bias, scales, /*zeros*/ nullptr, output, n, k); + } else if (act_method == "None") { weight_only_batched_gemv_multi_warp - <<>>( - weight, scale_list, /*zeros*/ nullptr, input, bias, output, n, k); - break; - } - case 3: { - weight_only_batched_gemv_multi_warp - <<>>( - weight, scale_list, /*zeros*/ nullptr, input, bias, output, n, k); - break; + <<>>( + input, weight, bias, scales, /*zeros*/ nullptr, output, n, k); + } else { + PADDLE_THROW( + errors::InvalidArgument("Currently, weightonly GEMV act_method " + "only support `gelu`, `None`. ")); } - case 4: { + } else { + if (act_method == "gelu") { weight_only_batched_gemv_multi_warp - <<>>( - weight, scale_list, /*zeros*/ nullptr, input, bias, output, n, k); - break; - } - case 5: { + <<>>( + input, weight, bias, scales, /*zeros*/ nullptr, output, n, k); + } else if (act_method == "None") { weight_only_batched_gemv_multi_warp - <<>>( - weight, scale_list, /*zeros*/ nullptr, input, bias, output, n, k); - break; - } - default: { - throw std::runtime_error("Use unsupported batch for gemv"); - break; + <<>>( + input, weight, bias, scales, /*zeros*/ nullptr, output, n, k); + } else { + PADDLE_THROW( + errors::InvalidArgument("Currently, weightonly GEMV act_method " + "only support `gelu`, `None`. ")); } } #endif } -template -void batched_int8_weight_only_gemv_multi_warp_launcher(const T* input, - const int8_t* weight, - const T* scale_list, - const T* bias, - T* output, - const int m, - const int k, - const int n, - const bool gelu, - gpuStream_t stream) { +template +void weight_only_batched_gemv_launcher( + const T* input, + const int8_t* weight, + const T* bias, + const T* scales, + int m, + int n, + int k, + const std::string& weight_only_quant_type, + const std::string& act_method, + T* output, + cudaStream_t stream) { #ifdef PADDLE_WITH_CUDA - if (bias) { - if (gelu) { - select_batch_gemv_multi_warp_by_batch( - input, weight, scale_list, bias, output, m, k, n, stream); - } else { - select_batch_gemv_multi_warp_by_batch( - input, weight, scale_list, bias, output, m, k, n, stream); + if (weight_only_quant_type == "int4") { + switch (m) { + case 1: { + select_activation_and_bias( + input, weight, bias, scales, m, n, k, act_method, output, stream); + break; + } + case 2: { + select_activation_and_bias( + input, weight, bias, scales, m, n, k, act_method, output, stream); + break; + } + case 3: { + select_activation_and_bias( + input, weight, bias, scales, m, n, k, act_method, output, stream); + break; + } + case 4: { + select_activation_and_bias( + input, weight, bias, scales, m, n, k, act_method, output, stream); + break; + } + default: { + throw std::runtime_error( + "Weight only cuda kernel only supported bs <= 4"); + break; + } } - } else { - if (gelu) { - select_batch_gemv_multi_warp_by_batch( - input, weight, scale_list, bias, output, m, k, n, stream); - } else { - select_batch_gemv_multi_warp_by_batch( - input, weight, scale_list, bias, output, m, k, n, stream); + } else if (weight_only_quant_type == "int8") { + switch (m) { + case 1: { + select_activation_and_bias( + input, weight, bias, scales, m, n, k, act_method, output, stream); + break; + } + case 2: { + select_activation_and_bias( + input, weight, bias, scales, m, n, k, act_method, output, stream); + break; + } + case 3: { + select_activation_and_bias( + input, weight, bias, scales, m, n, k, act_method, output, stream); + break; + } + case 4: { + select_activation_and_bias( + input, weight, bias, scales, m, n, k, act_method, output, stream); + break; + } + default: { + throw std::runtime_error( + "Weight only cuda kernel only supported bs <= 4"); + break; + } } + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "WeightOnlyGemvKernel quant_type only support 'int4' or 'int8'.")); } #endif } -template <> -void batched_int8_weight_only_gemv_multi_warp_launcher( - const phi::dtype::bfloat16* input, - const int8_t* weight, - const phi::dtype::bfloat16* scale_list, - const phi::dtype::bfloat16* bias, - phi::dtype::bfloat16* output, - const int m, - const int k, - const int n, - const bool gelu, - gpuStream_t stream) { - // Environment do not support bf16. - assert(false); -} - } // namespace template -void GemvWeightonlyInt8Wrapper(const Context& ctx, - const T* x, - const int8_t* weight, - const T* bias, - const T* weight_scale, - const int m, - const int n, - const int k, - const std::string& act_method, - T* output) { +void WeightOnlyGemvWrapper(const Context& dev_ctx, + const T* input, + const int8_t* weight, + const T* bias, + const T* scales, + int m, + int n, + int k, + int group_size, + const std::string& weight_only_quant_type, + const std::string& weight_only_type, + const std::string& act_method, + T* output) { using DataType = typename PDDataTypeTraits::DataType; - - bool gelu = false; - if (act_method == "gelu") { - gelu = true; - } else if (act_method == "None") { - gelu = false; - } else { - PADDLE_THROW( - errors::InvalidArgument("Currently, Int8 weightonly GEMV act_method " - "only support `gelu`, `None`. ")); - } - if (m < 1) { - // should no go here since m >=1 - // multi_warp is slightly faster even in m == 1. we don't dispatch to this - // kernel but keep it for future use. - int8_weight_only_gemv_launcher( - reinterpret_cast(x), - weight, - reinterpret_cast(weight_scale), - reinterpret_cast(bias), - reinterpret_cast(output), - k, - n, - gelu, - ctx.stream()); - } else { - batched_int8_weight_only_gemv_multi_warp_launcher( - reinterpret_cast(x), - weight, - reinterpret_cast(weight_scale), + if (weight_only_type == "per_channel") { + PADDLE_ENFORCE_EQ(group_size, + -1, + phi::errors::InvalidArgument( + "group size must be -1 in per-channel mode.")); + + weight_only_batched_gemv_launcher( + reinterpret_cast(input), + reinterpret_cast(weight), reinterpret_cast(bias), - reinterpret_cast(output), + reinterpret_cast(scales), m, - k, n, - gelu, - ctx.stream()); + k, + weight_only_quant_type, + act_method, + reinterpret_cast(output), + dev_ctx.stream()); + } else if (weight_only_type == "group_wise") { + if (group_size == 64) { + weight_only_batched_gemv_launcher>( + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(bias), + reinterpret_cast(scales), + m, + n, + k, + weight_only_quant_type, + act_method, + reinterpret_cast(output), + dev_ctx.stream()); + } else if (group_size == 128) { + weight_only_batched_gemv_launcher>( + reinterpret_cast(input), + reinterpret_cast(weight), + reinterpret_cast(bias), + reinterpret_cast(scales), + m, + n, + k, + weight_only_quant_type, + act_method, + reinterpret_cast(output), + dev_ctx.stream()); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "WeightOnlyGemvKernel group_size only support 64 or 128.")); + } + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("WeightOnlyGemvKernel type only support " + "'per_channel' or 'group_wise'.")); } } +template <> +void WeightOnlyGemvWrapper(const phi::GPUContext& dev_ctx, + const float* input, + const int8_t* weight, + const float* bias, + const float* scales, + int m, + int n, + int k, + int group_size, + const std::string& weight_only_quant_type, + const std::string& weight_only_type, + const std::string& act_method, + float* output) { + PADDLE_THROW(phi::errors::Unimplemented( + "WeightOnlyGemvKernel type only support 'float16' and 'bfloa16." + "Not support float32.")); +} + template -void GemvWeightonlyInt8Kernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& weight, - const paddle::optional& bias, - const DenseTensor& weight_scale, - const std::string& act_method, - DenseTensor* out) { +void WeightOnlyGemvKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& weight, + const paddle::optional& bias, + const DenseTensor& weight_scale, + int group_size, + const std::string& weight_only_quant_type, + const std::string& weight_only_type, + const std::string& act_method, + DenseTensor* out) { const T* x_data = x.data(); - const int8_t* weight_data = - weight.data(); // Actually, we pass the weight datatype is - // uint8_t type. + const int8_t* weight_data = weight.data(); + // Actually, we pass the weight datatype is uint8_t type. const T* bias_data = bias ? bias.get().data() : nullptr; const T* weight_scale_data = weight_scale.data(); T* out_data = dev_ctx.template Alloc(out); int m = x.dims()[0]; int k = x.dims()[1]; int n = weight.dims()[0]; - GemvWeightonlyInt8Wrapper(dev_ctx, - x_data, - weight_data, - bias_data, - weight_scale_data, - m, - n, - k, - act_method, - out_data); -} -template void GemvWeightonlyInt8Wrapper(const phi::GPUContext& ctx, - const phi::dtype::float16* x, - const int8_t* weight, - const phi::dtype::float16* bias, - const phi::dtype::float16* weight_scale, - const int m, - const int n, - const int k, - const std::string& act_method, - phi::dtype::float16* output); - -template void GemvWeightonlyInt8Wrapper( - const phi::GPUContext& ctx, - const phi::dtype::bfloat16* x, - const int8_t* weight, - const phi::dtype::bfloat16* bias, - const phi::dtype::bfloat16* weight_scale, - const int m, - const int n, - const int k, - const std::string& act_method, - phi::dtype::bfloat16* output); - -// template void GemvWeightonlyInt8Wrapper(const phi::GPUContext& ctx, -// const float* x, -// const int8_t* weight, -// const float* bias, -// const float* weight_scale, -// const int m, -// const int n, -// const int k, -// const std::string& act_method, -// float* output); + WeightOnlyGemvWrapper(dev_ctx, + x_data, + weight_data, + bias_data, + weight_scale_data, + m, + n, + k, + group_size, + weight_only_quant_type, + weight_only_type, + act_method, + out_data); +} +template void WeightOnlyGemvWrapper(const phi::GPUContext& ctx, + const float* input, + const int8_t* weight, + const float* bias, + const float* scales, + int m, + int n, + int k, + int group_size, + const std::string& weight_only_quant_type, + const std::string& weight_only_type, + const std::string& act_method, + float* output); + +template void WeightOnlyGemvWrapper(const phi::GPUContext& ctx, + const phi::dtype::float16* input, + const int8_t* weight, + const phi::dtype::float16* bias, + const phi::dtype::float16* scales, + int m, + int n, + int k, + int group_size, + const std::string& weight_only_quant_type, + const std::string& weight_only_type, + const std::string& act_method, + phi::dtype::float16* output); + +template void WeightOnlyGemvWrapper(const phi::GPUContext& ctx, + const phi::dtype::bfloat16* input, + const int8_t* weight, + const phi::dtype::bfloat16* bias, + const phi::dtype::bfloat16* scales, + int m, + int n, + int k, + int group_size, + const std::string& weight_only_quant_type, + const std::string& weight_only_type, + const std::string& act_method, + phi::dtype::bfloat16* output); } // namespace phi diff --git a/paddle/phi/kernels/funcs/weight_only_gemv.h b/paddle/phi/kernels/funcs/weight_only_gemv.h index 8a2cb1d5b4f34..7f0b4aa7fbc2c 100644 --- a/paddle/phi/kernels/funcs/weight_only_gemv.h +++ b/paddle/phi/kernels/funcs/weight_only_gemv.h @@ -19,15 +19,18 @@ limitations under the License. */ namespace phi { template -void GemvWeightonlyInt8Wrapper(const Context& ctx, - const T* x, - const int8_t* weight, - const T* bias, - const T* weight_scale, - const int m, - const int n, - const int k, - const std::string& act_method, - T* output); +void WeightOnlyGemvWrapper(const Context& dev_ctx, + const T* input, + const int8_t* weight, + const T* bias, + const T* scales, + int m, + int n, + int k, + int group_size, + const std::string& weight_only_quant_type, + const std::string& weight_only_type, + const std::string& act_method, + T* output); } // namespace phi diff --git a/paddle/phi/kernels/gpu/weight_dequantize_kernel.cu b/paddle/phi/kernels/gpu/weight_dequantize_kernel.cu index fce785804c344..77e71b950ddfa 100644 --- a/paddle/phi/kernels/gpu/weight_dequantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_dequantize_kernel.cu @@ -29,11 +29,12 @@ void WeightDequantizeKernel(const Context& dev_ctx, const DenseTensor& scale, const std::string& algo, DataType out_dtype, + int32_t group_size, DenseTensor* out) { #if defined(PADDLE_WITH_CUTLASS) auto out_dims = out->dims(); dev_ctx.template Alloc(out); - WeightDequantize(dev_ctx, x, scale, algo, true, out); + WeightDequantize(dev_ctx, x, scale, algo, true, group_size, out); out->Resize({{out_dims[1], out_dims[0]}}); auto out_tmp = Transpose(dev_ctx, *out, {1, 0}); out->ShareDataWith(out_tmp); diff --git a/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu index c5dc7a15db6e4..de6c2742590b3 100644 --- a/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu @@ -33,6 +33,7 @@ void WeightOnlyLinearGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, const std::string& weight_dtype, const int32_t arch, + const int32_t group_size, DenseTensor* x_grad) { #if defined(PADDLE_WITH_CUTLASS) PADDLE_ENFORCE_EQ( @@ -41,6 +42,12 @@ void WeightOnlyLinearGradKernel(const Context& dev_ctx, phi::errors::InvalidArgument( "Currently weightonly linear grad only support arch = 80 or 86. ")); + PADDLE_ENFORCE_EQ( + group_size, + -1, + phi::errors::InvalidArgument( + "Currently weightonly linear grad only support per-channel mode. ")); + int n = weight_scale.dims()[0]; int k = weight.dims()[1]; dev_ctx.template Alloc(x_grad); @@ -49,8 +56,13 @@ void WeightOnlyLinearGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(&weight_dequantized); std::string algo = weight_dtype == "int8" ? "weight_only_int8" : "weight_only_int4"; - WeightDequantize( - dev_ctx, weight, weight_scale, algo, true, &weight_dequantized); + WeightDequantize(dev_ctx, + weight, + weight_scale, + algo, + true, + group_size, + &weight_dequantized); MatmulKernel( dev_ctx, out_grad, weight_dequantized, false, false, x_grad); #else diff --git a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu index 32fb9951bfa47..c41b86148291d 100644 --- a/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_only_linear_kernel.cu @@ -31,6 +31,7 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, const DenseTensor& weight_scale, const std::string& weight_dtype, const int32_t arch, + const int32_t group_size, DenseTensor* out) { #if defined(PADDLE_WITH_CUTLASS) PADDLE_ENFORCE_EQ( @@ -50,12 +51,12 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, T* out_data = out->data(); const auto x_dims = x.dims(); const auto w_dims = weight.dims(); - int n = weight_scale.dims()[0]; + int n = group_size > 0 ? weight_scale.dims()[1] : weight_scale.dims()[0]; int k = w_dims[1]; int m = x.numel() / k; // m > 3: run gemm. - if (m > 3 || weight_dtype == "int4" || (arch == 70)) { + if (m > 3 || (arch == 70)) { /* Note(Zhengzekang): If using arch = 70, we always dispatch to weightonly Gemm, @@ -157,19 +158,39 @@ we havenot support sm70 weightonly gemv, because sm70 weight layout is RowMajor. PADDLE_THROW(phi::errors::Unimplemented( "Please compile with cutlass to make cutlass available")); #endif - } else { // m == 1: gemv + } else { // m <= 3: gemv if (weight_dtype == "int8") { - GemvWeightonlyInt8Wrapper(dev_ctx, - x_data, - weight_data, - bias_data, - weight_scale_data, - m, - n, - k, - "None", - out->data()); - } // TODO(lizhenyun) support weight_only_gemv for int4. + WeightOnlyGemvWrapper( + dev_ctx, + x_data, + weight_data, + bias_data, + weight_scale_data, + m, + n, + k, + group_size, + "int8", + group_size > 0 ? "group_wise" : "per_channel", + "None", + out->data()); + + } else if (weight_dtype == "int4") { + WeightOnlyGemvWrapper( + dev_ctx, + x_data, + weight_data, + bias_data, + weight_scale_data, + m, + n, + k, + group_size, + "int4", + group_size > 0 ? "group_wise" : "per_channel", + "None", + out->data()); + } } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu index 0c0024fc9ece3..8cd5598e2e92a 100644 --- a/paddle/phi/kernels/gpu/weight_quantize_kernel.cu +++ b/paddle/phi/kernels/gpu/weight_quantize_kernel.cu @@ -26,8 +26,15 @@ void WeightQuantizeKernel(const Context& dev_ctx, const DenseTensor& x, const std::string& algo, const int32_t arch, + const int32_t group_size, DenseTensor* out, DenseTensor* scale) { + PADDLE_ENFORCE_EQ( + ((group_size == -1) || (group_size == 64) || (group_size == 128)), + true, + phi::errors::InvalidArgument( + "Currently, group_size only support -1(per-channel), 64 or 128.")); + DenseTensor quanted_x; dev_ctx.template Alloc(out); dev_ctx.template Alloc(scale); diff --git a/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h b/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h index d521090816108..2905fd14e6b33 100644 --- a/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h +++ b/paddle/phi/kernels/impl/weight_quantize_kernel_impl.h @@ -56,6 +56,27 @@ void per_channel_scale( } } +template +void group_wise_scale(T* scale, + const T* input, + size_t m, + size_t n, + float bound, + size_t group_size) { + for (size_t i = 0; i < n; ++i) { + for (size_t j = 0; j < m; j += group_size) { + float max = static_cast(0.f); + for (size_t k = 0; k < group_size && j + k < m; ++k) { + max = static_cast(xabs(input[(j + k) * n + i])) > max + ? static_cast(xabs(input[(j + k) * n + i])) + : max; + } + scale[static_cast(j / group_size) * n + i] = + static_cast(max / bound); + } + } +} + template void per_channel_quant(int8_t* output, const T* input, @@ -102,6 +123,55 @@ void per_channel_quant(int8_t* output, } } +template +void group_wise_quant(int8_t* output, + const T* input, + const T* scale, + size_t num_rows, + size_t num_cols, + const int group_size) { + size_t bytes_per_out_col = num_cols * quant_bit / 8; + for (size_t ii = 0; ii < num_rows; ++ii) { + int8_t* current_quantized_weight_row = output + ii * bytes_per_out_col; + const T* current_weight_row = input + ii * num_cols; + for (size_t jj = 0; jj < bytes_per_out_col; ++jj) { + if (quant_bit == 8) { + size_t scale_cur_offset = jj + (ii / group_size) * num_cols; + const float col_scale = static_cast(scale[scale_cur_offset]); + const float weight_elt = static_cast(current_weight_row[jj]); + const float scaled_weight = round(weight_elt / col_scale); + const int8_t clipped_weight = static_cast( + std::max(-127.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } else if (quant_bit == 4) { + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { + const size_t input_idx = 2 * jj + packed_idx; + if (input_idx < num_cols) { + size_t scale_cur_offset = input_idx + (ii / group_size) * num_cols; + const float col_scale = static_cast(scale[scale_cur_offset]); + const float weight_elt = + static_cast(current_weight_row[input_idx]); + const float scaled_weight = round(weight_elt / col_scale); + int int_weight = static_cast(scaled_weight); + const int8_t clipped_weight = std::max(-7, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to + // upper bits if packing the second int4 and or the bits into the + // final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + current_quantized_weight_row[jj] = packed_int4s; + } else { + phi::errors::Unimplemented("Unsupported quantization bits: %d", + quant_bit); + } + } + } +} + template void add_bias_and_interleave_inplace(int8_t* tensor_ptr, size_t num_elts) { const size_t num_bytes = num_elts * quant_bit / 8; diff --git a/paddle/phi/kernels/weight_dequantize_kernel.h b/paddle/phi/kernels/weight_dequantize_kernel.h index 3a0a10924b57e..59bc406d3b689 100644 --- a/paddle/phi/kernels/weight_dequantize_kernel.h +++ b/paddle/phi/kernels/weight_dequantize_kernel.h @@ -24,6 +24,7 @@ void WeightDequantizeKernel(const Context& dev_ctx, const DenseTensor& scale, const std::string& algo, DataType out_dtype, + int32_t group_size, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/weight_only_linear_grad_kernel.h b/paddle/phi/kernels/weight_only_linear_grad_kernel.h index af05059c488f3..5ac26f03b9e65 100644 --- a/paddle/phi/kernels/weight_only_linear_grad_kernel.h +++ b/paddle/phi/kernels/weight_only_linear_grad_kernel.h @@ -27,6 +27,7 @@ void WeightOnlyLinearGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, const std::string& weight_dtype, const int32_t arch, + const int32_t group_size, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/weight_only_linear_kernel.h b/paddle/phi/kernels/weight_only_linear_kernel.h index 17037fb531f06..4ec3bbcd82ead 100644 --- a/paddle/phi/kernels/weight_only_linear_kernel.h +++ b/paddle/phi/kernels/weight_only_linear_kernel.h @@ -26,5 +26,6 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, const DenseTensor& weight_scale, const std::string& weight_dtype, const int32_t arch, + const int32_t group_size, DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/weight_quantize_kernel.h b/paddle/phi/kernels/weight_quantize_kernel.h index b906e68a40338..17adb5e21d59c 100644 --- a/paddle/phi/kernels/weight_quantize_kernel.h +++ b/paddle/phi/kernels/weight_quantize_kernel.h @@ -23,6 +23,7 @@ void WeightQuantizeKernel(const Context& dev_ctx, const DenseTensor& x, const std::string& algo, const int32_t arch, + const int32_t group_size, DenseTensor* out, DenseTensor* scale); diff --git a/python/paddle/nn/quant/quantized_linear.py b/python/paddle/nn/quant/quantized_linear.py index 155ea48c063aa..059ecc463605f 100644 --- a/python/paddle/nn/quant/quantized_linear.py +++ b/python/paddle/nn/quant/quantized_linear.py @@ -36,7 +36,7 @@ def _get_arch_info(): ) -def weight_quantize(x, algo="weight_only_int8", arch=None): +def weight_quantize(x, algo="weight_only_int8", arch=None, group_size=-1): """ Quantization function for weight_only and llm.int8's weight. @@ -45,6 +45,7 @@ def weight_quantize(x, algo="weight_only_int8", arch=None): algo (str): The algo that is x will be apply, must be one of 'weight_only_int8', 'weight_only_int4' and 'llm.int8', default: 'weight_only_int8'. arch (int): The compute arch for target device. For example, A100 is 80, v100 is 70, if you do not assign arch, we will get arch from your device, default: None. + group_size (int): The group size for weight quantization. -1 stands for default per-channel mode. Currently only support 64 or 128. Returns: out (Tensor): The Tensor which is the quantitative results, the data type is int8, the shape is transposition of x. @@ -71,8 +72,12 @@ def weight_quantize(x, algo="weight_only_int8", arch=None): arch == 70 or arch == 80 or arch == 86 or arch == 75 ), f"Currently weight_quantize only support SM70/75/80/86. but got {arch} " + assert ( + group_size == -1 or group_size == 64 or group_size == 128 + ), f"Currently group_size only support -1/64/128. but got {group_size} " + if in_dynamic_mode(): - return _C_ops.weight_quantize(x, algo, arch) + return _C_ops.weight_quantize(x, algo, arch, group_size) else: type = "weight_quantize" helper = LayerHelper(type, **locals()) @@ -83,12 +88,14 @@ def weight_quantize(x, algo="weight_only_int8", arch=None): type=type, inputs={"x": x}, outputs={'out': out, "scale": scale}, - attrs={"algo": algo, "arch": arch}, + attrs={"algo": algo, "arch": arch, "group_size": group_size}, ) return (out, scale) -def weight_dequantize(x, scale, algo="weight_only_int8", out_dtype='float16'): +def weight_dequantize( + x, scale, algo="weight_only_int8", out_dtype='float16', group_size=-1 +): """ Dequantization function for weight_only and llm.int8's weight. @@ -114,12 +121,16 @@ def weight_dequantize(x, scale, algo="weight_only_int8", out_dtype='float16'): >>> out, scale = weight_quantize(x, algo='weight_only_int8') >>> x_dequant = weight_dequantize(out, scale) """ + assert ( + group_size == -1 or group_size == 64 or group_size == 128 + ), f"Currently group_size only support -1/64/128. but got {group_size} " + check_dtype( out_dtype, 'out_dtype', ['float16', 'bfloat16'], 'weight_dequantize' ) out_dtype = convert_np_dtype_to_dtype_(out_dtype) if in_dynamic_mode(): - return _C_ops.weight_dequantize(x, scale, algo, out_dtype) + return _C_ops.weight_dequantize(x, scale, algo, out_dtype, group_size) else: type = "weight_dequantize" helper = LayerHelper(type, **locals()) @@ -129,13 +140,23 @@ def weight_dequantize(x, scale, algo="weight_only_int8", out_dtype='float16'): type=type, inputs={"x": x, "scale": scale}, outputs={'out': out}, - attrs={"algo": algo, "out_dtype": out_dtype}, + attrs={ + "algo": algo, + "out_dtype": out_dtype, + "group_size": group_size, + }, ) return out def weight_only_linear( - x, weight, bias=None, weight_scale=None, weight_dtype="int8", arch=None + x, + weight, + bias=None, + weight_scale=None, + weight_dtype="int8", + arch=None, + group_size=-1, ): """ Applies matrix multiplication of two tensors and then bias addition if provided. @@ -149,6 +170,7 @@ def weight_only_linear( weight_scale (Tensor|None): The input scale Tensor Provided to weight for dequantization. Its rank must be 1. weight_dtype(str): The dtype of weight Tensor, must be one of 'int8', 'int4', Defaulted to 'int8'. arch (int): The compute arch for target device. For example, A100 is 80, v100 is 70, if you do not assign arch, we will get arch from your device, default: None. + group_size (int): The group size for weight quantization. -1 stands for default per-channel mode. Currently only support 64 or 128. Returns: Tensor: the output Tensor, the data type is the same as that of x. @@ -174,10 +196,13 @@ def weight_only_linear( assert ( arch == 70 or arch == 80 or arch == 86 or arch == 75 ), f"Currently weight_quantize only support SM70/75/80/86. but got {arch} " + assert ( + group_size == -1 or group_size == 64 or group_size == 128 + ), f"Currently weight_quantize only support group size of -1, 64 or 128. but got {group_size} " if in_dynamic_mode(): out = _C_ops.weight_only_linear( - x, weight, bias, weight_scale, weight_dtype, arch + x, weight, bias, weight_scale, weight_dtype, arch, group_size ) return out else: @@ -195,7 +220,11 @@ def weight_only_linear( } if bias is not None: inputs["bias"] = [bias] - attrs = {'weight_dtype': weight_dtype, 'arch': arch} + attrs = { + 'weight_dtype': weight_dtype, + 'arch': arch, + 'group_size': group_size, + } out = helper.create_variable_for_type_inference(dtype) diff --git a/test/quantization/test_weight_only_linear.py b/test/quantization/test_weight_only_linear.py index c7bbc1c658267..f3749d0b4fb15 100644 --- a/test/quantization/test_weight_only_linear.py +++ b/test/quantization/test_weight_only_linear.py @@ -72,6 +72,7 @@ def config(self): self.out_features = 256 self.weight_dtype = "int8" self.static = False + self.group_size = -1 def weightQuantizeCPUGPUConsistenceCheck(self, weight_float): for arch in [70, 75, 80, 86]: @@ -83,6 +84,7 @@ def weightQuantizeCPUGPUConsistenceCheck(self, weight_float): if self.weight_dtype == "int8" else "weight_only_int4", arch=arch, + group_size=self.group_size, ) weight_cpu, weight_scale_cpu = Q.weight_quantize( weight_float.cpu(), @@ -90,6 +92,7 @@ def weightQuantizeCPUGPUConsistenceCheck(self, weight_float): if self.weight_dtype == "int8" else "weight_only_int4", arch=arch, + group_size=self.group_size, ) np.testing.assert_allclose( weight_gpu.numpy(), weight_cpu.numpy(), atol=1.5 @@ -106,7 +109,7 @@ def weightQuantizeCPUGPUConsistenceCheck(self, weight_float): def setUp(self): self.config() if self.dtype == "bfloat16" or self.weight_dtype == "int4": - self.atol = 1e-1 + self.atol = 1.5e-1 x = np.random.random((self.batch, self.token, self.in_features)) self.x = paddle.to_tensor(x, dtype=self.dtype) if self.bias: @@ -136,6 +139,7 @@ def setUp(self): algo="weight_only_int8" if self.weight_dtype == "int8" else "weight_only_int4", + group_size=self.group_size, ) def get_linear_out(self): @@ -149,6 +153,7 @@ def get_weight_only_linear_out(self): bias=self.bias, weight_scale=self.weight_scale, weight_dtype=self.weight_dtype, + group_size=self.group_size, ) return out.numpy() @@ -185,6 +190,7 @@ def get_weight_only_linear_out_static(self): bias, weight_scale, self.weight_dtype, + group_size=self.group_size, ) feed_dict = { 'x': x_np, @@ -351,59 +357,188 @@ def config(self): not core.is_compiled_with_cuda() or get_cuda_version() < 11020, "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", ) -class WeightOnlyLinearTestCaseStatic(WeightOnlyLinearTestCase): +class WeightOnlyLinearTestCase11(WeightOnlyLinearTestCase): def config(self): super().config() - self.static = True + self.dtype = 'float16' + self.weight_dtype = "int4" + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase12(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int4" + self.bias = False + self.batch = 1 + self.token = 1 @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11020 - or paddle.device.cuda.get_device_capability()[0] < 8, + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class WeightOnlyLinearTestCase13(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.bias = False + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class WeightOnlyLinearTestCase14(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.bias = False + self.batch = 1 + self.token = 1 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class WeightOnlyLinearTestCase15(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.bias = False + self.batch = 1 + self.token = 1 + self.group_size = 64 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class WeightOnlyLinearTestCase16(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.bias = False + self.batch = 1 + self.token = 1 + self.group_size = 128 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", ) -class WeightOnlyLinearBackwardAndWeightDequantizeTestCase(unittest.TestCase): - def test_weightonly_linear_backward(self): - x = ( - paddle.rand(shape=(128, 4096), dtype='float16') - * 1 - / math.sqrt(4096) - ) - x.stop_gradient = False - quant_x = copy.deepcopy(x) - quant_x.stop_gradient = False - weight = ( - paddle.rand(shape=(4096, 12288), dtype='float16') - * 1 - / math.sqrt(4096) - ) +class WeightOnlyLinearTestCase17(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int4" + self.bias = False + self.batch = 1 + self.token = 1 + self.group_size = 64 - quant_weight, quant_scale = Q.weight_quantize( - x=weight.cuda(), algo='weight_only_int8' - ) - dequant_weight = Q.weight_dequantize(quant_weight.cuda(), quant_scale) - np.testing.assert_allclose(weight, dequant_weight, rtol=1e-2, atol=1e-2) - quant_out = Q.weight_only_linear( - x=quant_x, - weight=quant_weight, - weight_scale=quant_scale, - weight_dtype="int8", - ) - out = paddle.matmul(x=x, y=weight) - np.testing.assert_allclose(quant_out, out, rtol=1e-3, atol=1e-3) +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCase18(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'float16' + self.weight_dtype = "int4" + self.bias = False + self.batch = 1 + self.token = 1 + self.group_size = 128 - quant_out.backward() - out.backward() - np.testing.assert_allclose(quant_x.grad, x.grad, rtol=1e-3, atol=1e-3) + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class WeightOnlyLinearTestCase19(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int4" + self.bias = False + self.batch = 1 + self.token = 2 + self.group_size = 128 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class WeightOnlyLinearTestCase20(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + self.bias = False + self.batch = 1 + self.token = 1 + self.group_size = 64 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8 + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8 or core is not support bfloat16", +) +class WeightOnlyLinearTestCase21(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.dtype = 'bfloat16' + self.weight_dtype = "int8" + self.bias = False + self.batch = 1 + self.token = 1 + self.group_size = 128 @unittest.skipIf( not core.is_compiled_with_cuda() or get_cuda_version() < 11020, "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", ) -class WeightOnlyLinearTestCase11(WeightOnlyLinearTestCase): +class WeightOnlyLinearTestCase22(WeightOnlyLinearTestCase): def config(self): super().config() self.dtype = 'float16' @@ -416,7 +551,7 @@ def config(self): not core.is_compiled_with_cuda() or get_cuda_version() < 11020, "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", ) -class WeightOnlyLinearTestCase12(WeightOnlyLinearTestCase): +class WeightOnlyLinearTestCase23(WeightOnlyLinearTestCase): def config(self): super().config() self.dtype = 'float16' @@ -432,7 +567,7 @@ def config(self): or paddle.device.cuda.get_device_capability()[0] < 8, "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", ) -class WeightOnlyLinearTestCase13(WeightOnlyLinearTestCase): +class WeightOnlyLinearTestCase24(WeightOnlyLinearTestCase): def config(self): super().config() self.dtype = 'bfloat16' @@ -441,5 +576,57 @@ def config(self): self.out_features = 288 +@unittest.skipIf( + not core.is_compiled_with_cuda() or get_cuda_version() < 11020, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearTestCaseStatic(WeightOnlyLinearTestCase): + def config(self): + super().config() + self.static = True + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or get_cuda_version() < 11020 + or paddle.device.cuda.get_device_capability()[0] < 8, + "quantized_matmul requires CUDA >= 11.2 and CUDA_ARCH >= 8", +) +class WeightOnlyLinearBackwardAndWeightDequantizeTestCase(unittest.TestCase): + def test_weightonly_linear_backward(self): + x = ( + paddle.rand(shape=(128, 4096), dtype='float16') + * 1 + / math.sqrt(4096) + ) + x.stop_gradient = False + quant_x = copy.deepcopy(x) + quant_x.stop_gradient = False + weight = ( + paddle.rand(shape=(4096, 12288), dtype='float16') + * 1 + / math.sqrt(4096) + ) + + quant_weight, quant_scale = Q.weight_quantize( + x=weight.cuda(), algo='weight_only_int8' + ) + dequant_weight = Q.weight_dequantize(quant_weight.cuda(), quant_scale) + np.testing.assert_allclose(weight, dequant_weight, rtol=1e-2, atol=1e-2) + + quant_out = Q.weight_only_linear( + x=quant_x, + weight=quant_weight, + weight_scale=quant_scale, + weight_dtype="int8", + ) + out = paddle.matmul(x=x, y=weight) + np.testing.assert_allclose(quant_out, out, rtol=1e-3, atol=1e-3) + + quant_out.backward() + out.backward() + np.testing.assert_allclose(quant_x.grad, x.grad, rtol=1e-3, atol=1e-3) + + if __name__ == '__main__': unittest.main()