Skip to content

Commit

Permalink
Merge branch 'devel' into devel-loss-plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
ChiahsinChu authored Oct 24, 2024
2 parents 9f6e269 + 0f817e1 commit 868ffa4
Show file tree
Hide file tree
Showing 20 changed files with 112 additions and 85 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build_wheel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ jobs:
- uses: actions/download-artifact@v4
with:
path: source/install/docker/dist
pattern: cibw-*-manylinux_x86_64-cu${{ matrix.cuda_version }}*
merge-multiple: true
- name: Log in to the Container registry
uses: docker/login-action@v3
Expand Down Expand Up @@ -180,6 +181,7 @@ jobs:
- uses: actions/download-artifact@v4
with:
path: dist/packages
pattern: cibw-*
merge-multiple: true
- uses: actions/setup-python@v5
name: Install Python
Expand Down
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ repos:
- id: ruff-format
exclude: ^source/3rdparty
types_or: [python, pyi, jupyter]
- repo: https://github.com/pycqa/flake8
# flake8 cannot autofix
rev: "7.1.1"
hooks:
- id: flake8
additional_dependencies:
- torchfix==0.6.0
- flake8-pyproject==1.2.3
# numpydoc
- repo: https://github.com/Carreau/velin
rev: 0.0.12
Expand Down
2 changes: 1 addition & 1 deletion deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def xp_take_along_axis(arr, indices, axis):
else:
indices = xp.reshape(indices, (0, 0))

offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis]
indices = xp.reshape(offset + indices, (-1,))

out = xp.take(arr, indices)
Expand Down
3 changes: 1 addition & 2 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def get_trainer(
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
local_rank = int(local_rank)
assert dist.is_nccl_available()
dist.init_process_group(backend="nccl")
dist.init_process_group(backend="cuda:nccl,cpu:gloo")

def prepare_trainer_input_single(
model_params_single, data_dict_single, rank=0, seed=None
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __init__(
super().__init__(type_map, **kwargs)
super().init_out_stat()
self.tab_file = tab_file
self.rcut = rcut
self.tab = self._set_pairtab(tab_file, rcut)
self.rcut = float(rcut)
self.tab = self._set_pairtab(tab_file, self.rcut)

self.type_map = type_map
self.ntypes = len(type_map)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,8 +605,8 @@ def __init__(
):
super().__init__()
self.epsilon = 1e-4 # protection of 1./nnei
self.rcut = rcut
self.rcut_smth = rcut_smth
self.rcut = float(rcut)
self.rcut_smth = float(rcut_smth)
self.ntypes = ntypes
sel = [sel] if isinstance(sel, int) else sel
self.nnei = sum(sel)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def __init__(
Random seed for parameter initialization.
"""
super().__init__()
self.rcut = rcut
self.rcut_smth = rcut_smth
self.rcut = float(rcut)
self.rcut_smth = float(rcut_smth)
self.ntypes = ntypes
self.nlayers = nlayers
sel = [sel] if isinstance(sel, int) else sel
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ def __init__(
- axis_neuron: Number of columns of the sub-matrix of the embedding matrix.
"""
super().__init__()
self.rcut = rcut
self.rcut_smth = rcut_smth
self.rcut = float(rcut)
self.rcut_smth = float(rcut_smth)
self.neuron = neuron
self.filter_neuron = self.neuron
self.axis_neuron = axis_neuron
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def __init__(
"""
super().__init__()
del type
self.rcut = rcut
self.rcut_smth = rcut_smth
self.rcut = float(rcut)
self.rcut_smth = float(rcut_smth)
self.neuron = neuron
self.filter_neuron = self.neuron
self.axis_neuron = axis_neuron
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __init__(
**kwargs,
):
super().__init__()
self.rcut = rcut
self.rcut_smth = rcut_smth
self.rcut = float(rcut)
self.rcut_smth = float(rcut_smth)
self.neuron = neuron
self.filter_neuron = self.neuron
self.set_davg_zero = set_davg_zero
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,8 @@ def __init__(
Random seed for initializing the network parameters.
"""
super().__init__()
self.rcut = rcut
self.rcut_smth = rcut_smth
self.rcut = float(rcut)
self.rcut_smth = float(rcut_smth)
self.neuron = neuron
self.filter_neuron = self.neuron
self.set_davg_zero = set_davg_zero
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,8 @@ def __init__(
seed: Optional[Union[int, list[int]]] = None,
):
super().__init__()
self.rcut = rcut
self.rcut_smth = rcut_smth
self.rcut = float(rcut)
self.rcut_smth = float(rcut_smth)
self.neuron = neuron
self.filter_neuron = self.neuron
self.tebd_dim = tebd_dim
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
mixed_types: bool,
) -> None:
super().__init__()
self.rcut = rcut
self.rcut = float(rcut)
self.ntypes = ntypes
self.mixed_types = mixed_types

Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,10 @@ plugins = ["source.3rdparty.coverage_plugins.jit_plugin"]
load-plugins = "deepmd_checker"
disable = "all"
enable = "E8001,E8002"

[tool.flake8]
select = [
"TOR0",
"TOR1",
"TOR2",
]
22 changes: 10 additions & 12 deletions source/api_cc/src/DeepPotPT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
nlist_data.copy_from_nlist(lmp_list);
nlist_data.shuffle_exclude_empty(fwd_map);
nlist_data.padding();
if (do_message_passing == 1 && nghost > 0) {
if (do_message_passing == 1) {
int nswap = lmp_list.nswap;
torch::Tensor sendproc_tensor =
torch::from_blob(lmp_list.sendproc, {nswap}, int32_option);
Expand All @@ -180,10 +180,14 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
torch::from_blob(lmp_list.recvnum, {nswap}, int32_option);
torch::Tensor sendnum_tensor =
torch::from_blob(lmp_list.sendnum, {nswap}, int32_option);
torch::Tensor communicator_tensor = torch::from_blob(
const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
// torch::Tensor communicator_tensor =
// torch::tensor(lmp_list.world, int32_option);
torch::Tensor communicator_tensor;
if (lmp_list.world == 0) {
communicator_tensor = torch::empty({1}, torch::kInt64);
} else {
communicator_tensor = torch::from_blob(
const_cast<void*>(lmp_list.world), {1}, torch::kInt64);
}

torch::Tensor nswap_tensor = torch::tensor(nswap, int32_option);
int total_send =
std::accumulate(lmp_list.sendnum, lmp_list.sendnum + nswap, 0);
Expand All @@ -196,12 +200,6 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
comm_dict.insert("recv_num", recvnum_tensor);
comm_dict.insert("communicator", communicator_tensor);
}
if (do_message_passing == 1 && nghost == 0) {
// for the situation that no ghost atoms (e.g. serial nopbc)
// set the mapping arange(nloc) is enough
auto option = torch::TensorOptions().device(device).dtype(torch::kInt64);
mapping_tensor = at::arange(nloc_real, option).unsqueeze(0);
}
}
at::Tensor firstneigh = createNlistTensor(nlist_data.jlist);
firstneigh_tensor = firstneigh.to(torch::kInt64).to(device);
Expand All @@ -224,7 +222,7 @@ void DeepPotPT::compute(ENERGYVTYPE& ener,
.to(device);
}
c10::Dict<c10::IValue, c10::IValue> outputs =
(do_message_passing == 1 && nghost > 0)
(do_message_passing == 1)
? module
.run_method("forward_lower", coord_wrapped_Tensor, atype_Tensor,
firstneigh_tensor, mapping_tensor, fparam_tensor,
Expand Down
2 changes: 1 addition & 1 deletion source/checker/deepmd_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def visit_call(self, node):
if (
isinstance(node.func, Attribute)
and isinstance(node.func.expr, Name)
and node.func.expr.name in {"np", "tf", "torch"}
and node.func.expr.name in {"np", "tf", "torch", "xp", "jnp"}
and node.func.attrname
in {
# https://pytorch.org/docs/stable/torch.html#creation-ops
Expand Down
8 changes: 0 additions & 8 deletions source/lmp/tests/test_lammps_dpa_pt_nopbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,14 +681,6 @@ def test_pair_deepmd_si(lammps_si):
[(["--balance"],), ([],)],
)
def test_pair_deepmd_mpi(balance_args: list):
if balance_args == []:
# python:5331 terminated with signal 11 at PC=7f3e940e3806 SP=7ffd5787edc0. Backtrace:
# /home/runner/work/deepmd-kit/deepmd-kit/dp_test/lib/libdeepmd_op_pt.so(+0x95806)[0x7f3e940e3806]
# /home/runner/work/deepmd-kit/deepmd-kit/dp_test/lib/libdeepmd_op_pt.so(+0x8f76e)[0x7f3e940dd76e]
# /home/runner/work/deepmd-kit/deepmd-kit/dp_test/lib/libdeepmd_op_pt.so(+0x9a38a)[0x7f3e940e838a]
# /home/runner/work/deepmd-kit/deepmd-kit/dp_test/lib/libdeepmd_op_pt.so(_Z9border_opRKN2at6TensorES2_S2_S2_S2_S2_S2_S2_S2_+0x8e)[0x7f3e940dda63]
# /home/runner/work/deepmd-kit/deepmd-kit/dp_test/lib/libdeepmd_op_pt.so(+0xaeac3)[0x7f3e940fcac3]
pytest.skip(reason="Known segfault, see comments for details")
with tempfile.NamedTemporaryFile() as f:
sp.check_call(
[
Expand Down
63 changes: 37 additions & 26 deletions source/op/pt/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,18 @@ class Border : public torch::autograd::Function<Border> {
int mpi_init = 0;
MPI_Initialized(&mpi_init);
int cuda_aware = 1;
int me;
int me = 0;
MPI_Comm world;
int world_size = 0;
unpack_communicator(communicator_tensor, world);
MPI_Comm_rank(world, &me);
MPI_Comm_size(world, &world_size);
if (mpi_init) {
unpack_communicator(communicator_tensor, world);
MPI_Comm_rank(world, &me);
MPI_Comm_size(world, &world_size);
}
MPI_Datatype mpi_type = get_mpi_type<FPTYPE>();
MPI_Request request;
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
if (world_size != 1) {
if (world_size >= 1) {
int version, subversion;
MPI_Get_version(&version, &subversion);
if (version >= 4) {
Expand All @@ -120,11 +122,15 @@ class Border : public torch::autograd::Function<Border> {
for (int iswap = 0; iswap < nswap; ++iswap) {
int nrecv = recvnum[iswap];
int nsend = sendnum[iswap];
torch::Tensor isendlist =
torch::from_blob(sendlist[iswap], {nsend}, int32_options)
.to(recv_g1_tensor.device());
torch::Tensor send_g1_tensor = recv_g1_tensor.index_select(0, isendlist);
FPTYPE* send_g1 = send_g1_tensor.data_ptr<FPTYPE>();
torch::Tensor isendlist;
torch::Tensor send_g1_tensor;
FPTYPE* send_g1;
if (nsend != 0) {
isendlist = torch::from_blob(sendlist[iswap], {nsend}, int32_options)
.to(recv_g1_tensor.device());
send_g1_tensor = recv_g1_tensor.index_select(0, isendlist);
send_g1 = send_g1_tensor.data_ptr<FPTYPE>();
}
#ifdef USE_MPI
if (sendproc[iswap] != me) {
if (nrecv) {
Expand Down Expand Up @@ -207,15 +213,17 @@ class Border : public torch::autograd::Function<Border> {
MPI_Initialized(&mpi_init);
int world_size = 0;
int cuda_aware = 1;
int me = 0;
MPI_Comm world;
unpack_communicator(communicator_tensor, world);
int me;
MPI_Comm_rank(world, &me);
MPI_Comm_size(world, &world_size);
if (mpi_init) {
unpack_communicator(communicator_tensor, world);
MPI_Comm_rank(world, &me);
MPI_Comm_size(world, &world_size);
}
MPI_Datatype mpi_type = get_mpi_type<FPTYPE>();
MPI_Request request;
#if defined(GOOGLE_CUDA) || defined(TENSORFLOW_USE_ROCM)
if (world_size != 1) {
if (world_size >= 1) {
int version, subversion;
MPI_Get_version(&version, &subversion);
if (version >= 4) {
Expand Down Expand Up @@ -248,17 +256,20 @@ class Border : public torch::autograd::Function<Border> {
int nlocal = nlocal_tensor.item<int>();
int nghost = nghost_tensor.item<int>();
int ntotal = nlocal + nghost;

torch::Tensor send_g1_tensor = d_local_g1_tensor;

int max_recvnum = sendnum_tensor.max().item<int>();
auto options = torch::TensorOptions()
.dtype(d_local_g1_tensor.dtype())
.device(d_local_g1_tensor.device());
torch::Tensor recv_g1_tensor =
torch::empty({max_recvnum, tensor_size}, options);
FPTYPE* recv_g1 = recv_g1_tensor.data_ptr<FPTYPE>();
FPTYPE* send_g1 = send_g1_tensor.data_ptr<FPTYPE>() + ntotal * tensor_size;
torch::Tensor send_g1_tensor;
torch::Tensor recv_g1_tensor;
FPTYPE* recv_g1;
FPTYPE* send_g1;
if (nswap != 0) {
send_g1_tensor = d_local_g1_tensor;
int max_recvnum = sendnum_tensor.max().item<int>();
auto options = torch::TensorOptions()
.dtype(d_local_g1_tensor.dtype())
.device(d_local_g1_tensor.device());
recv_g1_tensor = torch::empty({max_recvnum, tensor_size}, options);
recv_g1 = recv_g1_tensor.data_ptr<FPTYPE>();
send_g1 = send_g1_tensor.data_ptr<FPTYPE>() + ntotal * tensor_size;
}

int end = ntotal;
auto int32_options = torch::TensorOptions().dtype(torch::kInt32);
Expand Down
Loading

0 comments on commit 868ffa4

Please sign in to comment.