Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mfence for XPU2 KP #44258

Merged
merged 4 commits into from
Jul 19, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ __device__ __forceinline__ void WriteData(T _global_ptr_* dst,
T* src,
int num) {
if (num > 0) {
mfence_local();
LM2GM(src, dst, num * sizeof(T));
}
}
Expand Down Expand Up @@ -370,6 +371,7 @@ __device__ __inline__ void ReadData(Ty* dst,
__local__ Tx in_temp[1];
// Each branch is added for better performance
if (NX == 1 && NY == 1) { // for NX == 1 and NY == 1
mfence_local();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个mfence不需要

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

if (IsBoundary) {
if (left_size_nx > 0) {
GM2LM(src + thread_offset, in_temp, sizeof(Tx));
Expand All @@ -387,6 +389,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break;
}
}
mfence_local();
GM2LM(src + thread_offset + idy * stride_ny, in_temp, sizeof(Tx));
dst[idy] = static_cast<Ty>(in_temp[0]);
}
Expand All @@ -398,6 +401,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break;
}
}
mfence_local();
GM2LM(src + thread_offset + idx * stride_nx, in_temp, sizeof(Tx));
dst[idx] = static_cast<Ty>(in_temp[0]);
}
Expand All @@ -412,6 +416,7 @@ __device__ __inline__ void ReadData(Ty* dst,
}
}
int fix = thread_offset + idx * stride_nx + idy * stride_ny;
mfence_local();
GM2LM(src + fix, in_temp, sizeof(Tx));
dst[idy * NX + idx] = static_cast<Ty>(in_temp[0]);
}
Expand Down Expand Up @@ -484,14 +489,13 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ __inline__ void ReadData(T* dst,
const T _global_ptr_* src,
int num) {
mfence_local();
int thread_offset = core_id() * NX;
__local__ T in_temp[1];
if (IsBoundary) { // core_num() * NX > num
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
dst[idx] = in_temp[0];
GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
}
Copy link
Contributor

@tiancaitzp tiancaitzp Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

402行, in_temp在for循环中,且看代码NX应该有可能>1,那么在403行scalar read之后,下一次循环则发生GM2LM, 所以应该在402行之前是否应该mfence一下

Copy link
Contributor

@tiancaitzp tiancaitzp Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

391,416, 494, 515, 571,627, 720行的in_temp看着也是同样,最好用模拟器的mfence检查工具跑一下,这样最保险

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加

}
} else { // core_num() * NX < num
Expand All @@ -505,13 +509,12 @@ __device__ __inline__ void ReadData(T* dst,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
mfence_local();
if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
dst[idx] = in_temp[0];
GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
}
}
} else { // core_num() * read_lens < num
Expand Down Expand Up @@ -607,8 +610,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
int stride_ny) {
uint32_t thread_offset = block_offset + core_id();
uint32_t index_src = 0;
__local__ T in_temp[1];

mfence_local();
#pragma unroll
for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
Expand All @@ -621,8 +623,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
}
}
index_src = config(index_output);
GM2LM(src + index_src, in_temp, sizeof(T));
dst[nx + ny * NX] = in_temp[0];
GM2LM(src + index_src, dst + nx + ny * Nx, sizeof(T));
}
}
}
Expand Down Expand Up @@ -698,8 +699,10 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[ny] = static_cast<Ty>(func(in_temp[0]));

thread_offset += stride_ny;
}
} else {
Expand All @@ -714,6 +717,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[nx + ny * NX] = static_cast<Ty>(func(in_temp[0]));
thread_offset += stride_ny;
Expand Down Expand Up @@ -749,37 +753,34 @@ __device__ void WriteData(T _global_ptr_* dst,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
mfence_local();

if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
mfence();
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * read_lens < num
mfence();
LM2GM(src, dst + thread_offset, read_lens * sizeof(T));
}
}

template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
int thread_offset = core_id() * NX;
__local__ T in_temp[1];
mfence_local();

if (IsBoundary) { // core_num() * NX > num
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * NX < num
mfence_local();
LM2GM(src, dst + thread_offset, NX * sizeof(T));
}
}
Expand Down Expand Up @@ -831,10 +832,12 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
if (IsBoundary) {
if (left_size_nx > 0) {
in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
}
} else {
in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
}
} else if (NX == 1) {
Expand All @@ -847,6 +850,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}

in_temp[0] = static_cast<Ty>(src[idy]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty));
}
} else if (NY == 1) { // for NY == 1 and NX != 1
Expand All @@ -859,6 +863,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}

in_temp[0] = static_cast<Ty>(src[idx]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty));
}
} else { // for NX != 1 and NY != 1
Expand All @@ -877,6 +882,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}
}
in_temp[0] = static_cast<Ty>(src[idx + idy * NX]);
mfence_local();
LM2GM(in_temp,
dst + thread_offset + idx * stride_nx + idy * stride_ny,
sizeof(Ty));
Expand Down Expand Up @@ -1029,6 +1035,7 @@ __device__ __inline__ void ReadDataBc1NMn(
for (int i = 0; i < last_col; i++) {
dst[i] = in_temp;
}
mfence_local();
GM2LM(src + index_base + 1, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
Expand Down Expand Up @@ -1083,6 +1090,7 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
} else {
next_part_index = 0;
}
mfence_local();
GM2LM(src + next_part_index, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
Expand Down Expand Up @@ -1169,6 +1177,7 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
if (index_src >= index_base && index_src < index_base + cache_size) {
in_temp = src_temp[index_src - index_base];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在1040行对in_temp发生了scalar read,注意一下1042行,这里GM2LM之前需要mfence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加

} else {
mfence_local();
GM2LM(src + index_src, &in_temp, sizeof(T));
}
dst[nx] = in_temp;
Expand Down