Skip to content

Commit

Permalink
[XPU]bind xblas fc (#61340)
Browse files Browse the repository at this point in the history
* run xblas by default when the device == XPU3

* run xblas by default when the device == XPU3

* reset xpu.cmake

* Add compilation control

* remove redundant code calling xdnn
  • Loading branch information
runzhech authored Feb 22, 2024
1 parent 8d704b7 commit 7a2bb2f
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 463 deletions.
70 changes: 16 additions & 54 deletions paddle/phi/kernels/fusion/xpu/fused_gemm_epilogue_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,6 @@ void FusedGemmEpilogueKernel(const Context& dev_ctx,
const XPUType* y_ptr = reinterpret_cast<const XPUType*>(y.data<T>());
XPUType* out_ptr = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(out));

decltype(&xpu_fc_wrapper<XPUType, int16_t>) fc_api_list[5] = {
&xpu_fc_wrapper<XPUType, int16_t>,
&xpu_fc_wrapper<XPUType, int32_t>,
&xpu_fc_wrapper<XPUType, float>,
&xpu_fc_wrapper<XPUType, int_with_ll_t>,
&xpu_fc_wrapper<XPUType, tfloat32>,
};

auto fccal_type = FCCalcType<XPUType>();

auto fc_api = fc_api_list[fccal_type];

// fc + bias + act
phi::XpuFcInfo fc_info;
auto mat_x_dims =
common::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1);
auto mat_y_dims = y.dims();
phi::GetFCInfo(mat_x_dims, mat_y_dims, trans_x, trans_y, &fc_info);
int batch_size = fc_info.bs;
int m = fc_info.m;
int n = fc_info.n;
int k = fc_info.k;
int ldx = fc_info.stride_x;
int ldy = fc_info.stride_y;
int ldout = fc_info.stride_out;
float* max_x = fc_info.max_x;
float* max_y = fc_info.max_y;
float* max_out = fc_info.max_out;

PADDLE_ENFORCE_LE(
batch_size,
1,
errors::InvalidArgument(
"FusedGemm do not support batched fc now, but got batch size %d.",
batch_size));
const float* bias_ptr = reinterpret_cast<const float*>(bias.data<T>());
if (!std::is_same<T, float>::value) {
// TODO(lijin23): Now xblas and xdnn support fp32 bias only, may be removed
Expand All @@ -93,25 +58,22 @@ void FusedGemmEpilogueKernel(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_ptr = bias_tmp;
}
fc_api(xpu_ctx,
x_ptr,
y_ptr,
out_ptr,
m,
n,
k,
trans_x,
trans_y,
max_x,
max_y,
max_out,
ldx,
ldy,
ldout,
1.0f,
0,
bias_ptr,
act);
// fc + bias + act
phi::XpuFcInfo fc_info;
fc_info.bias = bias_ptr;
auto mat_x_dims =
common::flatten_to_2d(x.dims(), trans_x ? 1 : x.dims().size() - 1);
auto mat_y_dims = y.dims();
phi::GetFCInfo(mat_x_dims, mat_y_dims, trans_x, trans_y, &fc_info);
int batch_size = fc_info.bs;
PADDLE_ENFORCE_LE(
batch_size,
1,
errors::InvalidArgument(
"FusedGemm do not support batched fc now, but got batch size %d.",
batch_size));
MatMulXPUFunction<XPUType>(
xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f, false, act);
}

} // namespace fusion
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/xpu/bmm_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ void MatMul(const Context& dev_ctx,
MatMulXPUFunction<T, float>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
MatMulXPUFunction<T, int_with_ll_t>(a, b, out, trans_a, trans_b, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT16) {
MatMulXPUFunction<T, float16>(a, b, out, trans_a, trans_b, xpu_ctx);
} else {
MatMulXPUFunction<T, int16_t>(a, b, out, trans_a, trans_b, xpu_ctx);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/xpu/bmm_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ void BmmKernel(const Context& dev_ctx,
MatMulXPUFunction<T, float>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) {
MatMulXPUFunction<T, int_with_ll_t>(x, y, out, trans_x, trans_y, xpu_ctx);
} else if (fccal_type == XPUFCCalcType::FC_FLOAT16) {
MatMulXPUFunction<T, float16>(x, y, out, trans_x, trans_y, xpu_ctx);
} else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, xpu_ctx);
}
Expand Down
54 changes: 35 additions & 19 deletions paddle/phi/kernels/xpu/bmm_xpu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,41 @@ static void MatMulXPUFunction(const DenseTensor& x,
int k = mat_dim_a.width_;
int batch_size = mat_dim_a.batch_size_;
// batch matmul
int r = xpu::fc_batched<XPUType, XPUType, XPUType, FCT>(
xpu_ctx, // Context* ctx,
batch_size, // int batch_size,
mat_dim_a.trans_, // bool x_trans,
mat_dim_b.trans_, // bool w_trans,
m, // int m,
n, // int n,
k, // int k,
1.0, // float alpha,
reinterpret_cast<const XPUType*>(x.data<T>()), // const TX* x,
mat_dim_a.stride_, // int stride_a,
reinterpret_cast<const XPUType*>(y.data<T>()), // const TW* w,
mat_dim_b.stride_, // int stride_b,
0.0, // float beta,
reinterpret_cast<XPUType*>(data_c), // TY* y,
m * n, // int stride_c,
nullptr, // const float* x_maxptr,
nullptr); // const float* w_maxptr
int fccal_type = FCCalcType<XPUType>();
decltype(&xblas_fc_batch_wrapper<XPUType, int16_t, float>)
xblas_fc_batch_api_list[6] = {
&xblas_fc_batch_wrapper<XPUType, int16_t, float>,
&xblas_fc_batch_wrapper<XPUType, int32_t, float>,
&xblas_fc_batch_wrapper<XPUType, float, float>,
&xblas_fc_batch_wrapper<XPUType, int_with_ll_t, float>,
&xblas_fc_batch_wrapper<XPUType, tfloat32, float>,
&xblas_fc_batch_wrapper<XPUType, XPUTypeFP16, float>,
};

PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_batched");
auto xblas_fc_batch_api = xblas_fc_batch_api_list[fccal_type];
if (fccal_type == XPUFCCalcType::FC_FLOAT16 &&
std::getenv("XPU_PADDLE_FC_FLOAT16") != nullptr) {
xblas_fc_batch_api =
&xblas_fc_batch_wrapper<XPUType, XPUTypeFP16, XPUTypeFP16>;
}

xblas_fc_batch_api(xpu_ctx,
batch_size,
mat_dim_a.trans_,
mat_dim_b.trans_,
m,
n,
k,
1.0,
reinterpret_cast<const XPUType*>(x.data<T>()),
mat_dim_a.stride_,
reinterpret_cast<const XPUType*>(y.data<T>()),
mat_dim_b.stride_,
0.0,
reinterpret_cast<XPUType*>(data_c),
m * n,
nullptr,
nullptr);
}

} // namespace phi
Loading

0 comments on commit 7a2bb2f

Please sign in to comment.