Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make the se attn v2 descriptor energy conservative. #2905

Merged
merged 15 commits into from
Oct 11, 2023
Merged
28 changes: 24 additions & 4 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,8 @@ def build(
self.filter_precision,
)
self.negative_mask = -(2 << 32) * (1.0 - self.nmask)
# hard coding the magnitude of attention weight shift
self.smth_attn_w_shift = 20.0
# only used when tensorboard was set as true
tf.summary.histogram("descrpt", self.descrpt)
tf.summary.histogram("rij", self.rij)
Expand Down Expand Up @@ -599,7 +601,9 @@ def build(
)
self.recovered_r = (
tf.reshape(
tf.slice(tf.reshape(self.descrpt, [-1, 4]), [0, 0], [-1, 1]),
tf.slice(
tf.reshape(self.descrpt_reshape, [-1, 4]), [0, 0], [-1, 1]
),
[-1, natoms[0], self.sel_all_a[0]],
)
* self.std_looked_up
Expand Down Expand Up @@ -865,10 +869,26 @@ def _scaled_dot_attn(
save_weights=True,
):
attn = tf.matmul(Q / temperature, K, transpose_b=True)
attn *= self.nmask
attn += self.negative_mask
if self.smooth:
# (nb x nloc) x nsel
nsel = self.sel_all_a[0]
attn = (attn + self.smth_attn_w_shift) * tf.reshape(
self.recovered_switch, [-1, 1, nsel]
) * tf.reshape(
self.recovered_switch, [-1, nsel, 1]
) - self.smth_attn_w_shift
else:
attn *= self.nmask
attn += self.negative_mask
attn = tf.nn.softmax(attn, axis=-1)
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])
if self.smooth:
attn = (
attn
* tf.reshape(self.recovered_switch, [-1, 1, nsel])
* tf.reshape(self.recovered_switch, [-1, nsel, 1])
)
else:
attn *= tf.reshape(self.nmask, [-1, attn.shape[-1], 1])
if save_weights:
self.attn_weight[layer] = attn[0] # atom 0
if dotr:
Expand Down
3 changes: 2 additions & 1 deletion deepmd/op/_tabulate_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _tabulate_fusion_se_atten_grad_cc(op, dy):
op.outputs[0],
is_sorted=op.get_attr("is_sorted"),
)
return [None, None, dy_dx, dy_df, None]
return [None, None, dy_dx, dy_df, dy_dtwo]


@ops.RegisterGradient("TabulateFusionSeAttenGrad")
Expand All @@ -68,6 +68,7 @@ def _tabulate_fusion_se_atten_grad_grad_cc(op, dy, dy_, dy_dtwo):
op.inputs[4],
dy,
dy_,
dy_dtwo,
op.inputs[6],
is_sorted=op.get_attr("is_sorted"),
)
Expand Down
4 changes: 4 additions & 0 deletions source/lib/include/tabulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ void tabulate_fusion_se_a_cpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -38,6 +39,7 @@ void tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down Expand Up @@ -125,6 +127,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -145,6 +148,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down
34 changes: 28 additions & 6 deletions source/lib/src/gpu/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ template <typename FPTYPE, int MTILE, int KTILE>
__global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* em_x,
const FPTYPE* em,
Expand Down Expand Up @@ -307,6 +308,7 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
(var[1] +
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
xx;
FPTYPE oldres = res;
FPTYPE t;
if (enable_se_atten) {
t = two_embed[block_idx * nnei * last_layer_size +
Expand All @@ -330,6 +332,13 @@ __global__ void tabulate_fusion_se_a_grad_fifth_order_polynomial(
xx) *
xx) *
(enable_se_atten ? res * t + res : res);
if (enable_se_atten) {
// from ii to ii + (nnei - breakpoint)
for (int ii2 = ii; ii2 < ii + nnei - breakpoint; ii2++) {
dy_dtwo[block_idx * nnei * last_layer_size + ii2 * last_layer_size +
jj] = oldres * res;
}
}
}
GpuSyncThreads();
for (int kk = 0; kk < MTILE; kk++) {
Expand Down Expand Up @@ -357,6 +366,7 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const FPTYPE lower,
const FPTYPE upper,
const FPTYPE max,
Expand Down Expand Up @@ -404,9 +414,15 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
((FPTYPE)4. * var[4] + (FPTYPE)5. * var[5] * xx) * xx) *
xx) *
xx;
FPTYPE two_grad = 0.;
if (enable_se_atten) {
FPTYPE t = two_embed[block_idx * nnei * last_layer_size +
ii * last_layer_size + thread_idx];
// dz_dy_dtwo * res * em
// res above should be used instead of res + res * t below
two_grad = dz_dy_dtwo[block_idx * nnei * last_layer_size +
ii * last_layer_size + thread_idx] *
res;
res += res * t;
res_grad += res_grad * t;
}
Expand Down Expand Up @@ -434,8 +450,8 @@ __global__ void tabulate_fusion_se_a_grad_grad_fifth_order_polynomial(
for (int kk = 0; kk < MTILE; kk++) {
int em_index = block_idx * nnei * MTILE + ii * MTILE + kk;
iteratorC[kk * last_layer_size + thread_idx] +=
(nnei - breakpoint) *
(em[em_index] * res_grad * dz_xx + dz_dy_dem[em_index] * res);
(nnei - breakpoint) * (em[em_index] * (res_grad * dz_xx + two_grad) +
dz_dy_dem[em_index] * res);
}
mark_table_idx = table_idx;
if (unloop) {
Expand Down Expand Up @@ -764,6 +780,7 @@ void tabulate_fusion_se_a_gpu(FPTYPE* out,
template <typename FPTYPE>
void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -784,9 +801,9 @@ void tabulate_fusion_se_a_grad_gpu(FPTYPE* dy_dem_x,

tabulate_fusion_se_a_grad_fifth_order_polynomial<FPTYPE, MM, KK>
<<<nloc, KK * WARP_SIZE, sizeof(FPTYPE) * MM * last_layer_size>>>(
dy_dem_x, dy_dem, table, em_x, em, two_embed, dy, table_info[0],
table_info[1], table_info[2], table_info[3], table_info[4], nnei,
last_layer_size, is_sorted);
dy_dem_x, dy_dem, dy_dtwo, table, em_x, em, two_embed, dy,
table_info[0], table_info[1], table_info[2], table_info[3],
table_info[4], nnei, last_layer_size, is_sorted);
DPErrcheck(gpuGetLastError());
DPErrcheck(gpuDeviceSynchronize());
}
Expand All @@ -800,6 +817,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand All @@ -812,7 +830,7 @@ void tabulate_fusion_se_a_grad_grad_gpu(FPTYPE* dz_dy,
DPErrcheck(gpuMemset(dz_dy, 0, sizeof(FPTYPE) * nloc * 4 * last_layer_size));
tabulate_fusion_se_a_grad_grad_fifth_order_polynomial<FPTYPE, MM, KK>
<<<nloc, last_layer_size, sizeof(FPTYPE) * MM * last_layer_size>>>(
dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem,
dz_dy, table, em_x, em, two_embed, dz_dy_dem_x, dz_dy_dem, dz_dy_dtwo,
table_info[0], table_info[1], table_info[2], table_info[3],
table_info[4], nnei, last_layer_size, is_sorted);
DPErrcheck(gpuGetLastError());
Expand Down Expand Up @@ -990,6 +1008,7 @@ template void tabulate_fusion_se_a_gpu<double>(double* out,
const bool is_sorted);
template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
float* dy_dem,
float* dy_dtwo,
const float* table,
const float* table_info,
const float* em_x,
Expand All @@ -1002,6 +1021,7 @@ template void tabulate_fusion_se_a_grad_gpu<float>(float* dy_dem_x,
const bool is_sorted);
template void tabulate_fusion_se_a_grad_gpu<double>(double* dy_dem_x,
double* dy_dem,
double* dy_dtwo,
const double* table,
const double* table_info,
const double* em_x,
Expand All @@ -1021,6 +1041,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<float>(
const float* two_embed,
const float* dz_dy_dem_x,
const float* dz_dy_dem,
const float* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand All @@ -1034,6 +1055,7 @@ template void tabulate_fusion_se_a_grad_grad_gpu<double>(
const double* two_embed,
const double* dz_dy_dem_x,
const double* dz_dy_dem,
const double* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down
52 changes: 42 additions & 10 deletions source/lib/src/tabulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ void deepmd::tabulate_fusion_se_a_cpu(FPTYPE* out,
template <typename FPTYPE>
void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE* dy_dem,
FPTYPE* dy_dtwo,
const FPTYPE* table,
const FPTYPE* table_info,
const FPTYPE* em_x,
Expand All @@ -171,6 +172,9 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
bool enable_se_atten = two_embed != nullptr;
memset(dy_dem_x, 0, sizeof(FPTYPE) * nloc * nnei);
memset(dy_dem, 0, sizeof(FPTYPE) * nloc * nnei * 4);
if (enable_se_atten) {
memset(dy_dtwo, 0, sizeof(FPTYPE) * nloc * nnei * last_layer_size);
}
FPTYPE const lower = table_info[0];
FPTYPE const upper = table_info[1];
FPTYPE const _max = table_info[2];
Expand Down Expand Up @@ -212,25 +216,38 @@ void deepmd::tabulate_fusion_se_a_grad_cpu(FPTYPE* dy_dem_x,
a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;
FPTYPE g =
(a1 + (2 * a2 + (3 * a3 + (4 * a4 + 5 * a5 * xx) * xx) * xx) * xx);
FPTYPE resold = res;
if (enable_se_atten) {
FPTYPE t = two_embed[ii * nnei * last_layer_size +
jj * last_layer_size + kk];
res = res * t + res;
g += t * g;
}

FPTYPE dotllrr = dot(ll, rr);
if (unloop) {
grad += g * dot(ll, rr) * (nnei - jj);
grad += g * dotllrr * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0] * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1] * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2] * (nnei - jj);
dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3] * (nnei - jj);
if (enable_se_atten) {
// fill from jj to nnei
for (int jj2 = jj; jj2 < nnei; jj2++) {
dy_dtwo[ii * nnei * last_layer_size + jj2 * last_layer_size +
kk] += resold * dotllrr;
}
}
} else {
grad += g * dot(ll, rr);
grad += g * dotllrr;
dy_dem[ii * nnei * 4 + jj * 4 + 0] += res * rr[0];
dy_dem[ii * nnei * 4 + jj * 4 + 1] += res * rr[1];
dy_dem[ii * nnei * 4 + jj * 4 + 2] += res * rr[2];
dy_dem[ii * nnei * 4 + jj * 4 + 3] += res * rr[3];
if (enable_se_atten) {
dy_dtwo[ii * nnei * last_layer_size + jj * last_layer_size + kk] +=
resold * dotllrr;
}
}
}
dy_dem_x[ii * nnei + jj] = grad;
Expand All @@ -250,6 +267,7 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
const FPTYPE* two_embed,
const FPTYPE* dz_dy_dem_x,
const FPTYPE* dz_dy_dem,
const FPTYPE* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down Expand Up @@ -300,9 +318,15 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
((FPTYPE)3. * a3 + ((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
xx) *
xx;
FPTYPE two_grad = 0.;
if (enable_se_atten) {
FPTYPE t = two_embed[ii * nnei * last_layer_size +
jj * last_layer_size + kk];
// dz_dy_dtwo * var * ll
// var above should be used instead of var + var * t below
two_grad = dz_dy_dtwo[ii * nnei * last_layer_size +
jj * last_layer_size + kk] *
var;
var += var * t;
var_grad += var_grad * t;
}
Expand All @@ -329,22 +353,26 @@ void deepmd::tabulate_fusion_se_a_grad_grad_cpu(FPTYPE* dz_dy,
*/
if (unloop) {
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[0] + dz_xx * var_grad * ll[0]);
(nnei - jj) *
(var * hh[0] + (dz_xx * var_grad + two_grad) * ll[0]);
dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[1] + dz_xx * var_grad * ll[1]);
(nnei - jj) *
(var * hh[1] + (dz_xx * var_grad + two_grad) * ll[1]);
dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[2] + dz_xx * var_grad * ll[2]);
(nnei - jj) *
(var * hh[2] + (dz_xx * var_grad + two_grad) * ll[2]);
dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] +=
(nnei - jj) * (var * hh[3] + dz_xx * var_grad * ll[3]);
(nnei - jj) *
(var * hh[3] + (dz_xx * var_grad + two_grad) * ll[3]);
} else {
dz_dy[ii * last_layer_size * 4 + 0 * last_layer_size + kk] +=
var * hh[0] + dz_xx * var_grad * ll[0];
var * hh[0] + (dz_xx * var_grad + two_grad) * ll[0];
dz_dy[ii * last_layer_size * 4 + 1 * last_layer_size + kk] +=
var * hh[1] + dz_xx * var_grad * ll[1];
var * hh[1] + (dz_xx * var_grad + two_grad) * ll[1];
dz_dy[ii * last_layer_size * 4 + 2 * last_layer_size + kk] +=
var * hh[2] + dz_xx * var_grad * ll[2];
var * hh[2] + (dz_xx * var_grad + two_grad) * ll[2];
dz_dy[ii * last_layer_size * 4 + 3 * last_layer_size + kk] +=
var * hh[3] + dz_xx * var_grad * ll[3];
var * hh[3] + (dz_xx * var_grad + two_grad) * ll[3];
}
}
if (unloop) {
Expand Down Expand Up @@ -660,6 +688,7 @@ template void deepmd::tabulate_fusion_se_a_cpu<double>(
template void deepmd::tabulate_fusion_se_a_grad_cpu<float>(
float* dy_dem_x,
float* dy_dem,
float* dy_dtwo,
const float* table,
const float* table_info,
const float* em_x,
Expand All @@ -673,6 +702,7 @@ template void deepmd::tabulate_fusion_se_a_grad_cpu<float>(
template void deepmd::tabulate_fusion_se_a_grad_cpu<double>(
double* dy_dem_x,
double* dy_dem,
double* dy_dtwo,
const double* table,
const double* table_info,
const double* em_x,
Expand All @@ -692,6 +722,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu<float>(
const float* two_embed,
const float* dz_dy_dem_x,
const float* dz_dy_dem,
const float* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand All @@ -705,6 +736,7 @@ template void deepmd::tabulate_fusion_se_a_grad_grad_cpu<double>(
const double* two_embed,
const double* dz_dy_dem_x,
const double* dz_dy_dem,
const double* dz_dy_dtwo,
const int nloc,
const int nnei,
const int last_layer_size,
Expand Down
Loading