Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inference] Support groupwise mode of gemv kernel #60204

Merged
merged 13 commits into from
Dec 27, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,14 @@ class FusedWeightOnlyLinearPattern
return getSMVersion();
});

const auto &group_size_attr = res.Attr(
[](const pir::drr::MatchContext &match_ctx) -> int { return -1; });

const auto &weight_quantize =
res.Op(paddle::dialect::WeightQuantizeOp::name(),
{{"algo", weight_only_int8_attr}, {"arch", arch_attr}});
{{"algo", weight_only_int8_attr},
{"arch", arch_attr},
{"group_size", group_size_attr}});
weight_quantize({&res.Tensor("w")},
{&res.Tensor("quanted_weight_tensor"),
&res.Tensor("weight_scale_tensor")});
Expand All @@ -110,7 +115,9 @@ class FusedWeightOnlyLinearPattern

const auto &weight_only_linear =
res.Op(paddle::dialect::WeightOnlyLinearOp::name(),
{{"weight_dtype", weight_dtype_attr}, {"arch", arch_attr}});
{{"weight_dtype", weight_dtype_attr},
{"arch", arch_attr},
{"group_size", group_size_attr}});
weight_only_linear({&res.Tensor("x"),
&res.Tensor("quanted_weight_tensor"),
&res.Tensor("bias"),
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2603,8 +2603,8 @@
no_need_buffer : input

- backward_op : weight_only_linear_grad
forward : weight_only_linear(Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch) -> Tensor(out)
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, Tensor out_grad, str weight_dtype, int arch)
forward : weight_only_linear(Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch, int group_size) -> Tensor(out)
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, Tensor out_grad, str weight_dtype, int arch, int group_size)
output : Tensor(x_grad)
infer_meta :
func : WeightOnlyLinearGradInferMeta
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/api/yaml/op_version.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,30 @@
comment : The axis to apply unique. If None, the input will be flattened.
default : std::vector<int>{}

- op : weight_dequantize
version :
- checkpoint : Upgrade weight_dequantize, add a new attribute [group_size]
action :
- add_attr : group_size
comment : The group size of the dequantization scales.
default : -1

- op : weight_only_linear
version :
- checkpoint : Upgrade weight_only_linear, add a new attribute [group_size]
action :
- add_attr : group_size
comment : The group size of the dequantization scales.
default : -1

- op : weight_quantize
version :
- checkpoint : Upgrade weight_quantize, add a new attribute [group_size]
action :
- add_attr : group_size
comment : The group size of the quantization scales.
default : -1

- op : yolo_box
version :
- checkpoint : Upgrade yolo box to add new attribute [iou_aware, iou_aware_factor].
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2833,7 +2833,7 @@
backward : warprnnt_grad

- op : weight_dequantize
args : (Tensor x, Tensor scale, str algo="weight_only_int8", DataType out_dtype=DataType::FLOAT16)
args : (Tensor x, Tensor scale, str algo = "weight_only_int8", DataType out_dtype = DataType::FLOAT16, int group_size = -1)
output : Tensor(out)
infer_meta :
func : WeightDequantizeInferMeta
Expand All @@ -2842,7 +2842,7 @@
data_type : out_dtype

- op : weight_only_linear
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch = 80)
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype, int arch = 80, int group_size = -1)
output : Tensor(out)
infer_meta :
func : WeightOnlyLinearInferMeta
Expand All @@ -2853,7 +2853,7 @@
backward: weight_only_linear_grad

- op : weight_quantize
args : (Tensor x, str algo = "weight_only_int8", int arch = 80)
args : (Tensor x, str algo = "weight_only_int8", int arch = 80, int group_size = -1)
output : Tensor(out), Tensor(scale)
infer_meta :
func : WeightQuantizeInferMeta
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1191,12 +1191,19 @@ void WeightOnlyLinearGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
const std::string& weight_dtype,
const int32_t arch,
const int32_t group_size,
MetaTensor* x_grad) {
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 86)),
true,
phi::errors::InvalidArgument(
"Currently weightonly linear grad only support arch = 80 or 86. "));
PADDLE_ENFORCE_EQ(
group_size,
-1,
phi::errors::InvalidArgument(
"Currently weightonly linear grad only support per-channel mode. "));

x_grad->set_dims(x.dims());
x_grad->set_dtype(x.dtype());
}
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ void WeightOnlyLinearGradInferMeta(const MetaTensor& x,
const MetaTensor& out_grad,
const std::string& weight_dtype,
const int32_t arch,
const int32_t group_size,
MetaTensor* x_grad);

void YoloLossGradInferMeta(const MetaTensor& x,
Expand Down
51 changes: 39 additions & 12 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3381,25 +3381,52 @@ void WeightDequantizeInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const std::string& algo,
DataType out_dtype,
const int32_t group_size,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(x.dims().size(),
2UL,
phi::errors::InvalidArgument(
"The x tensor of dequantize op must be 2D, but got[%d]",
x.dims().size()));
PADDLE_ENFORCE_EQ(
scale.dims().size(),
1UL,
phi::errors::InvalidArgument(
"The scale tensor of dequantize op must be 1D, but got[%d]",
scale.dims().size()));
PADDLE_ENFORCE_EQ(scale.dims()[0],
x.dims()[0],
phi::errors::InvalidArgument(
"The scale tensor's shape must be equal to the x "
"tensor's shape, but got [%d] not equal to [%d]",
scale.dims()[0],
x.dims()[0]));
(group_size == -1 || group_size == 64 || group_size == 128),
true,
phi::errors::InvalidArgument("group_size must be -1, 64 or 128."));

auto dim_scale = scale.dims();

// per-channel dequantization
if (group_size == -1) {
PADDLE_ENFORCE_EQ(
dim_scale.size(),
1UL,
phi::errors::InvalidArgument("The scale tensor of dequantize op must "
"be 1D in per-channel mode, but got[%d]",
scale.dims().size()));
PADDLE_ENFORCE_EQ(dim_scale[0],
x.dims()[0],
phi::errors::InvalidArgument(
"The scale tensor's shape must be equal to the x "
"tensor's shape, but got [%d] not equal to [%d]",
scale.dims()[0],
x.dims()[0]));
} else /* groupwise dequantization */ {
PADDLE_ENFORCE_EQ(
dim_scale.size(),
2UL,
phi::errors::InvalidArgument("The scale tensor of dequantize op must "
"be 2D in group-wise mode, but got[%d]",
scale.dims().size()));
PADDLE_ENFORCE_EQ(
dim_scale[0],
(x.dims()[1] + (group_size - 1)) / group_size,
errors::InvalidArgument("The input(weight_scale) dim[0] must be equal "
"to (Input(weight).dim[1] + (group_size -1))"
" / group_size"
"But receive %d and %d",
dim_scale[0],
(x.dims()[1] + (group_size - 1)) / group_size));
}
int n = x.dims()[1];
int k = x.dims()[0];
out->set_dims(common::make_ddim({n, k}));
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ void WeightDequantizeInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const std::string& algo,
DataType out_dtype,
const int32_t group_size,
MetaTensor* out);

} // namespace phi
36 changes: 31 additions & 5 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3943,21 +3943,23 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x,
const MetaTensor& weight_scale,
const std::string& weight_dtype,
const int32_t arch,
const int32_t group_size,
MetaTensor* out) {
PADDLE_ENFORCE((group_size == -1 || group_size == 64 || group_size == 128),
errors::InvalidArgument("group_size must be -1, 64 or 128."));

auto weight_scale_dims = weight_scale.dims();

auto x_dims = x.dims();
auto w_dims = weight.dims();
auto n = weight_scale.dims()[0];
auto n = group_size == -1 ? weight_scale_dims[0] : weight_scale_dims[1];
PADDLE_ENFORCE(
weight_dtype == "int8" || weight_dtype == "int4",
errors::InvalidArgument("quant_method must be 'int8' or 'int4'."));
PADDLE_ENFORCE_EQ(
w_dims.size(),
2UL,
errors::InvalidArgument("The input(weight) must be a 2D Tensor."));
PADDLE_ENFORCE_EQ(
weight_scale.dims().size(),
1UL,
errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor."));
PADDLE_ENFORCE_EQ(
w_dims[0] % 16,
0,
Expand All @@ -3978,6 +3980,30 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x,
"But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)",
x_dims[x_dims.size() - 1],
w_dims[1]));

// per-channel dequantization
if (group_size == -1) {
PADDLE_ENFORCE_EQ(
weight_scale_dims.size(),
1UL,
errors::InvalidArgument("The input(weight_scale) must be a 1D Tensor."
"in per-channel mode."));
} else /* groupwise dequantization */ {
PADDLE_ENFORCE_EQ(
weight_scale_dims.size(),
2UL,
errors::InvalidArgument("The input(weight_scale) must be a 2D Tensor"
" in groupwise mode."));
PADDLE_ENFORCE_EQ(
weight_scale_dims[0],
(w_dims[1] + (group_size - 1)) / group_size,
errors::InvalidArgument("The input(weight_scale) dim[0] must be equal "
"to Input(weight) dim[1] / group_size"
"But receive %d and %d",
weight_scale_dims[0],
(w_dims[1] + (group_size - 1)) / group_size));
}

auto out_dims = x_dims;
out_dims[out_dims.size() - 1] = n;
out->set_dims(out_dims);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,7 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x,
const MetaTensor& weight_scale,
const std::string& weight_dtype,
const int32_t arch,
const int32_t group_size,
MetaTensor* out);

void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
Expand Down
17 changes: 16 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5203,6 +5203,7 @@ void UnStackInferMeta(const MetaTensor& x,
void WeightQuantizeInferMeta(const MetaTensor& x,
const std::string& algo,
const int32_t arch,
const int32_t group_size,
MetaTensor* out,
MetaTensor* scale) {
PADDLE_ENFORCE_EQ(
Expand All @@ -5229,7 +5230,21 @@ void WeightQuantizeInferMeta(const MetaTensor& x,
phi::errors::InvalidArgument(
"The second dimension of input must be divisible by 16, but got[%d]",
x_dims[1]));
std::vector<int64_t> dim_scale({x_dims[1]});
PADDLE_ENFORCE_EQ(
((group_size == -1) || (group_size == 64) || (group_size == 128)),
true,
phi::errors::InvalidArgument(
"Currently, group_size only support -1, 64 or 128."));

std::vector<int64_t> dim_scale;
if (group_size != -1) {
int64_t scale_dim0 = (x_dims[0] + (group_size - 1)) / group_size;
int64_t scale_dim1 = x_dims[1];
dim_scale = std::vector<int64_t>({scale_dim0, scale_dim1});
} else {
dim_scale = std::vector<int64_t>({x_dims[1]});
}

std::vector<int64_t> dim_out;
if (algo == "weight_only_int8" || algo == "llm.int8") {
dim_out = std::vector<int64_t>({x_dims[1], x_dims[0]});
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ void QuantizeXPUInferMeta(const MetaTensor& x,
void WeightQuantizeInferMeta(const MetaTensor& x,
const std::string& algo,
const int32_t arch,
const int32_t group_size,
MetaTensor* out,
MetaTensor* scale);

Expand Down
29 changes: 22 additions & 7 deletions paddle/phi/kernels/cpu/weight_quantize_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ void quant_compute(const DeviceContext& dev_ctx,
DenseTensor* out,
DenseTensor* scale,
const std::string& algo,
const int32_t arch) {
const int32_t arch,
const int32_t group_size) {
PADDLE_ENFORCE_EQ(
((arch == 80) || (arch == 86) || (arch == 75) || (arch == 70)),
true,
Expand All @@ -51,7 +52,8 @@ void quant_compute(const DeviceContext& dev_ctx,

DenseTensor x_int(out->type());

if ((arch == 80) || (arch == 75) || (arch == 86)) {
if ((arch == 80) || (arch == 75) || (arch == 86) || (arch == 89) ||
(arch == 90)) {
x_int.Resize({static_cast<int64_t>(m), static_cast<int64_t>(n)});
} else {
// phi::Copy may change tensor meta info, here we transpose the quanted
Expand All @@ -71,9 +73,19 @@ void quant_compute(const DeviceContext& dev_ctx,
int_processed_2.Resize(out->dims());
dev_ctx.template Alloc<D>(&int_processed_2);
D* int_processed_2_data = int_processed_2.data<D>();
per_channel_scale(scale_data, x_data, m, n, bits == 8 ? 127.0f : 7.0f);

per_channel_quant<T, bits>(x_int_data, x_data, scale_data, m, n);
if (group_size == -1) {
per_channel_scale(scale_data, x_data, m, n, bits == 8 ? 127.0f : 7.0f);
per_channel_quant<T, bits>(x_int_data, x_data, scale_data, m, n);
} else {
group_wise_scale(scale_data,
x_data,
m,
n,
bits == 8 ? 127.0f : 7.0f,
static_cast<size_t>(group_size));

group_wise_quant<T, bits>(x_int_data, x_data, scale_data, m, n, group_size);
}
if (algo == "llm.int8") {
std::vector<int> axis = {1, 0};
funcs::Transpose<DeviceContext, int8_t, 2> trans;
Expand Down Expand Up @@ -105,14 +117,17 @@ void WeightQuantizeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& algo,
const int32_t arch,
const int32_t group_size,
DenseTensor* out,
DenseTensor* scale) {
dev_ctx.template Alloc<int8_t>(out);
dev_ctx.template Alloc<T>(scale);
if (algo == "weight_only_int8" || algo == "llm.int8") {
quant_compute<Context, T, int8_t, 8>(dev_ctx, x, out, scale, algo, arch);
quant_compute<Context, T, int8_t, 8>(
dev_ctx, x, out, scale, algo, arch, group_size);
} else if (algo == "weight_only_int4") {
quant_compute<Context, T, int8_t, 4>(dev_ctx, x, out, scale, algo, arch);
quant_compute<Context, T, int8_t, 4>(
dev_ctx, x, out, scale, algo, arch, group_size);
} else {
phi::errors::Unimplemented(
"The algo must be in ['weight_only_int8', 'weight_only_int4', "
Expand Down
Loading