Skip to content

Commit

Permalink
add apply_per_channel_scale op
Browse files Browse the repository at this point in the history
  • Loading branch information
freeliuzc committed Jan 3, 2024
1 parent 8d5f154 commit dd9346f
Show file tree
Hide file tree
Showing 8 changed files with 383 additions and 0 deletions.
9 changes: 9 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@
func : angle
backward : angle_grad

- op : apply_per_channel_scale
args: (Tensor x, Tensor scales)
output: Tensor(out)
infer_meta :
func : ApplyPerChannelScaleInferMeta
kernel :
func : apply_per_channel_scale
data_type : x

- op : argmax
args : (Tensor x, Scalar(int64_t) axis, bool keepdims = false, bool flatten = false, DataType dtype = DataType::INT64)
output : Tensor(out)
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2607,6 +2607,33 @@ void PReluInferMeta(const MetaTensor& x,
out->share_lod(x);
}

void ApplyPerChannelScaleInferMeta(const MetaTensor& x,
const MetaTensor& scales,
MetaTensor* out) {
auto x_dim = x.dims();
auto scales_dim = scales.dims();
PADDLE_ENFORCE_EQ(
x_dim.size(),
2,
phi::errors::InvalidArgument(
"The rank of Input(x) must be 2, but received %d.", x_dim.size()));

PADDLE_ENFORCE_EQ(scales_dim.size(),
1,
phi::errors::InvalidArgument(
"The rank of Input(scales) must be 1, but received %d.",
scales_dim.size()));

PADDLE_ENFORCE_EQ(
x_dim[1],
scales_dim[0],
phi::errors::InvalidArgument(
"The second dim of Input(x) must be equal to the first dim of scales,"
"but received %d and %d.",
x_dim[2],
scales_dim[1]));
}

inline void ExpandAspectRatios(const std::vector<float>& input_aspect_ratior,
bool flip,
std::vector<float>* output_aspect_ratior) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,10 @@ void RowConvInferMeta(const MetaTensor& x,
const MetaTensor& filter,
MetaTensor* out);

void ApplyPerChannelScaleInferMeta(const MetaTensor& x,
const MetaTensor& scales,
MetaTensor* out);

void PriorBoxInferMeta(const MetaTensor& input,
const MetaTensor& image,
const std::vector<float>& min_sizes,
Expand Down
30 changes: 30 additions & 0 deletions paddle/phi/kernels/apply_per_channel_scale_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/phi/core/dense_tensor.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

template <typename T, typename Context>
void ApplyPerChannelScaleKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scales,
DenseTensor* out);
} // namespace phi
187 changes: 187 additions & 0 deletions paddle/phi/kernels/gpu/apply_per_channel_scale_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/apply_per_channel_scale_kernel.h"

#include <assert.h>
#include <stdint.h>
#include <cmath>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/datatype_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

namespace {
#ifdef PADDLE_WITH_CUDA
template <typename T>
struct CUDA_HALF_2_TYPE_TARIS {};

template <>
struct CUDA_HALF_2_TYPE_TARIS<half> {
using type = half2;
};

#ifdef PADDLE_CUDA_BF16
template <>
struct CUDA_HALF_2_TYPE_TARIS<__nv_bfloat16> {
using type = __nv_bfloat162;
};
#endif

template <typename T>
struct HalfMul2 {};

template <>
struct HalfMul2<half2> {
static __device__ __forceinline__ half2 apply(const half2& x,
const half2& y) {
return __hmul2(x, y);
}
};

#ifdef PADDLE_CUDA_BF16
template <>
struct HalfMul2<__nv_bfloat162> {
static __device__ __forceinline__ __nv_bfloat162
apply(const __nv_bfloat162& x, const __nv_bfloat162& y) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
return __hmul2(x, y);
#else
float fxl, fxh, fyl, fyh;
fxl = __low2float(x);
fxh = __high2float(x);
fyl = __low2float(y);
fyh = __high2float(y);
return __floats2bfloat162_rn(fxl * fyl, fxh * fyh);
#endif
}
};
#endif

template <typename T, int kProcessRows, typename AccessType>
__global__ void apply_per_channel_scale(
const T* act, const T* scales, int rows, int cols, T* out) {
using HALF_2_TYPE = typename CUDA_HALF_2_TYPE_TARIS<T>::type;
static constexpr int kElems = sizeof(AccessType) / sizeof(T);
T scale[kElems], act_vec[kElems];
int col_offset = blockIdx.x * blockDim.x + threadIdx.x;
int row_offset = blockIdx.y;
if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows) return;
act += row_offset * kProcessRows * cols;
out += row_offset * kProcessRows * cols;
*reinterpret_cast<AccessType*>(scale) =
reinterpret_cast<const AccessType*>(scales)[col_offset];
#pragma unroll
for (int i = 0; i < kProcessRows; ++i) {
*reinterpret_cast<AccessType*>(act_vec) =
reinterpret_cast<const AccessType*>(act + i * cols)[col_offset];
if constexpr (kElems % 2 == 0 && (std::is_same_v<T, half> ||
std::is_same_v<T, __nv_bfloat16>)) {
#pragma unroll
for (int j = 0; j < kElems; j += 2) {
*reinterpret_cast<HALF_2_TYPE*>(act_vec + j) =
HalfMul2<HALF_2_TYPE>::apply(
*reinterpret_cast<HALF_2_TYPE*>(act_vec + j),
*reinterpret_cast<HALF_2_TYPE*>(scale + j));
}
} else {
#pragma unroll
for (int j = 0; j < kElems; ++j) {
act_vec[j] *= scale[j];
}
}
reinterpret_cast<AccessType*>(out + i * cols)[col_offset] =
*reinterpret_cast<AccessType*>(act_vec);
}
}

template <typename T, int kProcessRows, typename AccessType = float4>
void apply_per_channel_scale_launcher(const T* act,
const T* scales,
int rows,
int cols,
T* out,
cudaStream_t stream = 0) {
static constexpr int kElems = sizeof(AccessType) / sizeof(T);
dim3 block(128);
dim3 grid((cols / kElems + block.x - 1) / block.x,
(rows + kProcessRows - 1) / kProcessRows);
apply_per_channel_scale<T, kProcessRows, AccessType>
<<<grid, block, 0, stream>>>(act, scales, rows, cols, out);
}

} // namespace
#endif

template <typename T, typename Context>
void ApplyPerChannelScaleKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& scales,
DenseTensor* out) {
#ifdef PADDLE_WITH_CUDA
using DataType = typename PDDataTypeTraits<T>::DataType;
int rows = x.dims()[0];
int cols = x.dims()[1];
int elems = rows * cols;
const T* x_data = x.data<T>();
const T* scales_data = scales.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
if (elems < 2048 * 2048) {
apply_per_channel_scale_launcher<DataType, 1, float4>(
reinterpret_cast<const DataType*>(x_data),
reinterpret_cast<const DataType*>(scales_data),
rows,
cols,
reinterpret_cast<DataType*>(out_data),
dev_ctx.stream());
} else if (elems < 4096 * 4096) {
apply_per_channel_scale_launcher<DataType, 4, float4>(
reinterpret_cast<const DataType*>(x_data),
reinterpret_cast<const DataType*>(scales_data),
rows,
cols,
reinterpret_cast<DataType*>(out_data),
dev_ctx.stream());
} else if (elems < 8192 * 8192) {
apply_per_channel_scale_launcher<DataType, 8, float4>(
reinterpret_cast<const DataType*>(x_data),
reinterpret_cast<const DataType*>(scales_data),
rows,
cols,
reinterpret_cast<DataType*>(out_data),
dev_ctx.stream());
} else {
apply_per_channel_scale_launcher<DataType, 16, float4>(
reinterpret_cast<const DataType*>(x_data),
reinterpret_cast<const DataType*>(scales_data),
rows,
cols,
reinterpret_cast<DataType*>(out_data),
dev_ctx.stream());
}
#endif
}

} // namespace phi

PD_REGISTER_KERNEL(apply_per_channel_scale,
GPU,
ALL_LAYOUT,
phi::ApplyPerChannelScaleKernel,
phi::dtype::float16,
phi::dtype::bfloat16) {}
2 changes: 2 additions & 0 deletions python/paddle/nn/quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from .quant_layers import QuantStub # noqa: F401
from .quantized_linear import (
apply_per_channel_scale,
llm_int8_linear,
weight_dequantize,
weight_only_linear,
Expand All @@ -40,4 +41,5 @@
"llm_int8_linear",
"weight_quantize",
"weight_dequantize",
"apply_per_channel_scale",
]
44 changes: 44 additions & 0 deletions python/paddle/nn/quant/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,47 @@ def llm_int8_linear(
attrs=attrs,
)
return out


def apply_per_channel_scale(x, scales):
"""
Apply pre-quant per channel scale on activations
Args:
x (Tensor): Input tensor representing the activations, the data type can be float16 or bfloat16.
scales(Tensor): Per-channel scale factors for pre-quantization. Data type should be compatible with x.
Returns:
out (Tensor): The Tensor which is the pre-quant results, the data type is compatible with x.
Examples:
.. code-block:: python
>>> # doctest: +SKIP('No testing required')
>>> import paddle
>>> from paddle.nn.quant import apply_per_channel_scale
>>> paddle.seed(2023)
>>> x = paddle.rand(shape=[64, 32], dtype=paddle.float16)
>>> scales = paddle.rand(shape=[32], dtype=paddle.float16)
>>> out = apply_per_channel_scale(x, scales)
"""
arch = _get_arch_info()

assert (
arch >= 80
), f"Currently pre_quant_scale only support SM >= 80 but got {arch} "

if in_dynamic_mode:
return _C_ops.apply_per_channel_scale(x, scales)
else:
type = "apply_per_channel_scale"
helper = LayerHelper(type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)

helper.append_op(
type=type,
inputs={"x": x, "scales": scales},
outputs={"out": out},
)
return out
Loading

0 comments on commit dd9346f

Please sign in to comment.