Skip to content

Commit

Permalink
add weight_dequantize python api (PaddlePaddle#57844)
Browse files Browse the repository at this point in the history
* add weight_dequantize python api

* fix comment

* fix doctest

* update

* update
  • Loading branch information
yuanlehome authored Oct 10, 2023
1 parent 60d3cb5 commit 17d7383
Show file tree
Hide file tree
Showing 17 changed files with 337 additions and 131 deletions.
11 changes: 10 additions & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,15 @@
intermediate: warprnntgrad
backward : warprnnt_grad

- op : weight_dequantize
args : (Tensor x, Tensor scale, str algo="weight_only_int8", DataType out_dtype=DataType::FLOAT16)
output : Tensor(out)
infer_meta :
func : WeightDequantizeInferMeta
kernel :
func : weight_dequantize
data_type : out_dtype

- op : weight_only_linear
args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, str weight_dtype)
output : Tensor(out)
Expand All @@ -2798,7 +2807,7 @@
backward: weight_only_linear_grad

- op : weight_quantize
args : (Tensor x, str algo = "weight_only_int8")
args : (Tensor x, str algo="weight_only_int8")
output : Tensor(out), Tensor(scale)
infer_meta :
func : WeightQuantizeInferMeta
Expand Down
29 changes: 29 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3227,6 +3227,35 @@ void Unpool3dInferMeta(const MetaTensor& x,
}
}

void WeightDequantizeInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const std::string& algo,
DataType out_dtype,
MetaTensor* out) {
PADDLE_ENFORCE_EQ(x.dims().size(),
2UL,
phi::errors::InvalidArgument(
"The x tensor of dequantize op must be 2D, but got[%d]",
x.dims().size()));
PADDLE_ENFORCE_EQ(
scale.dims().size(),
1UL,
phi::errors::InvalidArgument(
"The scale tensor of dequantize op must be 1D, but got[%d]",
scale.dims().size()));
PADDLE_ENFORCE_EQ(scale.dims()[0],
x.dims()[0],
phi::errors::InvalidArgument(
"The scale tensor's shape must be equal to the x "
"tensor's shape, but got [%d] not equal to [%d]",
scale.dims()[0],
x.dims()[0]));
int n = x.dims()[1];
int k = x.dims()[0];
out->set_dims(phi::make_ddim({n, k}));
out->set_dtype(out_dtype);
}

} // namespace phi

PD_REGISTER_INFER_META_FN(add_raw, phi::ElementwiseRawInferMeta);
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,4 +493,10 @@ void Unpool3dInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void WeightDequantizeInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const std::string& algo,
DataType out_dtype,
MetaTensor* out);

} // namespace phi
26 changes: 13 additions & 13 deletions paddle/phi/kernels/cpu/weight_quantize_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/weight_quantize_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
Expand Down
40 changes: 12 additions & 28 deletions paddle/phi/kernels/funcs/weight_dequant_functor.h
Original file line number Diff line number Diff line change
@@ -1,32 +1,16 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

Expand Down
26 changes: 13 additions & 13 deletions paddle/phi/kernels/funcs/weight_only_gemv.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

Expand Down
53 changes: 53 additions & 0 deletions paddle/phi/kernels/gpu/weight_dequantize_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/weight_dequantize_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/transpose_kernel.h"

#if defined(PADDLE_WITH_CUTLASS)
#include "paddle/phi/kernels/funcs/weight_dequant_functor.h"
#endif

namespace phi {

template <typename T, typename Context>
void WeightDequantizeKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const std::string& algo,
DataType out_dtype,
DenseTensor* out) {
#if defined(PADDLE_WITH_CUTLASS)
auto out_dims = out->dims();
dev_ctx.template Alloc<T>(out);
WeightDequantize<T, Context>(dev_ctx, x, scale, algo, true, out);
out->Resize({{out_dims[1], out_dims[0]}});
auto out_tmp = Transpose<T, Context>(dev_ctx, *out, {1, 0});
out->ShareDataWith(out_tmp);
#else
PADDLE_THROW(
phi::errors::PreconditionNotMet("Not compiled with WITH_CUTLASS=ON"));
#endif
}

} // namespace phi

PD_REGISTER_KERNEL(weight_dequantize,
GPU,
ALL_LAYOUT,
phi::WeightDequantizeKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
44 changes: 16 additions & 28 deletions paddle/phi/kernels/gpu/weight_only_linear_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,32 +1,16 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/weight_only_linear_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
Expand Down Expand Up @@ -62,8 +46,12 @@ void WeightOnlyLinearGradKernel(const Context& dev_ctx,
dev_ctx, weight, weight_scale, algo, true, &weight_dequantized);
MatmulKernel<T, Context>(
dev_ctx, out_grad, weight_dequantized, false, false, x_grad);
#else
PADDLE_THROW(
phi::errors::PreconditionNotMet("Not compiled with WITH_CUTLASS=ON"));
#endif
}

} // namespace phi

PD_REGISTER_KERNEL(weight_only_linear_grad,
Expand Down
27 changes: 14 additions & 13 deletions paddle/phi/kernels/gpu/weight_only_linear_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/kernels/weight_only_linear_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/datatype_traits.h"
Expand Down
29 changes: 29 additions & 0 deletions paddle/phi/kernels/weight_dequantize_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void WeightDequantizeKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scale,
const std::string& algo,
DataType out_dtype,
DenseTensor* out);

} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/kernels/weight_only_linear_grad_kernel.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/weight_only_linear_kernel.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/kernels/weight_quantize_kernel.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -21,4 +24,5 @@ void WeightQuantizeKernel(const Context& dev_ctx,
const std::string& algo,
DenseTensor* out,
DenseTensor* scale);

} // namespace phi
Loading

0 comments on commit 17d7383

Please sign in to comment.