From 7a2bb2fd93723e9f5f02ac3b4030b939dd9f65a7 Mon Sep 17 00:00:00 2001 From: enzodechine Date: Thu, 22 Feb 2024 14:33:55 +0800 Subject: [PATCH] [XPU]bind xblas fc (#61340) * run xblas by default when the device == XPU3 * run xblas by default when the device == XPU3 * reset xpu.cmake * Add compilation control * remove redundant code calling xdnn --- .../fusion/xpu/fused_gemm_epilogue_kernel.cc | 70 +- paddle/phi/kernels/xpu/bmm_grad_kernel.cc | 2 + paddle/phi/kernels/xpu/bmm_kernel.cc | 2 + paddle/phi/kernels/xpu/bmm_xpu_utils.h | 54 +- paddle/phi/kernels/xpu/xpu_api_wrapper.h | 672 ++++++++---------- test/xpu/test_matmul_v2_op_xpu.py | 23 +- 6 files changed, 360 insertions(+), 463 deletions(-) diff --git a/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc index 2b090136815204..981b11822eb0cc 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc @@ -45,41 +45,6 @@ void FusedGemmEpilogueKernel(const Context& dev_ctx, const XPUType* y_ptr = reinterpret_cast(y.data()); XPUType* out_ptr = reinterpret_cast(dev_ctx.template Alloc(out)); - decltype(&xpu_fc_wrapper) fc_api_list[5] = { - &xpu_fc_wrapper, - &xpu_fc_wrapper, - &xpu_fc_wrapper, - &xpu_fc_wrapper, - &xpu_fc_wrapper, - }; - - auto fccal_type = FCCalcType(); - - auto fc_api = fc_api_list[fccal_type]; - - // fc + bias + act - phi::XpuFcInfo fc_info; - auto mat_x_dims = - common::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1); - auto mat_y_dims = y.dims(); - phi::GetFCInfo(mat_x_dims, mat_y_dims, trans_x, trans_y, &fc_info); - int batch_size = fc_info.bs; - int m = fc_info.m; - int n = fc_info.n; - int k = fc_info.k; - int ldx = fc_info.stride_x; - int ldy = fc_info.stride_y; - int ldout = fc_info.stride_out; - float* max_x = fc_info.max_x; - float* max_y = fc_info.max_y; - float* max_out = fc_info.max_out; - - PADDLE_ENFORCE_LE( - batch_size, - 1, - errors::InvalidArgument( - "FusedGemm do not support batched fc now, but got batch size %d.", - batch_size)); const float* bias_ptr = reinterpret_cast(bias.data()); if (!std::is_same::value) { // TODO(lijin23): Now xblas and xdnn support fp32 bias only, may be removed @@ -93,25 +58,22 @@ void FusedGemmEpilogueKernel(const Context& dev_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); bias_ptr = bias_tmp; } - fc_api(xpu_ctx, - x_ptr, - y_ptr, - out_ptr, - m, - n, - k, - trans_x, - trans_y, - max_x, - max_y, - max_out, - ldx, - ldy, - ldout, - 1.0f, - 0, - bias_ptr, - act); + // fc + bias + act + phi::XpuFcInfo fc_info; + fc_info.bias = bias_ptr; + auto mat_x_dims = + common::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1); + auto mat_y_dims = y.dims(); + phi::GetFCInfo(mat_x_dims, mat_y_dims, trans_x, trans_y, &fc_info); + int batch_size = fc_info.bs; + PADDLE_ENFORCE_LE( + batch_size, + 1, + errors::InvalidArgument( + "FusedGemm do not support batched fc now, but got batch size %d.", + batch_size)); + MatMulXPUFunction( + xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f, false, act); } } // namespace fusion diff --git a/paddle/phi/kernels/xpu/bmm_grad_kernel.cc b/paddle/phi/kernels/xpu/bmm_grad_kernel.cc index 4c3b3dcd2c9b08..cbc98dd7ad9ace 100644 --- a/paddle/phi/kernels/xpu/bmm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/bmm_grad_kernel.cc @@ -35,6 +35,8 @@ void MatMul(const Context& dev_ctx, MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); + } else if (fccal_type == XPUFCCalcType::FC_FLOAT16) { + MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); } else { MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); } diff --git a/paddle/phi/kernels/xpu/bmm_kernel.cc b/paddle/phi/kernels/xpu/bmm_kernel.cc index 04afc3ef1b007c..ae80f12747ac15 100644 --- a/paddle/phi/kernels/xpu/bmm_kernel.cc +++ b/paddle/phi/kernels/xpu/bmm_kernel.cc @@ -70,6 +70,8 @@ void BmmKernel(const Context& dev_ctx, MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); + } else if (fccal_type == XPUFCCalcType::FC_FLOAT16) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); } else { MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); } diff --git a/paddle/phi/kernels/xpu/bmm_xpu_utils.h b/paddle/phi/kernels/xpu/bmm_xpu_utils.h index f0ac5c7e14ea1e..90d5b519739576 100644 --- a/paddle/phi/kernels/xpu/bmm_xpu_utils.h +++ b/paddle/phi/kernels/xpu/bmm_xpu_utils.h @@ -40,25 +40,41 @@ static void MatMulXPUFunction(const DenseTensor& x, int k = mat_dim_a.width_; int batch_size = mat_dim_a.batch_size_; // batch matmul - int r = xpu::fc_batched( - xpu_ctx, // Context* ctx, - batch_size, // int batch_size, - mat_dim_a.trans_, // bool x_trans, - mat_dim_b.trans_, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - 1.0, // float alpha, - reinterpret_cast(x.data()), // const TX* x, - mat_dim_a.stride_, // int stride_a, - reinterpret_cast(y.data()), // const TW* w, - mat_dim_b.stride_, // int stride_b, - 0.0, // float beta, - reinterpret_cast(data_c), // TY* y, - m * n, // int stride_c, - nullptr, // const float* x_maxptr, - nullptr); // const float* w_maxptr + int fccal_type = FCCalcType(); + decltype(&xblas_fc_batch_wrapper) + xblas_fc_batch_api_list[6] = { + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + }; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched"); + auto xblas_fc_batch_api = xblas_fc_batch_api_list[fccal_type]; + if (fccal_type == XPUFCCalcType::FC_FLOAT16 && + std::getenv("XPU_PADDLE_FC_FLOAT16") != nullptr) { + xblas_fc_batch_api = + &xblas_fc_batch_wrapper; + } + + xblas_fc_batch_api(xpu_ctx, + batch_size, + mat_dim_a.trans_, + mat_dim_b.trans_, + m, + n, + k, + 1.0, + reinterpret_cast(x.data()), + mat_dim_a.stride_, + reinterpret_cast(y.data()), + mat_dim_b.stride_, + 0.0, + reinterpret_cast(data_c), + m * n, + nullptr, + nullptr); } + } // namespace phi diff --git a/paddle/phi/kernels/xpu/xpu_api_wrapper.h b/paddle/phi/kernels/xpu/xpu_api_wrapper.h index d4d29aa7a4ad7f..5d6006b7a69bd1 100644 --- a/paddle/phi/kernels/xpu/xpu_api_wrapper.h +++ b/paddle/phi/kernels/xpu/xpu_api_wrapper.h @@ -22,6 +22,9 @@ #include "paddle/phi/backends/xpu/xpu_info.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "xblas/cublasLt.h" +namespace xblas = baidu::xpu::xblas; + namespace phi { using XPUTypeFP16 = typename XPUTypeTrait::Type; @@ -33,12 +36,17 @@ enum XPUFCCalcType { FC_FLOAT, FC_INT32_WITH_LL, FC_TF32, + FC_FLOAT16, }; template XPUFCCalcType FCCalcType() { - if (std::is_same::value || - std::is_same::value) { + if (std::getenv("XPU_PADDLE_FC_FLOAT16") != nullptr && + (std::is_same::value || + std::is_same::value || std::is_same::value)) { + return XPUFCCalcType::FC_FLOAT16; + } else if (std::is_same::value || + std::is_same::value) { return XPUFCCalcType::FC_INT16; } else if (std::getenv("XPU_PADDLE_FC_INT32") != nullptr) { return XPUFCCalcType::FC_INT32; @@ -50,7 +58,6 @@ XPUFCCalcType FCCalcType() { std::is_same::value) { return XPUFCCalcType::FC_TF32; } - return XPUFCCalcType::FC_INT16; } @@ -67,7 +74,13 @@ struct XpuFcInfo { float* max_x; float* max_y; float* max_out; + const float* bias; bool is_x_need_broadcast; + const float* scale_x; + const float* scale_y; + int scale_x_mode; + int scale_y_mode; + XpuFcInfo() : bs(0), m(0), @@ -81,7 +94,12 @@ struct XpuFcInfo { max_x(nullptr), max_y(nullptr), max_out(nullptr), - is_x_need_broadcast(false) {} + bias(nullptr), + is_x_need_broadcast(false), + scale_x(nullptr), + scale_y(nullptr), + scale_x_mode(0), + scale_y_mode(0) {} void InitFcInfo(int bs, int m, int n, @@ -202,25 +220,29 @@ static void GetFCInfo(const phi::DDim& x_dims, } template -static void xpu_fc_wrapper(xpu::Context* ctx, - const XPUType* x, - const XPUType* w, - XPUType* y, - int m, - int n, - int k, - bool x_trans, - bool w_trans, - const float* x_maxptr, - const float* w_maxptr, - float* y_maxptr, - int ldx, - int ldw, - int ldy, - float alpha, - float beta, - const float* bias, - const xpu::Activation_t& act) { +static void xblas_fc_wrapper(xpu::Context* ctx, + const XPUType* x, + const XPUType* w, + XPUType* y, + int m, + int n, + int k, + bool x_trans, + bool w_trans, + const float* x_maxptr, + const float* w_maxptr, + float* y_maxptr, + int ldx, + int ldw, + int ldy, + float alpha, + float beta, + const float* bias, + const xpu::Activation_t& act, + const float* scale_x, + const float* scale_w, + int scale_x_mode, + int scale_w_mode) { int r = 0; if (x_trans && std::getenv("XPU_PADDLE_FC_TRANS_A") != nullptr && std::is_same::value) { @@ -234,352 +256,222 @@ static void xpu_fc_wrapper(xpu::Context* ctx, r = xpu::transpose(ctx, x, l3_addr, shape, axis); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - r = xpu::fc_fusion(ctx, - l3_addr, - w, - y, - m, - n, - k, - false, - w_trans, - x_maxptr, - w_maxptr, - y_maxptr, - k, - ldw, - ldy, - alpha, - beta, - bias, - act); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); + r = xblas::fc_fusion(ctx, + l3_addr, + w, + y, + m, + n, + k, + false, + w_trans, + x_maxptr, + w_maxptr, + y_maxptr, + k, + ldw, + ldy, + alpha, + beta, + bias, + act, + scale_x, + scale_w, + scale_x_mode, + scale_w_mode); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion"); } else { - r = xpu::fc_fusion(ctx, - x, - w, - y, - m, - n, - k, - x_trans, - w_trans, - x_maxptr, - w_maxptr, - y_maxptr, - ldx, - ldw, - ldy, - alpha, - beta, - bias, - act); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_fusion"); + r = xblas::fc_fusion(ctx, + x, + w, + y, + m, + n, + k, + x_trans, + w_trans, + x_maxptr, + w_maxptr, + y_maxptr, + ldx, + ldw, + ldy, + alpha, + beta, + bias, + act, + scale_x, + scale_w, + scale_x_mode, + scale_w_mode); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion"); } } -template <> -void xpu_fc_wrapper(xpu::Context* ctx, - const XPUTypeBF16* x, - const XPUTypeBF16* w, - XPUTypeBF16* y, - int m, - int n, - int k, - bool x_trans, - bool w_trans, - const float* x_maxptr, - const float* w_maxptr, - float* y_maxptr, - int ldx, - int ldw, - int ldy, - float alpha, - float beta, - const float* bias, - const xpu::Activation_t& act) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_wrapper"); -} - -template <> -void xpu_fc_wrapper(xpu::Context* ctx, - const XPUTypeBF16* x, - const XPUTypeBF16* w, - XPUTypeBF16* y, - int m, - int n, - int k, - bool x_trans, - bool w_trans, - const float* x_maxptr, - const float* w_maxptr, - float* y_maxptr, - int ldx, - int ldw, - int ldy, - float alpha, - float beta, - const float* bias, - const xpu::Activation_t& act) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_wrapper"); -} - -template <> -void xpu_fc_wrapper(xpu::Context* ctx, - const XPUTypeBF16* x, - const XPUTypeBF16* w, - XPUTypeBF16* y, - int m, - int n, - int k, - bool x_trans, - bool w_trans, - const float* x_maxptr, - const float* w_maxptr, - float* y_maxptr, - int ldx, - int ldw, - int ldy, - float alpha, - float beta, - const float* bias, - const xpu::Activation_t& act) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_wrapper"); -} - -template <> -void xpu_fc_wrapper(xpu::Context* ctx, - const XPUTypeFP16* x, - const XPUTypeFP16* w, - XPUTypeFP16* y, - int m, - int n, - int k, - bool x_trans, - bool w_trans, - const float* x_maxptr, - const float* w_maxptr, - float* y_maxptr, - int ldx, - int ldw, - int ldy, - float alpha, - float beta, - const float* bias, - const xpu::Activation_t& act) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_wrapper"); -} - -template -static void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const XPUType* x, - int stride_x, - const XPUType* w, - int stride_w, - float beta, - XPUType* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::fc_batched( - xpu_ctx, // Context* ctx, - bs, // int batch_size, - trans_x, // bool x_trans, - trans_w, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - alpha, // float alpha, - reinterpret_cast(x), // const TX* x, - stride_x, // int stride_a, - reinterpret_cast(w), // const TW* w, - stride_w, // int stride_b, - 0.0, // float beta, - reinterpret_cast(y), // TY* y, - stride_y, // int stride_c, - x_maxptr, // const float* x_maxptr, - w_maxptr); // const float* w_maxptr - PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched"); -} - -template <> -void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const XPUTypeBF16* x, - int stride_x, - const XPUTypeBF16* w, - int stride_w, - float beta, - XPUTypeBF16* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); -} - -template <> -void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const XPUTypeBF16* x, - int stride_x, - const XPUTypeBF16* w, - int stride_w, - float beta, - XPUTypeBF16* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); -} - -template <> -void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const XPUTypeBF16* x, - int stride_x, - const XPUTypeBF16* w, - int stride_w, - float beta, - XPUTypeBF16* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); -} +#define DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUType, FCT) \ + template <> \ + void xblas_fc_wrapper(xpu::Context * ctx, \ + const XPUType* x, \ + const XPUType* w, \ + XPUType* y, \ + int m, \ + int n, \ + int k, \ + bool x_trans, \ + bool w_trans, \ + const float* x_maxptr, \ + const float* w_maxptr, \ + float* y_maxptr, \ + int ldx, \ + int ldw, \ + int ldy, \ + float alpha, \ + float beta, \ + const float* bias, \ + const xpu::Activation_t& act, \ + const float* scale_x, \ + const float* scale_w, \ + int scale_x_mode, \ + int scale_w_mode) { \ + int r = xpu::Error_t::INVALID_PARAM; \ + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_wrapper"); \ + } -template <> -void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const XPUTypeBF16* x, - int stride_x, - const XPUTypeBF16* w, - int stride_w, - float beta, - XPUTypeBF16* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); +DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, int_with_ll_t) +DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, int16_t) +DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, int32_t) +DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, XPUTypeFP16) +DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeFP16, int32_t) +DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeFP16, tfloat32) + +template +static void xblas_fc_batch_wrapper(xpu::Context* xpu_ctx, + int bs, + bool trans_x, + bool trans_w, + int m, + int n, + int k, + float alpha, + const XPUType* x, + int stride_x, + const XPUType* w, + int stride_w, + float beta, + XPUType* y, + int stride_y, + const float* x_maxptr, + const float* w_maxptr) { + int r = xblas::fc_batched( + xpu_ctx, + bs, + trans_x, + trans_w, + m, + n, + k, + alpha, + reinterpret_cast(x), + stride_x, + reinterpret_cast(w), + stride_w, + 0.0, + reinterpret_cast(y), + stride_y, + x_maxptr, + w_maxptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_batch_wrapper"); } -template <> -void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const XPUTypeBF16* x, - int stride_x, - const XPUTypeBF16* w, - int stride_w, - float beta, - XPUTypeBF16* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); -} +#define DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUType, FCT, TGEMM_OUT) \ + template <> \ + void xblas_fc_batch_wrapper( \ + xpu::Context * xpu_ctx, \ + int bs, \ + bool trans_x, \ + bool trans_w, \ + int m, \ + int n, \ + int k, \ + float alpha, \ + const XPUType* x, \ + int stride_x, \ + const XPUType* w, \ + int stride_w, \ + float beta, \ + XPUType* y, \ + int stride_y, \ + const float* x_maxptr, \ + const float* w_maxptr) { \ + int r = xpu::Error_t::INVALID_PARAM; \ + PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_batched"); \ + } -template <> -void xpu_fc_batch_wrapper(xpu::Context* xpu_ctx, - int bs, - bool trans_x, - bool trans_w, - int m, - int n, - int k, - float alpha, - const XPUTypeFP16* x, - int stride_x, - const XPUTypeFP16* w, - int stride_w, - float beta, - XPUTypeFP16* y, - int stride_y, - const float* x_maxptr, - const float* w_maxptr) { - int r = xpu::Error_t::INVALID_PARAM; - PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_fc_batch_wrapper"); -} +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, + int_with_ll_t, + XPUTypeFP16) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, tfloat32, XPUTypeFP16) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, float, XPUTypeFP16) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, + XPUTypeFP16, + XPUTypeFP16) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int32_t, XPUTypeFP16) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int16_t, XPUTypeFP16) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, int32_t, XPUTypeFP16) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int_with_ll_t, float) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, XPUTypeFP16, float) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, tfloat32, float) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int32_t, float) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int16_t, float) +DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, int32_t, float) template -static void MatMulXPUFunction(xpu::Context* xpu_ctx, - const T* x, - const T* y, - T* out, - const XpuFcInfo& fcinfo, - float alpha, - bool is_grad = false) { +static void MatMulXPUFunction( + xpu::Context* xpu_ctx, + const T* x, + const T* y, + T* out, + const XpuFcInfo& fcinfo, + float alpha, + bool is_grad = false, + xpu::Activation_t act = xpu::Activation_t::LINEAR) { using XPUType = typename XPUTypeTrait::Type; int fccal_type = FCCalcType(); - decltype(&xpu_fc_wrapper) fc_api_list[5] = { - &xpu_fc_wrapper, - &xpu_fc_wrapper, - &xpu_fc_wrapper, - &xpu_fc_wrapper, - &xpu_fc_wrapper, - }; - decltype(&xpu_fc_batch_wrapper) fc_batch_api_list[5] = { - &xpu_fc_batch_wrapper, - &xpu_fc_batch_wrapper, - &xpu_fc_batch_wrapper, - &xpu_fc_batch_wrapper, - &xpu_fc_batch_wrapper, + decltype(&xblas_fc_wrapper) xblas_fc_api_list[6] = { + &xblas_fc_wrapper, + &xblas_fc_wrapper, + &xblas_fc_wrapper, + &xblas_fc_wrapper, + &xblas_fc_wrapper, + &xblas_fc_wrapper, }; - auto fc_api = fc_api_list[fccal_type]; + decltype(&xblas_fc_batch_wrapper) + xblas_fc_batch_api_list[6] = { + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + &xblas_fc_batch_wrapper, + }; + + auto xblas_fc_api = xblas_fc_api_list[fccal_type]; + if (std::getenv("XPU_PADDLE_FC_GRAD_LOCAL") != nullptr) { if (is_grad) { - fc_api = fc_api_list[2]; + xblas_fc_api = xblas_fc_api_list[2]; } } + auto xblas_fc_batch_api = xblas_fc_batch_api_list[fccal_type]; - auto fc_batch_api = fc_batch_api_list[fccal_type]; - + if (fccal_type == XPUFCCalcType::FC_FLOAT16 && + std::getenv("XPU_PADDLE_FC_FLOAT16") != nullptr) { + xblas_fc_batch_api = + &xblas_fc_batch_wrapper; + } int m = fcinfo.m; int n = fcinfo.n; int k = fcinfo.k; @@ -593,27 +485,35 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, float* max_y = fcinfo.max_y; float* max_out = fcinfo.max_out; bool is_x_need_broadcast = fcinfo.is_x_need_broadcast; - + const float* bias = fcinfo.bias; + const float* scale_x = fcinfo.scale_x; + const float* scale_y = fcinfo.scale_y; + int scale_x_mode = fcinfo.scale_x_mode; + int scale_y_mode = fcinfo.scale_y_mode; if (batch_size <= 1) { - fc_api(xpu_ctx, - reinterpret_cast(x), - reinterpret_cast(y), - reinterpret_cast(out), - m, - n, - k, - trans_x, - trans_y, - max_x, - max_y, - max_out, - ldx, - ldy, - ldout, - alpha, - 0, - nullptr, - xpu::Activation_t::LINEAR); + xblas_fc_api(xpu_ctx, + reinterpret_cast(x), + reinterpret_cast(y), + reinterpret_cast(out), + m, + n, + k, + trans_x, + trans_y, + max_x, + max_y, + max_out, + ldx, + ldy, + ldout, + alpha, + 0, + bias, + act, + scale_x, + scale_y, + scale_x_mode, + scale_y_mode); } else { const XPUType* x_data = reinterpret_cast(x); if (is_x_need_broadcast) { @@ -629,23 +529,23 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, x_data = x_broadcast_data; } // batch matmul - fc_batch_api(xpu_ctx, // Context* ctx, - batch_size, // int batch_size, - trans_x, // bool x_trans, - trans_y, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - alpha, // float alpha, - x_data, // const TX* x, - ldx, // int stride_a, - reinterpret_cast(y), // const TW* w, - ldy, // int stride_b, - 0.0, // float beta, - reinterpret_cast(out), // TY* y, - ldout, // int stride_c, - max_x, // const float* x_maxptr, - max_y); // const float* w_maxptr + xblas_fc_batch_api(xpu_ctx, // Context* ctx, + batch_size, // int batch_size, + trans_x, // bool x_trans, + trans_y, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + alpha, // float alpha, + x_data, // const TX* x, + ldx, // int stride_a, + reinterpret_cast(y), // const TW* w, + ldy, // int stride_b, + 0.0, // float beta, + reinterpret_cast(out), // TY* y, + ldout, // int stride_c, + max_x, // const float* x_maxptr, + max_y); // const float* w_maxptr } } diff --git a/test/xpu/test_matmul_v2_op_xpu.py b/test/xpu/test_matmul_v2_op_xpu.py index 76909c5fd83253..0fae09badb44c8 100644 --- a/test/xpu/test_matmul_v2_op_xpu.py +++ b/test/xpu/test_matmul_v2_op_xpu.py @@ -20,6 +20,10 @@ create_test_class, get_xpu_op_support_types, ) +from op_test import ( + convert_float_to_uint16, + skip_check_grad_ci, +) from op_test_xpu import XPUOpTest import paddle @@ -69,15 +73,23 @@ def setUp(self): self.dtype = self.in_type self.config() self.op_type = "matmul_v2" - if self.dtype == np.float16 or self.dtype == "float16": - self.__class__.no_need_check_grad = True - x = np.random.random(self.x_shape).astype(self.dtype) - y = np.random.random(self.y_shape).astype(self.dtype) + + x = np.random.random(self.x_shape) + y = np.random.random(self.y_shape) + # -0.1 ~ 0.1 x = -0.1 + 0.2 * x y = -0.1 + 0.2 * y result = reference_matmul(x, y, self.trans_x, self.trans_y) + if self.dtype == np.uint16: + x = convert_float_to_uint16(x) + y = convert_float_to_uint16(y) + else: + x = x.astype(self.dtype) + y = y.astype(self.dtype) + result = result.astype(self.dtype) + self.inputs = { 'X': x, 'Y': y, @@ -263,6 +275,9 @@ def config(self): self.trans_x = False self.trans_y = False + @skip_check_grad_ci( + reason="[skip shape check] Use y_shape(17) to test case in ppyoloe." + ) class TestMatMulOp18(TestMatMulV2Op): """ case 18 : for ppyoloe model