Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pluggable backend #3294

Merged
merged 4 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions deepmd/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Backends.

Avoid directly importing third-party libraries in this module for performance.
"""
# copy from dpdata
import importlib
Fixed Show fixed Hide fixed
from importlib import (
metadata,
)
from pathlib import (
Path,
)

PACKAGE_BASE = "deepmd.backend"
NOT_LOADABLE = ("__init__.py",)

for module_file in Path(__file__).parent.glob("*.py"):
if module_file.name not in NOT_LOADABLE:
module_name = f".{module_file.stem}"
importlib.import_module(module_name, PACKAGE_BASE)

# https://setuptools.readthedocs.io/en/latest/userguide/entry_point.html
eps = metadata.entry_points(group="deepmd.backend")
for ep in eps:
plugin = ep.load()

Check warning on line 26 in deepmd/backend/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/__init__.py#L26

Added line #L26 was not covered by tests
200 changes: 200 additions & 0 deletions deepmd/backend/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
abstractmethod,
)
from enum import (
Flag,
auto,
)
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
Dict,
List,
Type,
)

from deepmd.utils.plugin import (
Plugin,
PluginVariant,
)

if TYPE_CHECKING:
from argparse import (

Check warning on line 24 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L24

Added line #L24 was not covered by tests
Namespace,
)

from deepmd.infer.deep_eval import (

Check warning on line 28 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L28

Added line #L28 was not covered by tests
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (

Check warning on line 31 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L31

Added line #L31 was not covered by tests
NeighborStat,
)


class Backend(PluginVariant):
r"""General backend class.

Examples
--------
>>> @Backend.register("tf")
>>> @Backend.register("tensorflow")
>>> class TensorFlowBackend(Backend):
... pass
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a backend plugin.

Parameters
----------
key : str
the key of a backend

Returns
-------
Callable[[object], object]
the decorator to register backend
"""
return Backend.__plugins.register(key.lower())

@staticmethod
def get_backend(key: str) -> Type["Backend"]:
"""Get the backend by key.

Parameters
----------
key : str
the key of a backend

Returns
-------
Backend
the backend
"""
try:
backend = Backend.__plugins.get_plugin(key.lower())
except KeyError:
raise KeyError(f"Backend {key} is not registered.")
assert isinstance(backend, type)
return backend

Check warning on line 84 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L79-L84

Added lines #L79 - L84 were not covered by tests

@staticmethod
def get_backends() -> Dict[str, Type["Backend"]]:
"""Get all the registered backend names.

Returns
-------
list
all the registered backends
"""
return Backend.__plugins.plugins

Check warning on line 95 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L95

Added line #L95 was not covered by tests

@staticmethod
def get_backends_by_feature(
feature: "Backend.Feature",
) -> Dict[str, Type["Backend"]]:
"""Get all the registered backend names with a specific feature.

Parameters
----------
feature : Backend.Feature
the feature flag

Returns
-------
list
all the registered backends with the feature
"""
return {

Check warning on line 113 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L113

Added line #L113 was not covered by tests
key: backend
for key, backend in Backend.__plugins.plugins.items()
if backend.features & feature
}

@staticmethod
def detect_backend_by_model(filename: str) -> Type["Backend"]:
"""Detect the backend of the given model file.

Parameters
----------
filename : str
The model file name
"""
filename = str(filename).lower()
for backend in Backend.get_backends().values():
if filename.endswith(backend.suffixes[0]):
return backend
raise ValueError(f"Cannot detect the backend of the model file {filename}.")

Check warning on line 132 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L128-L132

Added lines #L128 - L132 were not covered by tests

class Feature(Flag):
"""Feature flag to indicate whether the backend supports certain features."""

ENTRY_POINT = auto()
"""Support entry point hook."""
DEEP_EVAL = auto()
"""Support Deep Eval backend."""
NEIGHBOR_STAT = auto()
"""Support neighbor statistics."""

name: ClassVar[str] = "Unknown"
"""The formal name of the backend.

To be consistent, this name should be also registered in the plugin system."""

features: ClassVar[Feature] = Feature(0)
"""The features of the backend."""
suffixes: ClassVar[List[str]] = []
"""The supported suffixes of the saved model.

The first element is considered as the default suffix."""

@abstractmethod
def is_available(self) -> bool:
"""Check if the backend is available.

Returns
-------
bool
Whether the backend is available.
"""

@property
@abstractmethod
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.

Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
pass

Check warning on line 176 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L176

Added line #L176 was not covered by tests

@property
@abstractmethod
def deep_eval(self) -> Type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.

Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
pass

Check warning on line 188 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L188

Added line #L188 was not covered by tests

@property
@abstractmethod
def neighbor_stat(self) -> Type["NeighborStat"]:
"""The neighbor statistics of the backend.

Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
pass

Check warning on line 200 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L200

Added line #L200 was not covered by tests
86 changes: 86 additions & 0 deletions deepmd/backend/dpmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
TYPE_CHECKING,
Callable,
ClassVar,
List,
Type,
)

from deepmd.backend.backend import (
Backend,
)

if TYPE_CHECKING:
from argparse import (

Check warning on line 15 in deepmd/backend/dpmodel.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/dpmodel.py#L15

Added line #L15 was not covered by tests
Namespace,
)

from deepmd.infer.deep_eval import (

Check warning on line 19 in deepmd/backend/dpmodel.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/dpmodel.py#L19

Added line #L19 was not covered by tests
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (

Check warning on line 22 in deepmd/backend/dpmodel.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/dpmodel.py#L22

Added line #L22 was not covered by tests
NeighborStat,
)


@Backend.register("dp")
@Backend.register("dpmodel")
@Backend.register("np")
@Backend.register("numpy")
class DPModelBackend(Backend):
"""DPModel backend that uses NumPy as the reference implementation."""

name = "DPModel"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = Backend.Feature.NEIGHBOR_STAT
"""The features of the backend."""
suffixes: ClassVar[List[str]] = ["dp"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
"""Check if the backend is available.

Returns
-------
bool
Whether the backend is available.
"""
return True

Check warning on line 49 in deepmd/backend/dpmodel.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/dpmodel.py#L49

Added line #L49 was not covered by tests

@property
def entry_point_hook(self) -> Callable[["Namespace"], None]:
"""The entry point hook of the backend.

Returns
-------
Callable[[Namespace], None]
The entry point hook of the backend.
"""
raise NotImplementedError(f"Unsupported backend: {self.name}")

Check warning on line 60 in deepmd/backend/dpmodel.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/dpmodel.py#L60

Added line #L60 was not covered by tests

@property
def deep_eval(self) -> Type["DeepEvalBackend"]:
"""The Deep Eval backend of the backend.

Returns
-------
type[DeepEvalBackend]
The Deep Eval backend of the backend.
"""
raise NotImplementedError(f"Unsupported backend: {self.name}")

Check warning on line 71 in deepmd/backend/dpmodel.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/dpmodel.py#L71

Added line #L71 was not covered by tests

@property
def neighbor_stat(self) -> Type["NeighborStat"]:
"""The neighbor statistics of the backend.

Returns
-------
type[NeighborStat]
The neighbor statistics of the backend.
"""
from deepmd.dpmodel.utils.neighbor_stat import (

Check warning on line 82 in deepmd/backend/dpmodel.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/dpmodel.py#L82

Added line #L82 was not covered by tests
NeighborStat,
)

return NeighborStat

Check warning on line 86 in deepmd/backend/dpmodel.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/dpmodel.py#L86

Added line #L86 was not covered by tests
Loading
Loading