Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix se_e3 tabulate op #2552

Merged
merged 2 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 1 addition & 27 deletions source/lib/src/cuda/tabulate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -337,20 +337,11 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial(

FPTYPE sum = (FPTYPE)0.;
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago = __shfl_sync(
0xffffffff,
em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
int breakpoint = nnei_j - 1;
bool unloop = false;
FPTYPE var[6];
int mark_table_idx = -1;
for (int jj = 0; jj < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
if (xx == ago) {
unloop = true;
breakpoint = jj;
}
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
if (table_idx != mark_table_idx) {
Expand All @@ -363,9 +354,8 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial(
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
xx;

sum += (nnei_j - breakpoint) * tmp * res;
sum += tmp * res;
mark_table_idx = table_idx;
if (unloop) break;
}
}
out[block_idx * last_layer_size + thread_idx] = sum;
Expand Down Expand Up @@ -399,16 +389,9 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial(
__syncthreads();

for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago = __shfl_sync(
0xffffffff,
em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
for (int jj = warp_idx; jj < nnei_j; jj += KTILE) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
FPTYPE sum = (FPTYPE)0.;
Expand Down Expand Up @@ -438,7 +421,6 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial(
dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = sum;
dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = Csub;
}
if (unloop) break;
}
}
}
Expand All @@ -464,10 +446,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(

FPTYPE sum = (FPTYPE)0.;
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago = __shfl_sync(
0xffffffff,
em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
int mark_table_idx = -1;
for (int jj = 0; ii < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
Expand All @@ -476,9 +454,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE var[6];
if (ago == xx) {
unloop = true;
}

int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
Expand All @@ -498,7 +473,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(

sum += (tmp * res_grad * dz_xx + dz_em * res);
mark_table_idx = table_idx;
if (unloop) break;
}
}
dz_dy[block_idx * last_layer_size + thread_idx] = sum;
Expand Down
21 changes: 1 addition & 20 deletions source/lib/src/rocm/tabulate.hip.cu
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,9 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial(

FPTYPE sum = (FPTYPE)0.;
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago =
__shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
int breakpoint = nnei_j - 1;
bool unloop = false;
for (int jj = 0; jj < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
if (xx == ago) {
unloop = true;
breakpoint = jj;
}
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
FPTYPE var[6];
Expand All @@ -338,8 +330,7 @@ __global__ void tabulate_fusion_se_t_fifth_order_polynomial(
(var[2] + (var[3] + (var[4] + var[5] * xx) * xx) * xx) * xx) *
xx;

sum += (nnei_j - breakpoint) * tmp * res;
if (unloop) break;
sum += tmp * res;
}
}
out[block_idx * last_layer_size + thread_idx] = sum;
Expand Down Expand Up @@ -375,13 +366,9 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial(
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago =
__shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
for (int jj = warp_idx; jj < nnei_j; jj += KTILE) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
FPTYPE sum = (FPTYPE)0.;
Expand Down Expand Up @@ -417,7 +404,6 @@ __global__ void tabulate_fusion_se_t_grad_fifth_order_polynomial(
dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = sum;
dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj] = Csub;
}
if (unloop) break;
}
}
}
Expand Down Expand Up @@ -445,17 +431,13 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
for (int ii = 0; ii < nnei_i; ii++) {
FPTYPE ago =
__shfl(em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + nnei_j - 1], 0);
bool unloop = false;
for (int jj = 0; ii < nnei_j; jj++) {
FPTYPE xx = em_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE tmp = xx;
FPTYPE dz_xx =
dz_dy_dem_x[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE dz_em = dz_dy_dem[block_idx * nnei_i * nnei_j + ii * nnei_j + jj];
FPTYPE var[6];
if (ago == xx) {
unloop = true;
}

int table_idx = 0;
locate_xx_se_t(xx, table_idx, lower, upper, -max, max, stride0, stride1);
Expand All @@ -478,7 +460,6 @@ __global__ void tabulate_fusion_se_t_grad_grad_fifth_order_polynomial(
xx;

sum += (tmp * res_grad * dz_xx + dz_em * res);
if (unloop) break;
}
}
dz_dy[block_idx * last_layer_size + thread_idx] = sum;
Expand Down
50 changes: 9 additions & 41 deletions source/lib/src/tabulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,10 @@ void deepmd::tabulate_fusion_se_t_cpu(FPTYPE* out,
#pragma omp parallel for
for (int ii = 0; ii < nloc; ii++) {
for (int jj = 0; jj < nnei_i; jj++) {
FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1];
bool unloop = false;
// unloop not work as em_x is not sorted
for (int kk = 0; kk < nnei_j; kk++) {
FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];
FPTYPE ll = xx;
if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx,
table_idx);
Expand All @@ -342,13 +338,8 @@ void deepmd::tabulate_fusion_se_t_cpu(FPTYPE* out,
FPTYPE a5 = table[table_idx * last_layer_size * 6 + 6 * mm + 5];
FPTYPE var =
a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;
if (unloop) {
out[ii * last_layer_size + mm] += (nnei_j - kk) * var * ll;
} else {
out[ii * last_layer_size + mm] += var * ll;
}
out[ii * last_layer_size + mm] += var * ll;
}
if (unloop) break;
}
}
}
Expand Down Expand Up @@ -380,15 +371,10 @@ void deepmd::tabulate_fusion_se_t_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE ll = (FPTYPE)0.;
FPTYPE rr = (FPTYPE)0.;
for (int jj = 0; jj < nnei_i; jj++) {
FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1];
bool unloop = false;
for (int kk = 0; kk < nnei_j; kk++) {
// construct the dy/dx
FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];
ll = xx;
if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx,
table_idx);
Expand All @@ -404,27 +390,15 @@ void deepmd::tabulate_fusion_se_t_grad_cpu(FPTYPE* dy_dem_x,
FPTYPE res =
a0 + (a1 + (a2 + (a3 + (a4 + a5 * xx) * xx) * xx) * xx) * xx;

if (unloop) {
grad += (a1 + ((FPTYPE)2. * a2 +
((FPTYPE)3. * a3 +
((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
xx) *
xx) *
ll * rr * (nnei_j - kk);
dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] +=
res * rr * (nnei_j - kk);
} else {
grad += (a1 + ((FPTYPE)2. * a2 +
((FPTYPE)3. * a3 +
((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
xx) *
xx) *
ll * rr;
dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] += res * rr;
}
grad += (a1 + ((FPTYPE)2. * a2 +
((FPTYPE)3. * a3 +
((FPTYPE)4. * a4 + (FPTYPE)5. * a5 * xx) * xx) *
xx) *
xx) *
ll * rr;
dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk] += res * rr;
}
dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk] = grad;
if (unloop) break;
}
}
}
Expand Down Expand Up @@ -453,17 +427,12 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(FPTYPE* dz_dy,
#pragma omp parallel for
for (int ii = 0; ii < nloc; ii++) {
for (int jj = 0; jj < nnei_i; jj++) {
FPTYPE ago = em_x[ii * nnei_i * nnei_j + jj * nnei_j + nnei_j - 1];
bool unloop = false;
for (int kk = 0; kk < nnei_j; kk++) {
FPTYPE xx = em_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];
FPTYPE tmp = xx;
FPTYPE dz_em = dz_dy_dem[ii * nnei_i * nnei_j + jj * nnei_j + kk];
FPTYPE dz_xx = dz_dy_dem_x[ii * nnei_i * nnei_j + jj * nnei_j + kk];

if (ago == xx) {
unloop = true;
}
int table_idx = 0;
locate_xx_se_t(lower, upper, -_max, _max, stride0, stride1, xx,
table_idx);
Expand All @@ -486,7 +455,6 @@ void deepmd::tabulate_fusion_se_t_grad_grad_cpu(FPTYPE* dz_dy,
dz_dy[ii * last_layer_size + mm] +=
var * dz_em + dz_xx * var_grad * tmp;
}
if (unloop) break;
}
}
}
Expand Down
66 changes: 44 additions & 22 deletions source/tests/test_model_compression_se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,28 @@ def _init_models():
INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models()


def tearDownModule():
_file_delete(INPUT)
_file_delete(FROZEN_MODEL)
_file_delete(COMPRESSED_MODEL)
_file_delete("out.json")
_file_delete("compress.json")
_file_delete("checkpoint")
_file_delete("model.ckpt.meta")
_file_delete("model.ckpt.index")
_file_delete("model.ckpt.data-00000-of-00001")
_file_delete("model.ckpt-100.meta")
_file_delete("model.ckpt-100.index")
_file_delete("model.ckpt-100.data-00000-of-00001")
_file_delete("model-compression/checkpoint")
_file_delete("model-compression/model.ckpt.meta")
_file_delete("model-compression/model.ckpt.index")
_file_delete("model-compression/model.ckpt.data-00000-of-00001")
_file_delete("model-compression")
_file_delete("input_v2_compat.json")
_file_delete("lcurve.out")


class TestDeepPotAPBC(unittest.TestCase):
@classmethod
def setUpClass(self):
Expand Down Expand Up @@ -444,28 +466,6 @@ def setUpClass(self):
self.atype = [0, 1, 1, 0, 1, 1]
self.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])

@classmethod
def tearDownClass(self):
_file_delete(INPUT)
_file_delete(FROZEN_MODEL)
_file_delete(COMPRESSED_MODEL)
_file_delete("out.json")
_file_delete("compress.json")
_file_delete("checkpoint")
_file_delete("model.ckpt.meta")
_file_delete("model.ckpt.index")
_file_delete("model.ckpt.data-00000-of-00001")
_file_delete("model.ckpt-100.meta")
_file_delete("model.ckpt-100.index")
_file_delete("model.ckpt-100.data-00000-of-00001")
_file_delete("model-compression/checkpoint")
_file_delete("model-compression/model.ckpt.meta")
_file_delete("model-compression/model.ckpt.index")
_file_delete("model-compression/model.ckpt.data-00000-of-00001")
_file_delete("model-compression")
_file_delete("input_v2_compat.json")
_file_delete("lcurve.out")

def test_attrs(self):
self.assertEqual(self.dp_original.get_ntypes(), 2)
self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places)
Expand Down Expand Up @@ -558,3 +558,25 @@ def test_2frame_atm(self):
np.testing.assert_almost_equal(av0, av1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)


class TestDeepPotAPBC2(TestDeepPotAPBC):
@classmethod
def setUpClass(self):
self.dp_original = DeepPot(FROZEN_MODEL)
self.dp_compressed = DeepPot(COMPRESSED_MODEL)
self.coords = np.array(
[
0.0,
0.0,
0.0,
2.0,
0.0,
0.0,
0.0,
2.0,
0.0,
]
)
self.atype = [0, 0, 0]
self.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])