Skip to content

Commit

Permalink
4.16 update
Browse files Browse the repository at this point in the history
  • Loading branch information
WorgenZhang committed Apr 16, 2022
1 parent e4513bf commit 3f4fe14
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 19 deletions.
33 changes: 15 additions & 18 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_kernel.kps
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ __device__ void update_lr(float& w, float& g2sum, float g, // NOLINT

double add_g2sum = 0;
double ratio = local_learning_rate *
sqrt(local_initial_g2sum / (local_initial_g2sum + *g2sum));
sqrt(local_initial_g2sum / (local_initial_g2sum + g2sum));
double scaled_grad = g / scale;

w += scaled_grad * ratio;

if (*w < local_min_bound) *w = local_min_bound;
if (*w > local_max_bound) *w = local_max_bound;
if (w < local_min_bound) w = local_min_bound;
if (w > local_max_bound) w = local_max_bound;

add_g2sum += scaled_grad * scaled_grad;

Expand All @@ -82,7 +82,7 @@ __device__ void update_mf(int n, float* w, float& g2sum, const float* g,
double add_g2sum = 0;
double ratio =
local_mf_learning_rate *
sqrt(local_mf_initial_g2sum / (local_mf_initial_g2sum + *g2sum));
sqrt(local_mf_initial_g2sum / (local_mf_initial_g2sum + g2sum));
for (int i = 0; i < n; ++i) {
double scaled_grad = g[i] / scale;
w[i] += scaled_grad * ratio;
Expand All @@ -95,7 +95,7 @@ __device__ void update_mf(int n, float* w, float& g2sum, const float* g,
g2sum += add_g2sum / n;
}

__device__ void xpu_rand_uniform(float* ret_val) { *ret_val = 0.1; }
__device__ float xpu_rand_uniform() { return 0.1; }

template <typename ValType, typename GradType>
__device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
Expand All @@ -114,32 +114,29 @@ __device__ void update_value(ValType& val, const GradType& grad) { // NOLINT
GM2LM(optimizer_config::mf_create_thresholds, &local_mf_create_thresholds,
sizeof(float));

(*val).delta_score += local_nonclk_coeff * ((*grad).show - (*grad).clk) +
local_clk_coeff * (*grad).clk;
val.delta_score +=
local_nonclk_coeff * (grad.show - grad.clk) + local_clk_coeff * grad.clk;

update_lr(val.lr, val.lr_g2sum, grad.lr_g, grad.show);

if ((*val).mf_size == 0) {
if (val.mf_size == 0) {
if (local_mf_create_thresholds <=
local_nonclk_coeff * ((*val).show - (*val).clk) +
local_clk_coeff * (*val).clk) {
(*val).mf_size = MF_DIM + 1;
(*val).mf[0] = 0;
local_nonclk_coeff * (val.show - val.clk) + local_clk_coeff * val.clk) {
val.mf_size = MF_DIM + 1;
val.mf[0] = 0;

float ret_val;
xpu_rand_uniform(&ret_val);
for (int i = 0; i < MF_DIM; ++i) {
(*val).mf[i + 1] = (ret_val)*local_mf_initial_range;
val.mf[i + 1] = (xpu_rand_uniform()) * local_mf_initial_range;
}
}
} else {
update_mf(MF_DIM, &(*val).mf[1], &(*val).mf[0], (*grad).mf_g, (*grad).show);
update_mf(MF_DIM, &val.mf[1], val.mf[0], grad.mf_g, grad.show);
}
}

template <typename KeyType, typename ValType, typename Table>
__global__ void insert_kernel(Table* table, const KeyType* keys,
const ValType* vals, long long len) {
__global__ void insert_kernel(Table* table, const KeyType* const keys,
const ValType* const vals, long long len) {
int cid = core_id();
int ncores = core_num();
if (cid >= ncores) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ void HeterCommKernel::calc_shard_index(KeyType* d_keys, long long len,
T* shard_index, int total_devs,
const StreamType& stream) {
calc_shard_index_kernel<KeyType, T><<<4, 64, stream>>>(
d_keys, len, shard_index, total_xpu);
d_keys, len, shard_index, total_devs);
}

template <typename KeyType, typename T, typename StreamType>
Expand Down

1 comment on commit 3f4fe14

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 3f4fe14 Apr 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #40991 Commit ID: 3f4fe14 contains failed CI.

🔹 Failed: PR-CI-APPROVAL

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Static-Check

Unknown Failed
Unknown Failed

Please sign in to comment.