Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): neighbor stat #4258

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class JAXBackend(Backend):
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature.IO
# Backend.Feature.ENTRY_POINT
| Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
# | Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.NEIGHBOR_STAT
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".jax"]
Expand Down Expand Up @@ -82,7 +82,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
type[NeighborStat]
The neighbor statistics of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.neighbor_stat import (
NeighborStat,
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
Expand Down
35 changes: 18 additions & 17 deletions deepmd/dpmodel/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Optional,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
Expand Down Expand Up @@ -68,42 +69,42 @@ def call(
np.ndarray
The maximal number of neighbors
"""
xp = array_api_compat.array_namespace(coord, atype)
nframes = coord.shape[0]
coord = coord.reshape(nframes, -1, 3)
coord = xp.reshape(coord, (nframes, -1, 3))
nloc = coord.shape[1]
coord = coord.reshape(nframes, nloc * 3)
coord = xp.reshape(coord, (nframes, nloc * 3))
extend_coord, extend_atype, _ = extend_coord_with_ghosts(
coord, atype, cell, self.rcut
)

coord1 = extend_coord.reshape(nframes, -1)
coord1 = xp.reshape(extend_coord, (nframes, -1))
nall = coord1.shape[1] // 3
coord0 = coord1[:, : nloc * 3]
diff = (
coord1.reshape([nframes, -1, 3])[:, None, :, :]
- coord0.reshape([nframes, -1, 3])[:, :, None, :]
xp.reshape(coord1, [nframes, -1, 3])[:, None, :, :]
- xp.reshape(coord0, [nframes, -1, 3])[:, :, None, :]
)
assert list(diff.shape) == [nframes, nloc, nall, 3]
# remove the diagonal elements
mask = np.eye(nloc, nall, dtype=bool)
diff[:, mask] = np.inf
rr2 = np.sum(np.square(diff), axis=-1)
min_rr2 = np.min(rr2, axis=-1)
mask = xp.eye(nloc, nall, dtype=xp.bool)
mask = xp.tile(mask[None, :, :, None], (nframes, 1, 1, 3))
diff = xp.where(mask, xp.full_like(diff, xp.inf), diff)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
njzjz marked this conversation as resolved.
Show resolved Hide resolved
rr2 = xp.sum(xp.square(diff), axis=-1)
min_rr2 = xp.min(rr2, axis=-1)
# count the number of neighbors
if not self.mixed_types:
mask = rr2 < self.rcut**2
nnei = np.zeros((nframes, nloc, self.ntypes), dtype=int)
nneis = []
for ii in range(self.ntypes):
nnei[:, :, ii] = np.sum(
mask & (extend_atype == ii)[:, None, :], axis=-1
)
nneis.append(xp.sum(mask & (extend_atype == ii)[:, None, :], axis=-1))
nnei = xp.stack(nneis, axis=-1)
else:
mask = rr2 < self.rcut**2
# virtual type (<0) are not counted
nnei = np.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1).reshape(
nframes, nloc, 1
)
max_nnei = np.max(nnei, axis=1)
nnei = xp.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1)
nnei = xp.reshape(nnei, (nframes, nloc, 1))
max_nnei = xp.max(nnei, axis=1)
return min_rr2, max_nnei


Expand Down
59 changes: 59 additions & 0 deletions deepmd/jax/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import jaxlib

from deepmd.jax.env import (
jax,
)
from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase


class AutoBatchSize(AutoBatchSizeBase):
"""Auto batch size.

Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
is not set
factor : float, default: 2.
increased factor

"""

def __init__(
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
):
super().__init__(
initial_batch_size=initial_batch_size,
factor=factor,
)

def is_gpu_available(self) -> bool:
"""Check if GPU is available.

Returns
-------
bool
True if GPU is available
"""
return jax.devices()[0].platform == "gpu"
njzjz marked this conversation as resolved.
Show resolved Hide resolved

def is_oom_error(self, e: Exception) -> bool:
"""Check if the exception is an OOM error.

Parameters
----------
e : Exception
Exception
"""
# several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
# such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
# (the meaningless error message should be considered as a bug in cusolver)
if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (

Check warning on line 55 in deepmd/jax/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/utils/auto_batch_size.py#L55

Added line #L55 was not covered by tests
"RESOURCE_EXHAUSTED:" in e.args[0]
):
return True
return False

Check warning on line 59 in deepmd/jax/utils/auto_batch_size.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/utils/auto_batch_size.py#L58-L59

Added lines #L58 - L59 were not covered by tests
104 changes: 104 additions & 0 deletions deepmd/jax/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Iterator,
)
from typing import (
Optional,
)

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.neighbor_stat import (
NeighborStatOP,
)
from deepmd.jax.common import (
to_jax_array,
)
from deepmd.jax.utils.auto_batch_size import (
AutoBatchSize,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat


class NeighborStat(BaseNeighborStat):
"""Neighbor statistics using JAX.

Parameters
----------
ntypes : int
The num of atom types
rcut : float
The cut-off radius
mixed_type : bool, optional, default=False
Treat all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
mixed_type: bool = False,
) -> None:
super().__init__(ntypes, rcut, mixed_type)
self.op = NeighborStatOP(ntypes, rcut, mixed_type)
self.auto_batch_size = AutoBatchSize()

def iterator(
self, data: DeepmdDataSystem
) -> Iterator[tuple[np.ndarray, float, str]]:
"""Iterator method for producing neighbor statistics data.

Yields
------
np.ndarray
The maximal number of neighbors
float
The squared minimal distance between two atoms
str
The directory of the data system
"""
for ii in range(len(data.system_dirs)):
for jj in data.data_systems[ii].dirs:
data_set = data.data_systems[ii]
data_set_data = data_set._load_set(jj)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
minrr2, max_nnei = self.auto_batch_size.execute_all(
njzjz marked this conversation as resolved.
Show resolved Hide resolved
self._execute,
data_set_data["coord"].shape[0],
data_set.get_natoms(),
data_set_data["coord"],
data_set_data["type"],
data_set_data["box"] if data_set.pbc else None,
)
yield np.max(max_nnei, axis=0), np.min(minrr2), jj

def _execute(
self,
coord: np.ndarray,
atype: np.ndarray,
cell: Optional[np.ndarray],
):
"""Execute the operation.

Parameters
----------
coord
The coordinates of atoms.
atype
The atom types.
cell
The cell.
"""
minrr2, max_nnei = self.op(
to_jax_array(coord),
to_jax_array(atype),
to_jax_array(cell),
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
minrr2 = to_numpy_array(minrr2)
max_nnei = to_numpy_array(max_nnei)
return minrr2, max_nnei
69 changes: 0 additions & 69 deletions source/tests/common/dpmodel/test_neighbor_stat.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from ..seed import (
GLOBAL_SEED,
)
from .common import (
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
)


def gen_sys(nframes):
Expand Down Expand Up @@ -42,7 +47,7 @@ def setUp(self):
def tearDown(self):
shutil.rmtree("system_0")

def test_neighbor_stat(self):
def run_neighbor_stat(self, backend):
for rcut in (0.0, 1.0, 2.0, 4.0):
for mixed_type in (True, False):
with self.subTest(rcut=rcut, mixed_type=mixed_type):
Expand All @@ -52,7 +57,7 @@ def test_neighbor_stat(self):
rcut=rcut,
type_map=["TYPE", "NO_THIS_TYPE"],
mixed_type=mixed_type,
backend="pytorch",
backend=backend,
)
upper = np.ceil(rcut) + 1
X, Y, Z = np.mgrid[-upper:upper, -upper:upper, -upper:upper]
Expand All @@ -67,3 +72,18 @@ def test_neighbor_stat(self):
if not mixed_type:
ret.append(0)
np.testing.assert_array_equal(max_nbor_size, ret)

@unittest.skipUnless(INSTALLED_TF, "tensorflow is not installed")
def test_neighbor_stat_tf(self):
self.run_neighbor_stat("tensorflow")

@unittest.skipUnless(INSTALLED_PT, "pytorch is not installed")
def test_neighbor_stat_pt(self):
self.run_neighbor_stat("pytorch")

def test_neighbor_stat_dp(self):
self.run_neighbor_stat("numpy")

@unittest.skipUnless(INSTALLED_JAX, "jax is not installed")
def test_neighbor_stat_jax(self):
self.run_neighbor_stat("jax")
Loading
Loading