Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#28 from xymyeah/cvrq_nan_fix
Browse files Browse the repository at this point in the history
fix cvrq check nan
  • Loading branch information
xymyeah authored Nov 28, 2023
2 parents 1034434 + 810b0bf commit 5cdc82d
Showing 1 changed file with 204 additions and 18 deletions.
222 changes: 204 additions & 18 deletions paddle/fluid/framework/fleet/box_wrapper_kernel.kps
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ limitations under the License. */
#include "xpu/kernel/xtdk_simd.h"

#ifdef TRACE_PROFILE
// #include "xpu/kernel/xtdk_io.h"
#include "xpu/kernel/xtdk_io.h"
#include <fstream>

// The producer side.
Expand Down Expand Up @@ -70,6 +70,15 @@ struct ExpandPushGetOp {
}
};

struct ExpandPushEmdGetOp {
__device__ float get(float* expand, const int& row,
const int& expand_id,
const int& hidden,
const int& expand_dim) const {
return expand[row * (hidden + expand_dim) + hidden + expand_id];
}
};

template <typename T>
__device__ void set_byfloat(float* dest, const T& val) {
(*reinterpret_cast<T*>(dest)) = val;
Expand Down Expand Up @@ -340,6 +349,152 @@ __global__ void PullCopyNNCross(const TEmbedxOp* op,
}
}

template <typename TEmbedxOp>
__global__ void PullCopyNNCrossWithEmb(const TEmbedxOp* op,
const float scale,
const boxps::FeaturePullOffset* info,
int* total_dims,
unsigned long long* dst_vals,
const int* key2slot,
float* total_values,
const uint32_t* restore_idx,
const int total_length,
const int max_cols_num,
const int hidden_size,
const int expand_embed_dim,
const int pull_float_num,
const int skip_offset,
const int cvm_offset,
const int slot_num) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
return;
}
int thread_id = cluster_id() * ncores + cid;
int nthreads = cluster_num() * ncores;

const int buf_length = 5;
int per_thread_len = roundup_div(total_length, nthreads);
int per_thread_loop_count = roundup_div(per_thread_len, buf_length);
int per_thread_per_loop_len = roundup_div(per_thread_len, per_thread_loop_count);

__local__ float lm_total_values[buf_length * pull_float_num];
__local__ float lm_dst_vals[buf_length * hidden_size];
__local__ float lm_dst_expand_vals[buf_length * (hidden_size + expand_embed_dim)];
__local__ int lm_key2slot[buf_length];
__local__ int lm_total_dims[buf_length];
__local__ uint32_t lm_restore_idx[buf_length];
__local__ boxps::FeaturePullOffset lm_info[1];
__local__ TEmbedxOp lm_op[1];

const int max_slot_num = 1000;
int sm_slot_len = min(max_slot_num, slot_num);
__shared__ uint64_t sm_dst_vals_ptr[max_slot_num];
__shared__ uint64_t sm_dst_expand_vals_ptr[max_slot_num];
for (int i = cid; i < sm_slot_len; i += ncores) {
GM2SM(dst_vals + i, sm_dst_vals_ptr + i, sizeof(uint64_t));
GM2SM(dst_vals + slot_num + i, sm_dst_expand_vals_ptr + i, sizeof(uint64_t));
}
mfence();
xpu_sync_all();

__local__ uint64_t lm_dst_vals_ptr[1];
for(int i=0;i<slot_num;i++) {
if(sm_dst_vals_ptr[i] != 0) {
lm_dst_vals_ptr[0]=sm_dst_vals_ptr[i];
break;
}
}

GM2LM(info, lm_info, sizeof(boxps::FeaturePullOffset));
GM2LM(op, lm_op, sizeof(TEmbedxOp));
for (int i = thread_id; i < per_thread_loop_count * nthreads; i += nthreads) {
int gm_offset = i * per_thread_per_loop_len;
if (gm_offset >= total_length) {
return;
}

int len = min(per_thread_per_loop_len, total_length - gm_offset);
if(restore_idx != nullptr) {
GM2LM(restore_idx + gm_offset, lm_restore_idx, len * sizeof(uint32_t));
}
int pos = (restore_idx != nullptr) ? lm_restore_idx[gm_offset] : gm_offset;
GM2LM(total_values + pos * pull_float_num, lm_total_values, len * pull_float_num * sizeof(float));
GM2LM(total_dims + gm_offset, lm_total_dims, len * sizeof(int));
GM2LM(key2slot + gm_offset, lm_key2slot, len * sizeof(int));

for (int j = 0; j < len; j++) {
// mfence();
// cvm offset
for (int k = 0; k < cvm_offset; ++k) {
//TODO:consider xpu_value[slot_id]==nullptr?
if (sm_dst_vals_ptr[lm_key2slot[j]] != 0) {
lm_dst_vals[j * hidden_size + k] = lm_total_values[j * pull_float_num + lm_info[0].show + skip_offset + k];
}
if (sm_dst_expand_vals_ptr[lm_key2slot[j]] != 0) {
lm_dst_expand_vals[j * (hidden_size + expand_embed_dim) + k] = lm_total_values[j * pull_float_num + lm_info[0].show + skip_offset + k];
}
}

// embedx
// embedx flags + expand flags && *(keys[x] + y) != 0 && *(keys[x] + y)
int embedx_size = *((int *)&(lm_total_values[j * pull_float_num + lm_info[0].embedx_size]));
// int embedx_size = 0;
// TODO: expand_size = expand_embed_dim?
int expand_size = *((int *)&(lm_total_values[j * pull_float_num + lm_info[0].expand_size]));
lm_total_dims[j] = static_cast<int>(embedx_size > 0) | static_cast<int>((expand_size > 0) << 1);

if (sm_dst_vals_ptr[lm_key2slot[j]] != 0) {
for (int k = cvm_offset; k < cvm_offset + embedx_size; ++k) {
lm_op[0].copy(lm_dst_vals + j * hidden_size + k,
lm_total_values + j * pull_float_num + lm_info[0].embedx,
k - cvm_offset,
scale);
}

for (int k = cvm_offset + embedx_size; k < hidden_size; ++k) {
lm_dst_vals[j * hidden_size + k] = 0;
}
}

if (sm_dst_expand_vals_ptr[lm_key2slot[j]] != 0) {
for (int k = cvm_offset; k < cvm_offset + embedx_size; ++k) {
lm_op[0].copy(lm_dst_expand_vals + j * (hidden_size + expand_embed_dim) + k,
lm_total_values + j * pull_float_num + lm_info[0].embedx,
k - cvm_offset,
scale);
}

for (int k = cvm_offset + embedx_size; k < hidden_size; ++k) {
lm_dst_expand_vals[j * (hidden_size + expand_embed_dim) + k] = 0;
}
}

// expand
if (sm_dst_expand_vals_ptr[lm_key2slot[j]] == 0) {
continue;
}

for (int k = hidden_size; k < hidden_size + expand_size; ++k) {
lm_op[0].copy(lm_dst_expand_vals + j * (hidden_size + expand_embed_dim) + k,
lm_total_values + j * pull_float_num + lm_info[0].expand,
k - hidden_size,
scale);
}
for (int k = hidden_size + expand_size; k < max_cols_num; ++k) {
lm_dst_expand_vals[j * (hidden_size + expand_embed_dim) + k] = 0;
}
}
mfence();

LM2GM(lm_total_dims, total_dims + gm_offset, len * sizeof(int));
LM2GM(lm_dst_vals, ((__global_ptr__ float*)lm_dst_vals_ptr[0] + gm_offset * hidden_size), len * hidden_size * sizeof(float));
LM2GM(lm_dst_expand_vals, ((__global_ptr__ float*)lm_dst_vals_ptr[0] + total_length * hidden_size + gm_offset * (hidden_size + expand_embed_dim)), len * (hidden_size + expand_embed_dim) * sizeof(float));
mfence();
}
}

template <typename TEmbedxOp>
inline void FeaturePullCopyNNCross(
const paddle::platform::Place& place,
Expand Down Expand Up @@ -405,9 +560,22 @@ inline void FeaturePullCopyNNCross(
cvm_offset,
slot_num);
} else {
// PullCopyNNCrossWithEmb
// TODO:
CHECK(false) << "PullCopyNNCrossWithEmb not implement";
PullCopyNNCrossWithEmb<TEmbedxOp><<<8, 64, stream>>>(d_op,
scale,
info,
total_dims,
reinterpret_cast<unsigned long long*>(d_xpu_values),
key2slot,
total_values_xpu,
xpu_restore_idx,
total_length,
(hidden_size + expand_embed_dim),
hidden_size,
expand_embed_dim,
pull_float_num,
skip_offset,
cvm_offset,
slot_num);
}
xpu_free(d_xpu_values);
xpu_wait(stream);
Expand Down Expand Up @@ -816,21 +984,18 @@ inline void FeaturePushCopyNNCross(
auto ctx_xpu = static_cast<platform::XPUDeviceContext*>(dev_ctx)->x_context();
auto stream = ctx_xpu->xpu_stream;

auto d_op_tmp = memory::Alloc(place, sizeof(TExpandPushGetOp));
TExpandPushGetOp* d_op = reinterpret_cast<TExpandPushGetOp*>(d_op_tmp->ptr());
memory::Copy(place,
d_op,
platform::CPUPlace(),
op,
sizeof(TExpandPushGetOp));

#ifdef TRACE_PROFILE
TRACE_SCOPE_START("PushCopyNNCross", xpu_wait(stream));
#endif
if (expand_only) {
// TODO:
// if (d_sort_idx != nullptr){
// }
ExpandPushGetOp op;
auto d_op_tmp = memory::Alloc(place, sizeof(ExpandPushGetOp));
ExpandPushGetOp* d_op = reinterpret_cast<ExpandPushGetOp*>(d_op_tmp->ptr());
memory::Copy(place,
d_op,
platform::CPUPlace(),
&op,
sizeof(ExpandPushGetOp));
PushCopyNNCross<TExpandPushGetOp><<<8, 64, stream>>>(d_op,
info,
reinterpret_cast<unsigned long long*>(gm_src),//src
Expand All @@ -848,9 +1013,30 @@ inline void FeaturePushCopyNNCross(
skip_offset,
bs);
} else {
// PullCopyNNCrossWithEmb
// TODO:
CHECK(false) << "PullCopyNNCrossWithEmb not implement";
ExpandPushEmdGetOp op;
auto d_op_tmp = memory::Alloc(place, sizeof(ExpandPushEmdGetOp));
ExpandPushEmdGetOp* d_op = reinterpret_cast<ExpandPushEmdGetOp*>(d_op_tmp->ptr());
memory::Copy(place,
d_op,
platform::CPUPlace(),
&op,
sizeof(ExpandPushEmdGetOp));
PushCopyNNCross<ExpandPushEmdGetOp><<<8, 64, stream>>>(d_op,
info,
reinterpret_cast<unsigned long long*>(gm_src),//src
total_dims,
key2slot,
slot_vector,
slot_inner_offset,
push_grad_values,//dst
total_length,
hidden_size,
expand_embed_dim,
slot_num,
push_float_num,
cvm_offset,
skip_offset,
bs);
}
#ifdef TRACE_PROFILE
xpu_wait(stream);
Expand Down

0 comments on commit 5cdc82d

Please sign in to comment.