Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(cc): fix message passing when nloc is 0 #4021

Merged
merged 20 commits into from
Jul 26, 2024
Merged
43 changes: 6 additions & 37 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
row_tensors.push_back(row_tensor);
}

torch::Tensor tensor = torch::cat(row_tensors, 0).unsqueeze(0);
torch::Tensor tensor;
if (row_tensors.size() > 0) {
tensor = torch::cat(row_tensors, 0).unsqueeze(0);
} else {
tensor = torch::empty({1, 0, 0}, torch::kInt32);

Check warning on line 45 in source/api_cc/src/DeepPotPT.cc

View check run for this annotation

Codecov / codecov/patch

source/api_cc/src/DeepPotPT.cc#L45

Added line #L45 was not covered by tests
}
return tensor;
}
DeepPotPT::DeepPotPT() : inited(false) {}
Expand Down Expand Up @@ -152,25 +157,6 @@
nghost, ntypes, 1, daparam, nall, aparam_nall);
int nloc = nall_real - nghost_real;
int nframes = 1;
// TODO: dpa2 model may need a fake communication op to deal with nloc == 0.
// this should be fixed after wrapping comm op as a pure c++ implementation.
if (nloc == 0) {
// no backward map needed
ener.resize(nframes);
// dforce of size nall * 3
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
fill(force.begin(), force.end(), (VALUETYPE)0.0);
// dvirial of size 9
virial.resize(static_cast<size_t>(nframes) * 9);
fill(virial.begin(), virial.end(), (VALUETYPE)0.0);
// datom_energy_ of size nall
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
fill(atom_energy.begin(), atom_energy.end(), (VALUETYPE)0.0);
// datom_virial_ of size nall * 9
atom_virial.resize(static_cast<size_t>(nframes) * fwd_map.size() * 9);
fill(atom_virial.begin(), atom_virial.end(), (VALUETYPE)0.0);
return;
}
std::vector<VALUETYPE> coord_wrapped = dcoord;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, nall_real, 3}, options)
Expand Down Expand Up @@ -345,23 +331,6 @@
}
auto int_options = torch::TensorOptions().dtype(torch::kInt64);
int nframes = 1;
if (natoms == 0) {
// no backward map needed
ener.resize(nframes);
// dforce of size nall * 3
force.resize(static_cast<size_t>(nframes) * natoms * 3);
fill(force.begin(), force.end(), (VALUETYPE)0.0);
// dvirial of size 9
virial.resize(static_cast<size_t>(nframes) * 9);
fill(virial.begin(), virial.end(), (VALUETYPE)0.0);
// datom_energy_ of size nall
atom_energy.resize(static_cast<size_t>(nframes) * natoms);
fill(atom_energy.begin(), atom_energy.end(), (VALUETYPE)0.0);
// datom_virial_ of size nall * 9
atom_virial.resize(static_cast<size_t>(nframes) * natoms * 9);
fill(atom_virial.begin(), atom_virial.end(), (VALUETYPE)0.0);
return;
}
std::vector<torch::jit::IValue> inputs;
at::Tensor coord_wrapped_Tensor =
torch::from_blob(coord_wrapped.data(), {1, natoms, 3}, options)
Expand Down
2 changes: 1 addition & 1 deletion source/lmp/tests/test_lammps_dpa_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def test_pair_deepmd_si(lammps_si):
)
@pytest.mark.parametrize(
("balance_args",),
[(["--balance"],)],
[(["--balance"],), ([],)],
)
def test_pair_deepmd_mpi(balance_args: list):
with tempfile.NamedTemporaryFile() as f:
Expand Down
2 changes: 1 addition & 1 deletion source/lmp/tests/test_lammps_dpa_sel_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def test_pair_deepmd_si(lammps_si):
)
@pytest.mark.parametrize(
("balance_args",),
[(["--balance"],)],
[(["--balance"],), ([],)],
)
def test_pair_deepmd_mpi(balance_args: list):
with tempfile.NamedTemporaryFile() as f:
Expand Down
Binary file modified source/tests/infer/deeppot_dpa.pth
Binary file not shown.
Binary file modified source/tests/infer/deeppot_dpa_sel.pth
Binary file not shown.
Binary file modified source/tests/infer/deeppot_sea.pth
Binary file not shown.
Binary file modified source/tests/infer/fparam_aparam.pth
Binary file not shown.