diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 75d3d10d66f7d9..e67dfa5adf910e 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -292,6 +292,8 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(gather_squeeze_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(layer_norm_act_xpu_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) pass_library(fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(squeeze_excitation_fuse_pass inference DIR xpu DEPS diff --git a/paddle/fluid/framework/ir/xpu/layer_norm_act_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/layer_norm_act_xpu_fuse_pass.cc new file mode 100644 index 00000000000000..654deca45eb7ee --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/layer_norm_act_xpu_fuse_pass.cc @@ -0,0 +1,216 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +/* +Change layer_norm and act op to layer_norm_act_xpu op +For example: +graph: + x + | + layer_norm + | + leaky_relu + | + output +------------------------------------------------------ +After the pass is applied: + x + | + layer_norm_act_xpu + | + output +*/ + +struct LayerNormActXPUPattern : public PatternBase { + LayerNormActXPUPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type); + // declare operator node's name + PATTERN_DECL_NODE(layer_norm); + PATTERN_DECL_NODE(act); + // declare variable node's name + PATTERN_DECL_NODE(layer_norm_input); + PATTERN_DECL_NODE(layer_norm_bias); + PATTERN_DECL_NODE(layer_norm_scale); + PATTERN_DECL_NODE(layer_norm_out); + PATTERN_DECL_NODE(act_out); + + private: + std::string act_type_; +}; + +LayerNormActXPUPattern::LayerNormActXPUPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& act_type) + : PatternBase(pattern, name_scope, name_scope), act_type_(act_type) { + auto layer_norm = + pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm"); + auto layer_norm_input = pattern->NewNode(layer_norm_input_repr()) + ->AsInput() + ->assert_is_op_input("layer_norm", "X"); + auto layer_norm_bias = pattern->NewNode(layer_norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto layer_norm_scale = pattern->NewNode(layer_norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto layer_norm_out = pattern->NewNode(layer_norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y") + ->assert_has_n_outputs(1); + layer_norm_out->assert_is_op_input(act_type_, "X"); + layer_norm->LinksFrom({layer_norm_input, layer_norm_bias, layer_norm_scale}) + .LinksTo({layer_norm_out}); + + auto act = pattern->NewNode(act_repr())->assert_is_op(act_type_); + auto act_out = pattern->NewNode(act_out_repr()) + ->assert_is_op_output(act_type_, "Out") + ->AsOutput(); + act->LinksFrom({layer_norm_out}).LinksTo({act_out}); +} + +} // namespace patterns + +class LayerNormActXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + int ApplyImpl(ir::Graph* graph, const std::string& act_type) const; + + const std::string name_scope_{"layer_norm_act_xpu_fuse_pass"}; +}; + +void LayerNormActXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + int found_subgraph_count = 0; + for (auto act_type : {"leaky_relu"}) { + found_subgraph_count += ApplyImpl(graph, act_type); + } + AddStatis(found_subgraph_count); +} + +int LayerNormActXPUFusePass::ApplyImpl(ir::Graph* graph, + const std::string& act_type) const { + GraphPatternDetector gpd; + patterns::LayerNormActXPUPattern pattern( + gpd.mutable_pattern(), name_scope_, act_type); + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle LayerNormActXPUFusePass fuse"; + // declare operator node's name + GET_IR_NODE(layer_norm); + GET_IR_NODE(act); + // declare variable node's name + GET_IR_NODE(layer_norm_input); + GET_IR_NODE(layer_norm_bias); + GET_IR_NODE(layer_norm_scale); + GET_IR_NODE(layer_norm_out); + GET_IR_NODE(act_out); + auto* block = layer_norm->Op()->Block(); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + // delete useless node + std::unordered_set delete_nodes; + + float eps = PADDLE_GET_CONST(float, layer_norm->Op()->GetAttr("epsilon")); + int begin_norm_axis = + PADDLE_GET_CONST(int, layer_norm->Op()->GetAttr("begin_norm_axis")); + + std::string fused_op_out_name; + fused_op_out_name = act_out->Name(); + float act_param_ = 0.0f; + int act_type_ = 0; + if (!act_type.empty()) { + if (act_type == "leaky_relu") { + act_param_ = PADDLE_GET_CONST(float, act->Op()->GetAttr("alpha")); + act_type_ = static_cast(xpu::Activation_t::LEAKY_RELU); + } + } + + // Generate fused op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("layer_norm_act_xpu"); + // set attrs for fused op + fused_op_desc.SetAttr("begin_norm_axis", begin_norm_axis); + fused_op_desc.SetAttr("epsilon", eps); + fused_op_desc.SetAttr("act_param", act_param_); + fused_op_desc.SetAttr("act_type", act_type_); + + fused_op_desc.SetInput("x", {layer_norm_input->Name()}); + fused_op_desc.SetInput("bias", {layer_norm_bias->Name()}); + fused_op_desc.SetInput("scale", {layer_norm_scale->Name()}); + fused_op_desc.SetOutput("out", {fused_op_out_name}); + // relink fused op + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(layer_norm_input, fused_op); + IR_NODE_LINK_TO(layer_norm_bias, fused_op); + IR_NODE_LINK_TO(layer_norm_scale, fused_op); + IR_NODE_LINK_TO(fused_op, act_out); + + delete_nodes.insert({layer_norm, act, layer_norm_out}); + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + + return found_subgraph_count; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(layer_norm_act_xpu_fuse_pass, + paddle::framework::ir::LayerNormActXPUFusePass); + +REGISTER_PASS_CAPABILITY(layer_norm_act_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "layer_norm_act_xpu", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c41b59b6e3e03a..c7f3f87a4d192d 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -552,6 +552,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "squeeze_excitation_fuse_pass", "add_activation_xpu_fuse_pass", "add_layernorm_xpu_fuse_pass", + "layer_norm_act_xpu_fuse_pass", "fast_layernorm_xpu_fuse_pass", "bn_act_xpu_fuse_pass", "yolo_box_xpu_fuse_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 16840bf471cf0e..226c87b35d4584 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -197,6 +197,15 @@ func : generate_sequence_xpu data_type : dtype +- op : layer_norm_act_xpu + args : (Tensor x, Tensor scale, Tensor bias, int begin_norm_axis, float epsilon, int act_type, float act_param) + output : Tensor(out) + infer_meta : + func : LayerNormActXPUInferMeta + kernel : + func : layer_norm_act_xpu + data_type : x + - op : multi_encoder_xpu args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx) output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index b20e6e242bfacc..39defa8bdddd78 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -312,6 +312,8 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::INT32, phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"layer_norm_act_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fast_layernorm_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fc_xpu", diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 606fdf98727a51..8dfdf7f89fde77 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -1109,6 +1109,20 @@ void AddCMulXPUInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void LayerNormActXPUInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& bias, + int begin_norm_axis, + float epsilon, + int act_type, + float act_param, + MetaTensor* y) { + y->set_dims(x.dims()); + // y->share_lod(x); + y->set_dtype(x.dtype()); + y->set_layout(x.layout()); +} + void FusedScaleBiasReluConvBnstatsInferMeta( const MetaTensor& x, const MetaTensor& w, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index f6a672467932ee..ecda5cb9c88182 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -238,6 +238,15 @@ void AddCMulXPUInferMeta(const MetaTensor& x, const MetaTensor& w, MetaTensor* out); +void LayerNormActXPUInferMeta(const MetaTensor& x, + const MetaTensor& scale, + const MetaTensor& bias, + int begin_norm_axis, + float epsilon, + int act_type, + float act_param, + MetaTensor* y); + void FusedScaleBiasReluConvBnstatsInferMeta( const MetaTensor& x, const MetaTensor& w, diff --git a/paddle/phi/kernels/fusion/xpu/layer_norm_act_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/layer_norm_act_xpu_kernel.cc new file mode 100644 index 00000000000000..ead6959ba6debc --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/layer_norm_act_xpu_kernel.cc @@ -0,0 +1,132 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/norm_utils.h" + +namespace phi { +namespace fusion { + +template +void LayerNormActXPUKernel(const Context& ctx, + const DenseTensor& x, + const paddle::optional& scale, + const paddle::optional& bias, + int begin_norm_axis, + float epsilon, + int act_type, + float act_param, + DenseTensor* y) { + using XPUType = typename XPUTypeTrait::Type; + const auto& x_dims = x.dims(); + auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + const auto* x_data = x.data(); + + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + + // scale + const float* scale_data_fp32 = nullptr; + const auto* scale_ptr = scale.get_ptr(); + if (scale_ptr == nullptr) { + // no scale, do nothing + } else if (scale_ptr->dtype() == + phi::CppTypeToDataType::Type()) { + float* scale_data_temp = + RAII_GUARD.alloc_l3_or_gm(scale_ptr->numel()); + int r = xpu::cast( + ctx.x_context(), + reinterpret_cast(scale_ptr->data()), + scale_data_temp, + scale_ptr->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + scale_data_fp32 = scale_data_temp; + } else { + // no need to cast + scale_data_fp32 = scale_ptr->data(); + } + + // bias + const float* bias_data_fp32 = nullptr; + const auto* bias_ptr = bias.get_ptr(); + if (bias_ptr == nullptr) { + // no bias, do nothing + } else if (bias_ptr->dtype() == + phi::CppTypeToDataType::Type()) { + float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm(bias_ptr->numel()); + int r = xpu::cast( + ctx.x_context(), + reinterpret_cast(bias_ptr->data()), + bias_data_temp, + bias_ptr->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); + bias_data_fp32 = bias_data_temp; + } else { + // no need to cast + bias_data_fp32 = bias_ptr->data(); + } + + auto* out_data = ctx.template Alloc(y); + + xpu::Activation_t act(static_cast(act_type)); + if (act_type == xpu::Activation_t::LEAKY_RELU) { + act.leaky_alpha = act_param; + } else if (act_type == xpu::Activation_t::HARD_SIGMOID) { + act.hard_sigmoid_slope = act_param; + } +#ifdef PADDLE_WITH_XPU_PLUGIN + int r = xpu::plugin::layer_norm_act_fusion( + ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(out_data), + left, + right, + epsilon, + scale_data_fp32, + bias_data_fp32, + act); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm_act_fusion"); +#else + int r = xpu::layer_norm(ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(out_data), + left, + right, + epsilon, + scale_data_fp32, + bias_data_fp32, + nullptr, + nullptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); + r = xpu::leaky_relu(ctx.x_context(), + reinterpret_cast(out_data), + reinterpret_cast(out_data), + left * right, + act_param, + NULL, + NULL); +#endif +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(layer_norm_act_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::LayerNormActXPUKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h index 4e7e9b876af54c..fc2e437294f9c6 100644 --- a/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h +++ b/paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h @@ -134,6 +134,17 @@ DLL_EXPORT int bn_act_fusion_infer(Context* ctx, bool is_nchw, int act_type); +template +DLL_EXPORT int layer_norm_act_fusion(Context* ctx, + const T* x, + T* y, + int64_t m, + int64_t n, + float eps, + const float* scale, + const float* bias, + const Activation_t& act); + } // namespace plugin } // namespace api } // namespace xpu diff --git a/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/layer_norm_act_fusion.xpu b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/layer_norm_act_fusion.xpu new file mode 100644 index 00000000000000..71261f785fbba6 --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/layer_norm_act_fusion.xpu @@ -0,0 +1,256 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu2 { +namespace plugin { + +static inline __device__ float sum16(const float* ptr) { + float s0 = ptr[0] + ptr[8]; + float s1 = ptr[1] + ptr[9]; + float s2 = ptr[2] + ptr[10]; + float s3 = ptr[3] + ptr[11]; + float s4 = ptr[4] + ptr[12]; + float s5 = ptr[5] + ptr[13]; + float s6 = ptr[6] + ptr[14]; + float s7 = ptr[7] + ptr[15]; + s0 = s0 + s1; + s2 = s2 + s3; + s4 = s4 + s5; + s6 = s6 + s7; + s0 = s0 + s2; + s4 = s4 + s6; + return s0 + s4; +} + +template +static __device__ void update_sum_and_squaresum(T* a, + int size, + float* sum, + float* squaresum) { + __simd__ float sum_buf[16]; + __simd__ float squaresum_buf[16]; + float32x16_t al; + float32x16_t ah; + int rounddown_size = rounddown32(size - 1); + unsigned int mask = -1; + if ((size % 32) != 0) { + mask = ~(-1 << (size % 32)); + } + vload2_lm_mz(a + rounddown_size, al, ah, mask); + float32x16_t vsum = vvadd_float32x16(al, ah); + al = vvmul_float32x16(al, al); + ah = vvmul_float32x16(ah, ah); + float32x16_t vsquaresum = vvadd_float32x16(al, ah); + for (int i = 0; i < rounddown_size; i += 32) { + vload2_lm(a + i, al, ah); + vsum = vvadd_float32x16(vsum, al); + vsum = vvadd_float32x16(vsum, ah); + al = vvmul_float32x16(al, al); + ah = vvmul_float32x16(ah, ah); + vsquaresum = vvadd_float32x16(vsquaresum, al); + vsquaresum = vvadd_float32x16(vsquaresum, ah); + } + vstore_lm_float32x16(sum_buf, vsum); + vstore_lm_float32x16(squaresum_buf, vsquaresum); + mfence_lm(); + *sum = sum16(sum_buf); + *squaresum = sum16(squaresum_buf); +} + +template +static __device__ void vector_scale_and_bias_and_act_align32( + T* a, + int size, + float mean, + float var, + _shared_ptr_ const float* scale_sm, + _shared_ptr_ const float* bias_sm, + bool do_scale_bias, + float act_param) { + float32x16_t al; + float32x16_t ah; + float32x16_t bl; + float32x16_t bh; + mean = 0.0f - mean; + if (do_scale_bias) { + // ((a + b) - mean) * var * scale + bias + for (int i = 0; i < size; i += 32) { + vload2_lm(a + i, al, ah); + vload2_sm(scale_sm + i, bl, bh); + al = svadd_float32x16(mean, al); + ah = svadd_float32x16(mean, ah); + al = svmul_float32x16(var, al); + ah = svmul_float32x16(var, ah); + al = vvmul_float32x16(bl, al); + ah = vvmul_float32x16(bh, ah); + vload2_sm(bias_sm + i, bl, bh); + al = vvadd_float32x16(bl, al); + ah = vvadd_float32x16(bh, ah); + bl = svmul_float32x16(act_param, al); + bh = svmul_float32x16(act_param, ah); + al = vvmax_float32x16(al, bl); + ah = vvmax_float32x16(ah, bh); + vstore2_lm(a + i, al, ah); + } + } else { + // ((a + b) - mean) * var + for (int i = 0; i < size; i += 32) { + vload2_lm(a + i, al, ah); + al = svadd_float32x16(mean, al); + ah = svadd_float32x16(mean, ah); + al = svmul_float32x16(var, al); + ah = svmul_float32x16(var, ah); + bl = svmul_float32x16(act_param, al); + bh = svmul_float32x16(act_param, ah); + al = vvmax_float32x16(al, bl); + ah = vvmax_float32x16(ah, bh); + vstore2_lm(a + i, al, ah); + } + } + mfence_lm(); +} + +template +__global__ void fast_layer_norm_act_tiny_align32(float epsilon, + int64_t m, + int64_t n, + const T* x, + T* y, + const float* scale, + const float* bias, + float act_param) { + int cid = core_id(); + int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = ncores * cluster_num(); + int64_t mstart = 0; + int64_t mend = 0; + partition(tid, nthreads, m, 1, &mstart, &mend); + if (mstart >= mend) { + return; + } + + float one_div_n = 1.0f / n; + constexpr int lm_buffer_size = 1664 * sizeof(float) / sizeof(T); + constexpr int sm_buffer_size = 1664 * 16; + __simd__ T xlm[lm_buffer_size]; + __simd__ __shared__ float scale_sm[sm_buffer_size]; + __simd__ __shared__ float bias_sm[sm_buffer_size]; + int block_cnt = lm_buffer_size / n; + float sum = 0.0f; + float squaresum = 0.0f; + bool do_scale_bias = false; + if (scale != nullptr && bias != nullptr) { + do_scale_bias = true; + } + if (cid == 0 && do_scale_bias) { + GM2SM_ASYNC(scale, scale_sm, n * sizeof(float)); + GM2SM(bias, bias_sm, n * sizeof(float)); + } + sync_all(); + for (int64_t i = mstart; i < mend; i += block_cnt) { + int readlen = min((mend - i) * n, block_cnt * n); + GM2LM(x + i * n, xlm, readlen * sizeof(T)); + for (int64_t j = 0; j < readlen; j += n) { + update_sum_and_squaresum(xlm + j, n, &sum, &squaresum); + float sample_mean = sum * one_div_n; + float sample_var = squaresum * one_div_n - sample_mean * sample_mean; + float rstd = 1.0f / sqrt(sample_var + epsilon); + vector_scale_and_bias_and_act_align32(xlm + j, + n, + sample_mean, + rstd, + scale_sm, + bias_sm, + do_scale_bias, + act_param); + } + LM2GM(xlm, y + i * n, readlen * sizeof(T)); + } +} + +template +__global__ void fast_layer_norm_act_tiny_common(float epsilon, + int64_t m, + int64_t n, + const T* x, + T* y, + const float* scale, + const float* bias, + float act_param) { + int cid = core_id(); + int ncores = core_num(); + int tid = cid * cluster_num() + cluster_id(); + int nthreads = ncores * cluster_num(); + + float one_div_n = 1.0f / n; + constexpr int lm_buffer_size = 1664 * sizeof(float) / sizeof(T); + constexpr int sm_buffer_size = 1664 * 16; + __simd__ T xlm[lm_buffer_size]; + __simd__ __shared__ float scale_sm[sm_buffer_size]; + __simd__ __shared__ float bias_sm[sm_buffer_size]; + float sum = 0.0f; + float squaresum = 0.0f; + bool do_scale_bias = false; + if (scale != nullptr && bias != nullptr) { + do_scale_bias = true; + } + if (cid == 0 && do_scale_bias) { + GM2SM_ASYNC(scale, scale_sm, n * sizeof(float)); + GM2SM(bias, bias_sm, n * sizeof(float)); + } + sync_all(); + for (int64_t i = tid; i < m; i += nthreads) { + GM2LM(x + i * n, xlm, n * sizeof(T)); + update_sum_and_squaresum(xlm, n, &sum, &squaresum); + float sample_mean = sum * one_div_n; + float sample_var = squaresum * one_div_n - sample_mean * sample_mean; + float rstd = 1.0f / sqrt(sample_var + epsilon); + vector_scale_and_bias_and_act_align32( + xlm, n, sample_mean, rstd, scale_sm, bias_sm, do_scale_bias, act_param); + LM2GM(xlm, y + i * n, n * sizeof(T)); + } +} + +#define _XPU_DEF__FAST_LAYER_NORM_TINY_(DTYPE) \ + template __global__ void fast_layer_norm_act_tiny_common( \ + float epsilon, \ + int64_t m, \ + int64_t n, \ + const DTYPE* x, \ + DTYPE* y, \ + const float* scale, \ + const float* bias, \ + float act_param); \ + template __global__ void fast_layer_norm_act_tiny_align32( \ + float epsilon, \ + int64_t m, \ + int64_t n, \ + const DTYPE* x, \ + DTYPE* y, \ + const float* scale, \ + const float* bias, \ + float act_param); +_XPU_DEF__FAST_LAYER_NORM_TINY_(float16); +_XPU_DEF__FAST_LAYER_NORM_TINY_(float); + +} // namespace plugin +} // namespace xpu2 diff --git a/paddle/phi/kernels/xpu/plugin/src/wrapper/layer_norm_act_fusion.cpp b/paddle/phi/kernels/xpu/plugin/src/wrapper/layer_norm_act_fusion.cpp new file mode 100644 index 00000000000000..d63002d4fc9a7b --- /dev/null +++ b/paddle/phi/kernels/xpu/plugin/src/wrapper/layer_norm_act_fusion.cpp @@ -0,0 +1,176 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +/* + * copyright (C) 2022 KUNLUNXIN, Inc + */ + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu2 { +namespace plugin { +template +__attribute__((global)) void fast_layer_norm_act_tiny_common(float epsilon, + int64_t m, + int64_t n, + const T* x, + T* y, + const float* scale, + const float* bias, + float act_param); +template +__attribute__((global)) void fast_layer_norm_act_tiny_align32( + float epsilon, + int64_t m, + int64_t n, + const T* x, + T* y, + const float* scale, + const float* bias, + float act_param); + +} // namespace plugin +} // namespace xpu2 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +template +static int cpu_wrapper(Context* ctx, + const T* x, + T* y, + int64_t m, + int64_t n, + float eps, + const float* scale, + const float* bias, + const Activation_t& act) { + for (int64_t i = 0; i < m; i++) { + float sum = 0.0f; + float square_sum = 0.0f; + for (int64_t j = 0; j < n; j++) { + float v = static_cast(x[i * n + j]); + sum += v; + square_sum += v * v; + } + float mean_value = sum / n; + float var_value = square_sum / n - mean_value * mean_value; + float rstd = 1.0f / std::sqrt(var_value + eps); + for (int64_t j = 0; j < n; j++) { + float v = static_cast(x[i * n + j]); + float scale_value = ((scale == nullptr) ? 1.0f : scale[j]); + float bias_value = ((bias == nullptr) ? 0.0f : bias[j]); + float out = (v - mean_value) * rstd * scale_value + bias_value; + y[i * n + j] = static_cast(out); + } + } + + float act_param = 0.f; + if (act.type == api::Activation_t::LEAKY_RELU) { + act_param = act.leaky_alpha; + } + int64_t mxn = m * n; + for (int64_t i = 0; i < mxn; i++) { + y[i] = fmax(y[i], y[i] * act_param); + } + return SUCCESS; +} + +template +static int xpu2_wrapper(Context* ctx, + const T* x, + T* y, + int64_t m, + int64_t n, + float eps, + const float* scale, + const float* bias, + const Activation_t& act) { + float act_param = 0.f; + if (act.type == api::Activation_t::LEAKY_RELU) { + act_param = act.leaky_alpha; + } + if (n <= 832) { + if (n % 32 == 0 && n < 128) { + xpu2::plugin::fast_layer_norm_act_tiny_align32 + <<ncluster(), 64, ctx->xpu_stream>>>( + eps, m, n, x, y, scale, bias, act_param); + } else { + xpu2::plugin::fast_layer_norm_act_tiny_common + <<ncluster(), 64, ctx->xpu_stream>>>( + eps, m, n, x, y, scale, bias, act_param); + } + } else { + layer_norm(ctx, x, y, m, n, eps, scale, bias, nullptr, nullptr); + leaky_relu(ctx, y, y, m * n, act_param, NULL, NULL); + } + + return SUCCESS; +} + +template +int layer_norm_act_fusion(Context* ctx, + const T* x, + T* y, + int64_t m, + int64_t n, + float eps, + const float* scale, + const float* bias, + const Activation_t& act) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "layer_norm_act_fusion", T); + WRAPPER_DUMP_PARAM5(ctx, x, y, m, n, eps); + WRAPPER_DUMP_PARAM2(ctx, scale, bias); + WRAPPER_DUMP(ctx); + int64_t xylen = -1; + WRAPPER_CHECK_SHAPE(ctx, &xylen, {m, n}); + WRAPPER_CHECK_2PTRS(ctx, T, xylen, x, y); + WRAPPER_ASSERT_GE(ctx, eps, 0); + WRAPPER_CHECK_PTR_OR_NULL(ctx, float, n, scale); + WRAPPER_CHECK_PTR_OR_NULL(ctx, float, n, bias); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, x, y, m, n, eps, scale, bias, act); + } + if (ctx->dev().type() == api::kXPU2) { + return xpu2_wrapper(ctx, x, y, m, n, eps, scale, bias, act); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +template int layer_norm_act_fusion(Context*, + const float*, + float*, + int64_t, + int64_t, + float, + const float*, + const float*, + const Activation_t& act); +template int layer_norm_act_fusion(Context*, + const float16*, + float16*, + int64_t, + int64_t, + float, + const float*, + const float*, + const Activation_t& act); + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/test/ir/inference/test_xpu_layer_norm_act_fuse_pass.py b/test/ir/inference/test_xpu_layer_norm_act_fuse_pass.py new file mode 100644 index 00000000000000..141b5d786691f4 --- /dev/null +++ b/test/ir/inference/test_xpu_layer_norm_act_fuse_pass.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestXpuLayerNormActFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["layer_norm_act_xpu"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + batch_size = draw(st.integers(min_value=1, max_value=50)) + x_shape = [batch_size, 16, 128] + y_shape = x_shape + + epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001)) + begin_norm_axis = 2 + layer_norm_op = OpConfig( + "layer_norm", + inputs={ + "X": ["x"], + "Scale": ["layer_norm_scale"], + "Bias": ["layer_norm_bias"], + }, + outputs={ + "Y": ["layer_norm_out"], + "Mean": ["layer_norm_mean"], + "Variance": ["layer_norm_var"], + }, + begin_norm_axis=begin_norm_axis, + epsilon=epsilon, + ) + + alpha = draw(st.floats(min_value=0.0000001, max_value=0.001)) + relu_op = OpConfig( + "leaky_relu", + inputs={ + "X": ["layer_norm_out"], + }, + outputs={ + "Out": ["relu_out"], + }, + alpha=alpha, + ) + + sub_graph = [layer_norm_op, relu_op] + + program_config = ProgramConfig( + ops=sub_graph, + weights={ + "layer_norm_scale": TensorConfig(shape=[x_shape[2]]), + "layer_norm_bias": TensorConfig(shape=[x_shape[2]]), + }, + inputs={ + "x": TensorConfig(shape=x_shape), + }, + outputs=["relu_out"], + ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["layer_norm_act_xpu_fuse_pass"], + ) + + +if __name__ == "__main__": + np.random.seed(200) + unittest.main()