Skip to content

Commit 9cb7777

Browse files
committed
make AcceleratRunner a subclass of Accelerator
add TorchRunner add DeepSpeedRunner
1 parent 2c4bcf0 commit 9cb7777

28 files changed

+1929
-991
lines changed

.github/workflows/push.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- name: Install dependencies
2525
run: pip install -r requirements.txt && pip install -e .
2626
- name: Install dependencies for testing
27-
run: pip install pytest pytest-cov torch torcheval torchmetrics torchvision accelerate
27+
run: pip install pytest pytest-cov
2828
- name: pytest
2929
run: pytest --cov=materialx --cov-report=xml --cov-report=html .
3030
- name: Upload coverage report for documentation

danling/__init__.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from lazy_imports import try_import
2121

22-
from danling import metrics, modules, optim, registry, runner, tensors, typing, utils
22+
from danling import defaults, metrics, modules, optim, registry, runner, tensors, typing, utils
2323

2424
from .metrics import (
2525
AverageMeter,
@@ -31,7 +31,7 @@
3131
)
3232
from .optim import LRScheduler
3333
from .registry import GlobalRegistry, Registry
34-
from .runner import AccelerateRunner, BaseRunner, TorchRunner
34+
from .runner import AccelerateRunner, BaseRunner, Config, DeepSpeedRunner, Runner, TorchRunner
3535
from .tensors import NestedTensor, PNTensor, tensor
3636
from .utils import (
3737
catch,
@@ -49,6 +49,7 @@
4949
from .metrics import Metrics, MultiTaskMetrics
5050

5151
__all__ = [
52+
"defaults",
5253
"metrics",
5354
"modules",
5455
"optim",
@@ -57,9 +58,12 @@
5758
"tensors",
5859
"utils",
5960
"typing",
61+
"Config",
62+
"Runner",
6063
"BaseRunner",
61-
"AccelerateRunner",
6264
"TorchRunner",
65+
"AccelerateRunner",
66+
"DeepSpeedRunner",
6367
"LRScheduler",
6468
"Registry",
6569
"GlobalRegistry",

danling/runner/defaults.py danling/defaults.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
1818
# See the LICENSE file for more details.
1919

20-
DEFAULT_RUN_NAME = "Run"
21-
DEFAULT_EXPERIMENT_NAME = "DanLing"
22-
DEFAULT_EXPERIMENT_ID = "xxxxxxxxxxxxxxxx"
23-
DEFAULT_IGNORED_KEYS_IN_HASH = {
20+
RUN_NAME = "Run"
21+
EXPERIMENT_NAME = "DanLing"
22+
EXPERIMENT_ID = "xxxxxxxxxxxxxxxx"
23+
SEED = 1016
24+
IGNORED_NAMES_IN_METRICS = ("index", "epochs", "steps")
25+
IGNORED_NAMES_IN_HASH = {
2426
"timestamp",
25-
"iters",
26-
"steps",
27-
"epochs",
27+
"epoch",
28+
"step",
2829
"results",
2930
"score_split",
3031
"score",

danling/modules/mlp/dense.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def __init__(
3535
super().__init__()
3636
self.residual = residual
3737
self.linear = nn.Linear(in_features, out_features, bias=bias)
38-
self.norm = getattr(nn, norm)(out_features) if norm else nn.Identity()
39-
self.activation = getattr(nn, activation)() if activation else nn.Identity()
38+
self.norm = getattr(nn, norm)(out_features) if norm else None
39+
self.activation = getattr(nn, activation)() if activation else None
4040
self.dropout = nn.Dropout(dropout)
41-
self.pool = getattr(nn, pool)(out_features) if pool else nn.Identity() if self.residual else None
41+
self.pool = getattr(nn, pool)(out_features) if self.residual else None
4242

4343
def forward(self, x):
4444
out = self.linear(x)

danling/runner/README.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,26 @@ The Runner of DanLing sets up the basic environment for running neural networks.
44

55
## Components
66

7-
For cross-platform compatibilities, DanLing features a two-level Runner + RunnerState system.
7+
For cross-platform compatibilities, DanLing features a two-level Runner + Config system.
88

99
### PlatformRunner
1010

1111
PlatformRunner implements platform-specific features like `step` and `prepare`.
1212

13-
The Runner contains all runtime information that is irrelevant to the checkpoint (e.g. `world_size`, `rank`, etc.). All other information should be saved in `RunnerState`.
13+
The Runner contains all runtime information that is irrelevant to the checkpoint (e.g. `world_size`, `rank`, etc.). All other information should be saved in `Config`.
1414

1515
Currently, only [`AccelerateRunner`][danling.runner.AccelerateRunner] is supported.
1616

1717
### [`BaseRunner`][danling.runner.BaseRunner]
1818

19-
[`BaseRunner`](danling.runner.BaseRunner) defines shared attributes and implements platform-agnostic features, including `init_logging`, `results` and `scores`.
19+
[`BaseRunner`][danling.runner.BaseRunner] defines shared attributes and implements platform-agnostic features, including `init_logging`, `results` and `scores`.
2020

21-
### [`RunnerState`][danling.runner.RunnerState]
21+
### [`Config`][danling.runner.Config]
2222

23-
[`RunnerState`][danling.runner.RunnerState] stores the state of a run (e.g. `epochs`, `run_id`, `network`, etc.).
23+
[`Config`][danling.runner.Config] stores the state of a run (e.g. `epoch`, `run_id`, `network`, etc.).
2424

25-
With `RunnerState` and corresponding weights, you can resume a run from any point.
26-
Therefore, all members in `RunnerState` will be saved in the checkpoint, and thus should be json serialisable.
25+
With `Config` and corresponding weights, you can resume a run from any point.
26+
Therefore, all members in `Config` will be saved in the checkpoint, and thus should be json serialisable.
2727

2828
## Experiments Management
2929

danling/runner/__init__.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,22 @@
1717
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
1818
# See the LICENSE file for more details.
1919

20-
from . import defaults
2120
from .accelerate_runner import AccelerateRunner
2221
from .base_runner import BaseRunner
23-
from .state import RunnerState
22+
from .config import Config
23+
from .deepspeed_runner import DeepSpeedRunner
24+
from .runner import Runner
2425
from .torch_runner import TorchRunner
2526
from .utils import on_local_main_process, on_main_process
2627

2728
__all__ = [
28-
"RunnerState",
29+
"Config",
30+
"Runner",
2931
"BaseRunner",
32+
"TorchRunner",
3033
"AccelerateRunner",
34+
"DeepSpeedRunner",
3135
"TorchRunner",
3236
"on_main_process",
3337
"on_local_main_process",
34-
"defaults",
3538
]

0 commit comments

Comments
 (0)