Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 29d2c50

Browse files
committedMar 24, 2025··
make AcceleratRunner a subclass of Accelerator
add TorchRunner add DeepSpeedRunner
1 parent 2c4bcf0 commit 29d2c50

28 files changed

+1929
-991
lines changed
 

‎.github/workflows/push.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
- name: Install dependencies
2525
run: pip install -r requirements.txt && pip install -e .
2626
- name: Install dependencies for testing
27-
run: pip install pytest pytest-cov torch torcheval torchmetrics torchvision accelerate
27+
run: pip install pytest pytest-cov
2828
- name: pytest
2929
run: pytest --cov=materialx --cov-report=xml --cov-report=html .
3030
- name: Upload coverage report for documentation

‎danling/__init__.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from lazy_imports import try_import
2121

22-
from danling import metrics, modules, optim, registry, runner, tensors, typing, utils
22+
from danling import defaults, metrics, modules, optim, registry, runner, tensors, typing, utils
2323

2424
from .metrics import (
2525
AverageMeter,
@@ -31,7 +31,7 @@
3131
)
3232
from .optim import LRScheduler
3333
from .registry import GlobalRegistry, Registry
34-
from .runner import AccelerateRunner, BaseRunner, TorchRunner
34+
from .runner import AccelerateRunner, BaseRunner, Config, DeepSpeedRunner, Runner, TorchRunner
3535
from .tensors import NestedTensor, PNTensor, tensor
3636
from .utils import (
3737
catch,
@@ -49,6 +49,7 @@
4949
from .metrics import Metrics, MultiTaskMetrics
5050

5151
__all__ = [
52+
"defaults",
5253
"metrics",
5354
"modules",
5455
"optim",
@@ -57,9 +58,12 @@
5758
"tensors",
5859
"utils",
5960
"typing",
61+
"Config",
62+
"Runner",
6063
"BaseRunner",
61-
"AccelerateRunner",
6264
"TorchRunner",
65+
"AccelerateRunner",
66+
"DeepSpeedRunner",
6367
"LRScheduler",
6468
"Registry",
6569
"GlobalRegistry",

‎danling/runner/defaults.py ‎danling/defaults.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,15 @@
1717
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
1818
# See the LICENSE file for more details.
1919

20-
DEFAULT_RUN_NAME = "Run"
21-
DEFAULT_EXPERIMENT_NAME = "DanLing"
22-
DEFAULT_EXPERIMENT_ID = "xxxxxxxxxxxxxxxx"
23-
DEFAULT_IGNORED_KEYS_IN_HASH = {
20+
RUN_NAME = "Run"
21+
EXPERIMENT_NAME = "DanLing"
22+
EXPERIMENT_ID = "xxxxxxxxxxxxxxxx"
23+
SEED = 1016
24+
IGNORED_NAMES_IN_METRICS = ("index", "epochs", "steps")
25+
IGNORED_NAMES_IN_HASH = {
2426
"timestamp",
25-
"iters",
26-
"steps",
27-
"epochs",
27+
"epoch",
28+
"step",
2829
"results",
2930
"score_split",
3031
"score",

‎danling/modules/mlp/dense.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def __init__(
3535
super().__init__()
3636
self.residual = residual
3737
self.linear = nn.Linear(in_features, out_features, bias=bias)
38-
self.norm = getattr(nn, norm)(out_features) if norm else nn.Identity()
39-
self.activation = getattr(nn, activation)() if activation else nn.Identity()
38+
self.norm = getattr(nn, norm)(out_features) if norm else None
39+
self.activation = getattr(nn, activation)() if activation else None
4040
self.dropout = nn.Dropout(dropout)
41-
self.pool = getattr(nn, pool)(out_features) if pool else nn.Identity() if self.residual else None
41+
self.pool = getattr(nn, pool)(out_features) if self.residual else None
4242

4343
def forward(self, x):
4444
out = self.linear(x)

‎danling/runner/README.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,26 @@ The Runner of DanLing sets up the basic environment for running neural networks.
44

55
## Components
66

7-
For cross-platform compatibilities, DanLing features a two-level Runner + RunnerState system.
7+
For cross-platform compatibilities, DanLing features a two-level Runner + Config system.
88

99
### PlatformRunner
1010

1111
PlatformRunner implements platform-specific features like `step` and `prepare`.
1212

13-
The Runner contains all runtime information that is irrelevant to the checkpoint (e.g. `world_size`, `rank`, etc.). All other information should be saved in `RunnerState`.
13+
The Runner contains all runtime information that is irrelevant to the checkpoint (e.g. `world_size`, `rank`, etc.). All other information should be saved in `Config`.
1414

1515
Currently, only [`AccelerateRunner`][danling.runner.AccelerateRunner] is supported.
1616

1717
### [`BaseRunner`][danling.runner.BaseRunner]
1818

19-
[`BaseRunner`](danling.runner.BaseRunner) defines shared attributes and implements platform-agnostic features, including `init_logging`, `results` and `scores`.
19+
[`BaseRunner`][danling.runner.BaseRunner] defines shared attributes and implements platform-agnostic features, including `init_logging`, `results` and `scores`.
2020

21-
### [`RunnerState`][danling.runner.RunnerState]
21+
### [`Config`][danling.runner.Config]
2222

23-
[`RunnerState`][danling.runner.RunnerState] stores the state of a run (e.g. `epochs`, `run_id`, `network`, etc.).
23+
[`Config`][danling.runner.Config] stores the state of a run (e.g. `epoch`, `run_id`, `network`, etc.).
2424

25-
With `RunnerState` and corresponding weights, you can resume a run from any point.
26-
Therefore, all members in `RunnerState` will be saved in the checkpoint, and thus should be json serialisable.
25+
With `Config` and corresponding weights, you can resume a run from any point.
26+
Therefore, all members in `Config` will be saved in the checkpoint, and thus should be json serialisable.
2727

2828
## Experiments Management
2929

‎danling/runner/__init__.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,22 @@
1717
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
1818
# See the LICENSE file for more details.
1919

20-
from . import defaults
2120
from .accelerate_runner import AccelerateRunner
2221
from .base_runner import BaseRunner
23-
from .state import RunnerState
22+
from .config import Config
23+
from .deepspeed_runner import DeepSpeedRunner
24+
from .runner import Runner
2425
from .torch_runner import TorchRunner
2526
from .utils import on_local_main_process, on_main_process
2627

2728
__all__ = [
28-
"RunnerState",
29+
"Config",
30+
"Runner",
2931
"BaseRunner",
32+
"TorchRunner",
3033
"AccelerateRunner",
34+
"DeepSpeedRunner",
3135
"TorchRunner",
3236
"on_main_process",
3337
"on_local_main_process",
34-
"defaults",
3538
]

‎danling/runner/accelerate_runner.py

+115-492
Large diffs are not rendered by default.

‎danling/runner/base_runner.py

+547-404
Large diffs are not rendered by default.

‎danling/runner/state.py ‎danling/runner/config.py

+22-41
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,27 @@
1919

2020
from __future__ import annotations
2121

22-
from random import randint
2322
from typing import Optional
2423
from uuid import UUID, uuid5
2524

26-
from chanfig import NestedDict
25+
import chanfig
26+
27+
from danling import defaults
2728

28-
from . import defaults
2929
from .utils import get_git_hash
3030

3131

32-
class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes
32+
class Config(chanfig.Config): # pylint: disable=too-many-instance-attributes
3333
r"""
34-
`RunnerState` is a `NestedDict` that contains all states of a `Runner`.
34+
`Config` is a [`Config`][chanfig.Config] that contains all states of a `Runner`.
3535
36-
`RunnerState` is designed to store all critical information of a Run so that you can resume a run
36+
`Config` is designed to store all critical information of a Run so that you can resume a run
3737
from a state and corresponding weights or even restart a run from a state.
3838
39-
`RunnerState` is also designed to be serialisable and hashable, so that you can save it to a file.
40-
`RunnerState` is saved in checkpoint together with weights by default.
39+
`Config` is also designed to be serialisable and hashable, so that you can save it to a file.
40+
`Config` is saved in checkpoint together with weights by default.
4141
42-
Since `RunnerState` is a [`NestedDict`][chanfig.NestedDict], you can access its attributes by
42+
Since `Config` is a [`Config`][chanfig.Config], you can access its attributes by
4343
`state["key"]` or `state.key`.
4444
4545
Attributes: General:
@@ -59,18 +59,14 @@ class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes
5959
Defaults to `False`.
6060
6161
Attributes: Progress:
62-
iters (int): The number of data samples processed.
63-
equals to `steps` when `batch_size = 1`.
64-
steps (int): The number of `step` calls.
62+
steps (int): The number of `steps` calls.
6563
epochs (int): The number of complete passes over the datasets.
66-
iter_end (int): End running iters.
67-
Note that `step_end` not initialised since this variable may not apply to some Runners.
68-
step_end (int): End running steps.
64+
step_end (int): End running step.
6965
Note that `step_end` not initialised since this variable may not apply to some Runners.
70-
epoch_end (int): End running epochs.
66+
epoch_end (int): End running epoch.
7167
Note that `epoch_end` not initialised since this variable may not apply to some Runners.
7268
73-
In general you should only use one of `iter_end`, `step_end`, `epoch_end` to indicate the length of running.
69+
In general you should use either `step_end` or `epoch_end` to indicate the length of running.
7470
7571
Attributes: IO:
7672
project_root (str): The root directory for all experiments.
@@ -98,29 +94,26 @@ class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes
9894
If <= 0, save only the latest and the best checkpoints.
9995
10096
Note:
101-
`RunnerState` is a `NestedDict`, so you can access its attributes by `state["name"]` or `state.name`.
97+
`Config` is a [`Config`][chanfig.Config], so you can access its attributes by `state["name"]` or `state.name`.
10298
10399
See Also:
104100
[`BaseRunner`][danling.runner.BaseRunner]: The base runner class.
105101
"""
106102

107103
# DO NOT set default value in class, as they won't be stored in `__dict__`.
108104

109-
run_name: str = defaults.DEFAULT_RUN_NAME
105+
run_name: str = defaults.RUN_NAME
110106
run_id: str
111-
experiment_name: str = defaults.DEFAULT_EXPERIMENT_NAME
107+
experiment_name: str = defaults.EXPERIMENT_NAME
112108
experiment_id: str
113109

114-
seed: int
110+
seed: Optional[int] = None
115111
deterministic: bool = False
116112

117-
iters: int = 0
118113
steps: int = 0
119114
epochs: int = 0
120-
iter_begin: int = 0
121115
step_begin: int = 0
122116
epoch_begin: int = 0
123-
iter_end: Optional[int] = None
124117
step_end: Optional[int] = None
125118
epoch_end: Optional[int] = None
126119

@@ -134,24 +127,12 @@ class RunnerState(NestedDict): # pylint: disable=too-many-instance-attributes
134127
log_interval: Optional[int] = None
135128
save_interval: Optional[int] = None
136129

137-
distributed: Optional[bool] = None
138-
dist_backend: Optional[str] = None
139-
init_method: Optional[str] = None
140-
master_addr: Optional[str] = None
141-
master_port: Optional[int] = None
142-
143-
def __init__(self, *args, **kwargs):
144-
for k, v in self.__class__.__dict__.items():
145-
if not (k.startswith("__") and k.endswith("__")) and (not (isinstance(v, property) or callable(v))):
146-
self.set(k, v)
147-
if "seed" not in self:
148-
self.seed = randint(0, 2**32 - 1)
149-
super().__init__(*args, **kwargs)
130+
def __post_init__(self):
150131
if "experiment_id" not in self:
151-
self.experiment_id = get_git_hash() or defaults.DEFAULT_EXPERIMENT_ID
132+
self.experiment_id = get_git_hash() or defaults.EXPERIMENT_ID
152133
if "run_id" not in self:
153134
self.run_id = self.run_uuid.hex
154-
self.setattr("ignored_keys_in_hash", defaults.DEFAULT_IGNORED_KEYS_IN_HASH)
135+
self.setattr("ignored_keys_in_hash", defaults.IGNORED_NAMES_IN_HASH)
155136

156137
@property
157138
def experiment_uuid(self) -> UUID:
@@ -167,8 +148,8 @@ def run_uuid(self) -> UUID:
167148
UUID of the run.
168149
"""
169150

170-
ignored_keys_in_hash = self.getattr("ignored_keys_in_hash", defaults.DEFAULT_IGNORED_KEYS_IN_HASH)
171-
state: NestedDict = NestedDict({k: v for k, v in self.dict().items() if k not in ignored_keys_in_hash})
151+
ignored_keys_in_hash = self.getattr("ignored_keys_in_hash", defaults.IGNORED_NAMES_IN_HASH)
152+
state: chanfig.Config = chanfig.Config({k: v for k, v in self.dict().items() if k not in ignored_keys_in_hash})
172153
return uuid5(self.experiment_uuid, state.yamls())
173154

174155
def __hash__(self) -> int:

‎danling/runner/deepspeed_runner.py

+211
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# DanLing
2+
# Copyright (C) 2022-Present DanLing
3+
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the following licenses:
6+
# - The Unlicense
7+
# - GNU Affero General Public License v3.0 or later
8+
# - GNU General Public License v2.0 or later
9+
# - BSD 4-Clause "Original" or "Old" License
10+
# - MIT License
11+
# - Apache License 2.0
12+
13+
# This program is distributed in the hope that it will be useful,
14+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
15+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
16+
# See the LICENSE file for more details.
17+
18+
from __future__ import annotations
19+
20+
import os
21+
import shutil
22+
23+
import torch
24+
from chanfig import NestedDict
25+
from lazy_imports import try_import
26+
from torch import distributed as dist
27+
from torch import nn
28+
from torch.nn.utils import clip_grad_value_
29+
30+
from danling.runner.config import Config
31+
from danling.utils import catch
32+
33+
from .torch_runner import TorchRunner
34+
35+
with try_import() as ds:
36+
import deepspeed
37+
38+
39+
class DeepSpeedRunner(TorchRunner):
40+
41+
def __init__(self, config: Config) -> None:
42+
ds.check()
43+
super().__init__(config)
44+
45+
def init_distributed(self) -> None:
46+
r"""
47+
Set up distributed training.
48+
49+
Initialise process group and set up DDP variables.
50+
"""
51+
52+
backend = self.config.get("backend", os.getenv("BACKEND"))
53+
init_method = self.config.get("init_method", os.getenv("INIT_METHOD"))
54+
world_size = int(self.config.get("world_size", os.getenv("WORLD_SIZE", "1")))
55+
rank = int(self.config.get("rank", os.getenv("RANK", "0")))
56+
if world_size > 1:
57+
if torch.cuda.is_available():
58+
torch.cuda.set_device(self.get_local_rank())
59+
deepspeed.init_distributed(dist_backend=backend, init_method=init_method, world_size=world_size, rank=rank)
60+
object_list = [self.id, self.timestamp]
61+
dist.broadcast_object_list(object_list)
62+
self.id, self.timestamp = object_list
63+
64+
def __post_init__(self):
65+
super().__post_init__()
66+
self.config.deepspeed = self.get_deepspeed_config()
67+
self.model, self.optimizer, _, self.scheduler = deepspeed.initialize(
68+
model=self.model,
69+
optimizer=self.optimizer,
70+
lr_scheduler=self.scheduler,
71+
config=self.config.deepspeed,
72+
)
73+
74+
def advance(self, loss) -> None:
75+
self.backward(loss)
76+
if self.config.get("max_grad_value") is not None:
77+
clip_grad_value_(self.model.parameters(), self.config["max_grad_value"])
78+
self.model.step()
79+
if self.ema is not None:
80+
self.ema.update()
81+
self.config.steps = self.model.global_steps
82+
83+
def backward(self, loss: torch.Tensor) -> None:
84+
return self.model.backward(loss)
85+
86+
def get_local_rank(self) -> int:
87+
local_rank = self.config.get("local_rank", os.getenv("LOCAL_RANK"))
88+
if local_rank is not None:
89+
return int(local_rank)
90+
rank = self.config.get("rank", os.getenv("RANK"))
91+
world_size = self.config.get("world_size", os.getenv("WORLD_SIZE"))
92+
if world_size is None or rank is None:
93+
raise ValueError("Please provide either `local_rank` or `world_size` and `rank`")
94+
return int(world_size) % int(rank)
95+
96+
def unwrap(self, model: nn.Module) -> nn.Module:
97+
while isinstance(model, (deepspeed.DeepSpeedEngine, nn.parallel.DistributedDataParallel)):
98+
model = model.module
99+
return model
100+
101+
@property
102+
def deepspeed(self) -> NestedDict | None:
103+
if isinstance(self.model, deepspeed.DeepSpeedEngine):
104+
return self.model.config
105+
return None
106+
107+
@catch
108+
def save_checkpoint(self, name: str = "latest", epoch: int | None = None, save_best: bool = True) -> None:
109+
r"""
110+
Save checkpoint to `self.checkpoint_dir`.
111+
112+
Args:
113+
name: Name of the checkpoint. Defaults to `"latest"`.
114+
epoch: Epoch to save. Defaults to `self.epochs`.
115+
save_best: If `True`, when `self.is_best` is `True`, the checkpoint will also be copied to
116+
`self.checkpoint_dir/best`.
117+
118+
If `self.config.save_interval` is positive and `epochs + 1` is a multiple of `save_interval`,
119+
the checkpoint will also be copied to `self.checkpoint_dir/epoch-{epochs}`.
120+
"""
121+
122+
epoch = epoch or self.epochs
123+
save_interval = self.config.get("save_interval", -1)
124+
latest_path = os.path.join(self.checkpoint_dir, name)
125+
os.makedirs(latest_path, exist_ok=True)
126+
self.yaml(os.path.join(latest_path, "runner.yaml"))
127+
self.model.save_checkpoint(
128+
self.checkpoint_dir, tag=name, client_state={"runner": self.config.dict()}, save_latest=False
129+
)
130+
if save_interval > 0 and (epoch + 1) % save_interval == 0:
131+
save_path = os.path.join(self.checkpoint_dir, f"epoch-{epoch}")
132+
shutil.copytree(latest_path, save_path, dirs_exist_ok=True)
133+
if save_best and self.is_best:
134+
best_path = os.path.join(self.checkpoint_dir, "best")
135+
shutil.copytree(latest_path, best_path, dirs_exist_ok=True)
136+
137+
def load_checkpoint(self, checkpoint: bytes | str | os.PathLike, *args, **kwargs) -> None: # type: ignore[override]
138+
"""
139+
Load model, optimizer, and scheduler from checkpoint.
140+
141+
Args:
142+
checkpoint: Checkpoint (or its path) to load.
143+
*args: Additional arguments to pass to `self.load`.
144+
**kwargs: Additional keyword arguments to pass to `self.load`.
145+
146+
Raises:
147+
ValueError: If `model` is not defined.
148+
ValueError: If `model` is not an instance of `deepspeed.DeepSpeedEngine`.
149+
150+
See Also:
151+
[`from_checkpoint`][danling.BaseRunner.from_checkpoint]: Build runner from checkpoint.
152+
[`load_pretrained`][danling.BaseRunner.load_pretrained]: Load model parameters from pretrained checkpoint.
153+
"""
154+
155+
if self.model is None:
156+
raise ValueError("model is not defined")
157+
if not isinstance(self.model, deepspeed.DeepSpeedEngine):
158+
raise ValueError("model is not an instance of `deepspeed.DeepSpeedEngine`")
159+
160+
self.model.load_checkpoint(checkpoint)
161+
self.config.checkpoint = checkpoint
162+
163+
def load_pretrained(self, checkpoint: bytes | str | os.PathLike, *args, **kwargs) -> None: # type: ignore[override]
164+
"""
165+
Load model from pretrained checkpoint.
166+
167+
This method only loads the model weights.
168+
169+
Args:
170+
checkpoint: Pretrained checkpoint directory.
171+
*args: Additional arguments to pass to `self.load`.
172+
**kwargs: Additional keyword arguments to pass to `self.load`.
173+
174+
Raises:
175+
ValueError: If `model` is not defined.
176+
177+
See Also:
178+
[`load_checkpoint`][danling.BaseRunner.load_checkpoint]: Load model, optimizer, and scheduler from
179+
checkpoint.
180+
"""
181+
182+
if self.model is None:
183+
raise ValueError("model is not defined")
184+
185+
self.model.load_checkpoint(checkpoint, load_module_only=True)
186+
self.config.pretrained = checkpoint
187+
188+
def load_config(
189+
self, checkpoint: bytes | str | os.PathLike, overwrite: bool = False, *args, **kwargs # type: ignore[override]
190+
) -> None:
191+
r"""
192+
Load config from checkpoint.
193+
194+
Args:
195+
checkpoint: Checkpoint (or its path) to load.
196+
overwrite: If `True`, overwrite the current config with the loaded config.
197+
Defaults to `False`.
198+
*args: Additional arguments to pass to `self.load`.
199+
**kwargs: Additional keyword arguments to pass to `self.load`.
200+
201+
Raises:
202+
FileNotFoundError: If `checkpoint` does not exists.
203+
"""
204+
205+
if isinstance(checkpoint, bytes):
206+
checkpoint = checkpoint.decode()
207+
208+
config = self.load(os.path.join(checkpoint, "runner.yaml"), *args, **kwargs)
209+
self.config.merge(config, overwrite=overwrite)
210+
self.step_begin = config["steps"] + 1
211+
self.epoch_begin = config["epochs"] + 1

‎danling/runner/runner.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# DanLing
2+
# Copyright (C) 2022-Present DanLing
3+
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the following licenses:
6+
# - The Unlicense
7+
# - GNU Affero General Public License v3.0 or later
8+
# - GNU General Public License v2.0 or later
9+
# - BSD 4-Clause "Original" or "Old" License
10+
# - MIT License
11+
# - Apache License 2.0
12+
13+
# This program is distributed in the hope that it will be useful,
14+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
15+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
16+
# See the LICENSE file for more details.
17+
18+
from __future__ import annotations
19+
20+
from lazy_imports import try_import
21+
22+
from .base_runner import BaseRunner
23+
from .config import Config
24+
from .torch_runner import TorchRunner
25+
26+
with try_import() as ac:
27+
import accelerate # noqa: F401
28+
29+
from .accelerate_runner import AccelerateRunner
30+
31+
with try_import() as ds:
32+
import deepspeed # noqa: F401
33+
34+
from .deepspeed_runner import DeepSpeedRunner
35+
36+
37+
class Runner(BaseRunner):
38+
r"""
39+
Dynamic runner class that selects the appropriate platform based on configuration.
40+
41+
This runner dynamically changes its class to combine with the appropriate platform
42+
(torch, accelerate, or deepspeed) based on the 'platform' configuration option.
43+
44+
Valid platform options are:
45+
46+
- "auto" (default)
47+
- "torch"
48+
- "accelerate"
49+
- "deepspeed"
50+
51+
Examples:
52+
>>> config = Config({"platform": "accelerate"})
53+
>>> runner = Runner(config)
54+
55+
See Also:
56+
- [`BaseRunner`][danling.runner.BaseRunner]: Base class for all runners.
57+
- [`TorchRunner`][danling.runner.TorchRunner]: Runner for PyTorch.
58+
- [`AccelerateRunner`][danling.runner.AccelerateRunner]: Runner for Accelerate.
59+
- [`DeepSpeedRunner`][danling.runner.DeepSpeedRunner]: Runner for DeepSpeed.
60+
"""
61+
62+
def __init__(self, config: Config) -> None:
63+
platform = config.get("platform", "auto").lower()
64+
65+
if platform == "auto":
66+
platform = "deepspeed" if ds.is_successful() else "torch"
67+
68+
if platform == "accelerate":
69+
ac.check()
70+
self.__class__ = type("AccelerateRunner", (self.__class__, AccelerateRunner), {})
71+
elif platform == "deepspeed":
72+
ds.check()
73+
self.__class__ = type("DeepSpeedRunner", (self.__class__, DeepSpeedRunner), {})
74+
elif platform == "torch":
75+
self.__class__ = type("TorchRunner", (self.__class__, TorchRunner), {})
76+
else:
77+
raise ValueError(f"Unknown platform: {platform}. Valid options are: torch, accelerate, deepspeed")
78+
79+
super().__init__(config)

‎danling/runner/torch_runner.py

+662-2
Large diffs are not rendered by default.

‎danling/runner/utils.py

+47-4
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,18 @@
2121

2222
import os
2323
import sys
24+
from collections.abc import Mapping
2425
from contextlib import suppress
2526
from datetime import datetime
2627
from enum import auto
2728
from functools import wraps
29+
from math import isnan
2830
from typing import Any
2931
from warnings import warn
3032

33+
import torch
34+
from chanfig import FlatDict, NestedDict
35+
3136
from danling.utils import base62
3237

3338
try:
@@ -49,13 +54,13 @@ class RunnerMode(StrEnum): # pylint: disable=too-few-public-methods
4954
5055
Attributes:
5156
train: Training mode.
52-
eval: Evaluation mode.
53-
inf: Inference mode.
57+
evaluate: Evaluation mode.
58+
infer: Inference mode.
5459
"""
5560

5661
train = auto()
57-
eval = auto()
58-
inf = auto()
62+
evaluate = auto()
63+
infer = auto()
5964

6065

6166
def get_time_str() -> str:
@@ -123,3 +128,41 @@ def wrapper(self, *args, **kwargs) -> Any | None:
123128
return None
124129

125130
return wrapper
131+
132+
133+
def format_result(result, format_spec: str = ".4f", depth: int = 0) -> str:
134+
longest_key = max(len(k) for k in result.keys())
135+
repr_list = [_format_result(result, format_spec)]
136+
for k, v in result.items():
137+
if isinstance(v, Mapping):
138+
initials = " " * (longest_key - len(k)) + "\t" * depth
139+
repr_list.append(f"{initials}{k}: {format_result(v, format_spec, depth + 1)}")
140+
return "\n".join(repr_list)
141+
142+
143+
def _format_result(result, format_spec: str = ".4f") -> str:
144+
repr_str = ""
145+
for k, v in result.items():
146+
if isinstance(v, (Mapping,)):
147+
continue
148+
padding = 1
149+
if isinstance(v, (float,)):
150+
is_negative = v < 0 if not isnan(v) else False
151+
v = format(v, format_spec) if not isnan(v) else " NaN "
152+
padding = padding if is_negative else padding + 1
153+
repr_str += f"\t{k}:{' ' * padding}{v}"
154+
return repr_str
155+
156+
157+
def to_device(data: Any, device: torch.device):
158+
if isinstance(data, list):
159+
return [to_device(i, device) for i in data]
160+
if isinstance(data, tuple):
161+
return tuple(to_device(i, device) for i in data)
162+
if isinstance(data, NestedDict):
163+
return NestedDict({k: to_device(v, device) for k, v in data.all_items()})
164+
if isinstance(data, dict):
165+
return FlatDict({k: to_device(v, device) for k, v in data.items()})
166+
if hasattr(data, "to"):
167+
return data.to(device)
168+
return data

‎danling/tensors/nested_tensor.py

+3
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,9 @@ def reshape(self, *shape) -> Tensor:
11461146

11471147
return self.tensor.reshape(*shape)
11481148

1149+
def __iter__(self):
1150+
return iter(self._storage)
1151+
11491152

11501153
NestedTensorFunc = TorchFuncRegistry()
11511154

‎danling/utils/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
except ImportError:
2323
from cached_property import cached_property # type: ignore
2424

25-
from . import defaults
2625
from .basex import Base58, Base62, Base64, BaseX, base58, base62, base64
2726
from .contextmanagers import debug
2827
from .decorators import catch, flexible_decorator, method_cache
@@ -55,5 +54,4 @@
5554
"base58",
5655
"base62",
5756
"base64",
58-
"defaults",
5957
]

‎demo/accelerate_imdb.py

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# DanLing
2+
# Copyright (C) 2022-Present DanLing
3+
4+
# This program is free software: you can redistribute it and/or modify
5+
# it under the terms of the following licenses:
6+
# - The Unlicense
7+
# - GNU Affero General Public License v3.0 or later
8+
# - GNU General Public License v2.0 or later
9+
# - BSD 4-Clause "Original" or "Old" License
10+
# - MIT License
11+
# - Apache License 2.0
12+
13+
# This program is distributed in the hope that it will be useful,
14+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
15+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
16+
# See the LICENSE file for more details.
17+
18+
import torch
19+
from chanfig import Registry
20+
from datasets import load_dataset
21+
from torch import nn, optim
22+
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
23+
24+
import danling as dl
25+
26+
OPTIMIZERS = Registry()
27+
OPTIMIZERS.register(optim.AdamW, "adamw")
28+
OPTIMIZERS.register(optim.SGD, "sgd")
29+
30+
31+
class IMDBConfig(dl.Config):
32+
epoch_end: int = 2
33+
log: bool = False
34+
tensorboard: bool = False
35+
log_interval: int = 1000
36+
score_split: str = "val"
37+
score_name: str = "loss"
38+
debug: bool = False
39+
patience: int = 1
40+
41+
def __init__(self):
42+
super().__init__()
43+
self.pretrained = "prajjwal1/bert-tiny"
44+
self.dataset.path = "stanfordnlp/imdb"
45+
self.dataloader.batch_size = 8
46+
self.optim.name = "adamw"
47+
self.optim.lr = 1e-3
48+
self.optim.weight_decay = 1e-4
49+
self.sched.strategy = "cosine"
50+
51+
def post(self):
52+
super().post()
53+
self.transformers = AutoConfig.from_pretrained(self.pretrained)
54+
self.experiment_name = f"{self.pretrained}_{self.optim.name}@{self.optim.lr}"
55+
56+
57+
class IMDBRunner(dl.AccelerateRunner):
58+
def __init__(self, config: dl.Config):
59+
super().__init__(config)
60+
61+
self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained)
62+
self.datasets.train = load_dataset(split="train", **self.dataset)
63+
self.datasets.val = load_dataset(split="train", **self.dataset)
64+
# only run on a few samples to speed up testing process
65+
self.datasets.train._data = self.datasets.train._data[:64]
66+
self.datasets.val._data = self.datasets.val._data[:64]
67+
self.datasets.train = self.preprocess_data(self.datasets.train)
68+
self.datasets.val = self.preprocess_data(self.datasets.val)
69+
70+
self.model = AutoModelForSequenceClassification.from_config(self.config.transformers)
71+
self.optimizer = OPTIMIZERS.build(params=self.model.parameters(), **self.optim)
72+
self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.sched)
73+
self.criterion = nn.CrossEntropyLoss()
74+
75+
self.metrics = dl.metrics.binary_metrics()
76+
self.meters.loss.reset()
77+
self.meters.time.reset()
78+
79+
def preprocess_data(self, dataset):
80+
def tokenization(example):
81+
example["text"] = self.tokenizer(example["text"], truncation=True, max_length=510)["input_ids"]
82+
return example
83+
84+
def transform(data):
85+
text = dl.NestedTensor(data.pop("text"))
86+
data["input_ids"] = text.tensor
87+
data["attention_mask"] = text.mask
88+
data["labels"] = torch.tensor(data.pop("label"))
89+
return data
90+
91+
dataset = dataset.map(tokenization, batched=True)
92+
dataset.set_transform(transform)
93+
dataset.__getitems__ = dataset.__getitem__
94+
return dataset
95+
96+
def train_step(self, data) -> torch.Tensor:
97+
with self.autocast(), self.accumulate():
98+
pred = self.model(**data)
99+
loss = pred["loss"]
100+
self.advance(loss)
101+
self.metrics.update(pred["logits"][:, 0], data["labels"])
102+
return pred, loss
103+
104+
def evaluate_step(self, data) -> torch.Tensor:
105+
pred = self.model(**data)
106+
loss = pred["loss"]
107+
self.metrics.update(pred["logits"][:, 0], data["labels"])
108+
return pred, loss
109+
110+
@staticmethod
111+
def collate_fn(batch):
112+
return batch
113+
114+
115+
if __name__ == "__main__":
116+
config = IMDBConfig()
117+
config.parse()
118+
with dl.debug(config.get("debug", False)):
119+
runner = IMDBRunner(config)
120+
runner.train()
121+
runner.evaluate(["val"])

‎demo/vision/torch_mnist.py ‎demo/torch_mnist.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# See the LICENSE file for more details.
1919

2020
import torchvision
21-
from chanfig import Config, Registry
21+
from chanfig import Registry
2222
from torch import nn, optim
2323

2424
import danling as dl
@@ -28,11 +28,10 @@
2828
OPTIMIZERS.register(optim.SGD, "sgd")
2929

3030

31-
class MNISTConfig(Config):
31+
class MNISTConfig(dl.Config):
3232
epoch_end: int = 2
3333
log: bool = False
3434
tensorboard: bool = False
35-
log_interval: int = 1000
3635
score_split: str = "val"
3736
score_name: str = "loss"
3837
debug: bool = False
@@ -43,18 +42,19 @@ def __init__(self):
4342
self.network.name = "resnet18"
4443
self.dataset.download = True
4544
self.dataset.root = "data"
46-
self.dataloader.batch_size = 8
45+
self.dataloader.batch_size = 256
4746
self.optim.name = "adamw"
4847
self.optim.lr = 1e-3
4948
self.optim.weight_decay = 1e-4
5049
self.sched.strategy = "cosine"
5150

5251
def post(self):
52+
super().post()
5353
self.experiment_name = f"{self.network.name}_{self.optim.name}@{self.optim.lr}"
5454

5555

5656
class MNISTRunner(dl.TorchRunner):
57-
def __init__(self, config: Config):
57+
def __init__(self, config: dl.Config):
5858
super().__init__(config)
5959

6060
self.dataset.transform = torchvision.transforms.Compose(
@@ -66,13 +66,13 @@ def __init__(self, config: Config):
6666
self.datasets.train = torchvision.datasets.MNIST(train=True, **self.dataset)
6767
self.datasets.val = torchvision.datasets.MNIST(train=False, **self.dataset)
6868
# only run on a few samples to speed up testing process
69-
self.datasets.train.data = self.datasets.train.data[:64]
70-
self.datasets.val.data = self.datasets.val.data[:64]
69+
self.datasets.train.data = self.datasets.train.data[:100]
70+
self.datasets.val.data = self.datasets.val.data[:100]
7171

7272
self.model = getattr(torchvision.models, self.network.name)(pretrained=False, num_classes=10)
7373
self.model.conv1 = nn.Conv2d(1, 64, 1, bias=False)
7474
self.optimizer = OPTIMIZERS.build(params=self.model.parameters(), **self.optim)
75-
self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.trainable_steps, **self.sched)
75+
self.scheduler = dl.optim.LRScheduler(self.optimizer, total_steps=self.total_steps, **self.sched)
7676
self.criterion = nn.CrossEntropyLoss()
7777

7878
self.metrics = dl.metrics.multiclass_metrics(num_classes=10)

‎docs/docs/runner/state.md ‎docs/docs/runner/config.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ authors:
44
date: 2022-05-04
55
---
66

7-
# RunnerState
7+
# Config
88

9-
::: danling.runner.state
9+
::: danling.runner.config

‎docs/docs/runner/deepspeed_runner.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
authors:
3+
- Zhiyuan Chen
4+
date: 2022-05-04
5+
---
6+
7+
# DeepSpeedRunner
8+
9+
::: danling.runner.DeepSpeedRunner

‎docs/docs/runner/runner.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
authors:
3+
- Zhiyuan Chen
4+
date: 2022-05-04
5+
---
6+
7+
# Runner
8+
9+
::: danling.runner.Runner

‎docs/docs/runner/torch_runner.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
---
2+
authors:
3+
- Zhiyuan Chen
4+
date: 2022-05-04
5+
---
6+
7+
# TorchRunner
8+
9+
::: danling.runner.TorchRunner

‎docs/mkdocs.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@ nav:
1111
- DanLing: index.md
1212
- Runner:
1313
- runner/index.md
14-
- RunnerState: runner/runner_state.md
15-
- BaseRunner: runner/base_runner.md
14+
- Config: runner/config.md
15+
- Runner: runner/runner.md
16+
- TorchRunner: runner/torch_runner.md
17+
- DeepSpeedRunner: runner/deepspeed_runner.md
1618
- AccelerateRunner: runner/accelerate_runner.md
19+
- BaseRunner: runner/base_runner.md
1720
- Utilities: runner/utils.md
1821
- Tensors:
1922
- NestedTensor: tensors/nested_tensor.md
@@ -196,6 +199,7 @@ plugins:
196199
- https://pytorch.org/docs/stable/objects.inv
197200
- https://pytorch.org/torcheval/stable/objects.inv
198201
- https://huggingface.co/docs/transformers/master/en/objects.inv
202+
- https://huggingface.co/docs/accelerate/master/en/objects.inv
199203
- https://chanfig.danling.org/objects.inv
200204
- https://lightning.ai/docs/torchmetrics/stable/objects.inv
201205
rendering:

‎pyproject.toml

+12-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ dependencies = [
4747
"strenum; python_version<'3.11'",
4848
"tqdm",
4949
]
50+
optional-dependencies.accelerate = [
51+
"accelerate",
52+
"torch",
53+
"torcheval",
54+
"torchmetrics",
55+
]
56+
optional-dependencies.deepspeed = [
57+
"deepspeed",
58+
"torch",
59+
"torcheval",
60+
"torchmetrics",
61+
]
5062
optional-dependencies.jax = [
5163
"flax",
5264
"jax",
@@ -55,7 +67,6 @@ optional-dependencies.tensorflow = [
5567
"tensorflow",
5668
]
5769
optional-dependencies.torch = [
58-
"accelerate",
5970
"torch",
6071
"torcheval",
6172
"torchmetrics",

‎requirements.txt

+12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
1+
# This file is for testing purposes only.
2+
# Please refer to pyproject.toml for the actual dependencies.
3+
4+
accelerate
15
cached-property; python_version < "3.8"
26
chanfig >= 0.0.96
7+
datasets
38
gitpython
49
lazy-imports
10+
portalocker>=2.0.0
511
strenum; python_version < "3.11"
12+
torch
13+
# torchdata
614
torcheval
15+
torchmetrics
16+
# torchtext
17+
torchvision
718
tqdm
19+
transformers

‎tests/optim/test_lr_scheduler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@
3434
class Test:
3535
optimizer = optim.SGD([{"params": torch.tensor([0])}], lr=1, momentum=0.9)
3636

37-
def _get_lrs(self, strategy, method, steps: int = 100, final_lr_ratio: float = 0.001):
37+
def _get_lrs(self, strategy, method, total_steps: int = 100, final_lr_ratio: float = 0.001):
3838
lrs = []
3939
scheduler = LRScheduler(
4040
self.optimizer,
41-
total_steps=steps,
41+
total_steps=total_steps,
4242
final_lr_ratio=final_lr_ratio,
4343
strategy=strategy,
4444
method=method,

‎tests/runner/test_base_runner.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
1818
# See the LICENSE file for more details.
1919

20-
from chanfig import Config as Config_
2120
from chanfig import NestedDict
2221

2322
import danling as dl
@@ -30,7 +29,7 @@ def init_distributed(self) -> None:
3029
pass
3130

3231

33-
class Config(Config_):
32+
class Config(dl.Config):
3433
__test__ = False
3534

3635
def __init__(self):
@@ -107,7 +106,7 @@ def test_results(self):
107106

108107
def test_conflict(self):
109108
runner = self.runner
110-
state = runner.state
109+
config = runner.config
111110
runner.conflict = False
112111
assert not runner.conflict
113-
assert state.conflict == 1
112+
assert config.conflict == 1

‎danling/utils/defaults.py ‎tests/runner/test_imdb.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,19 @@
1717
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
1818
# See the LICENSE file for more details.
1919

20-
DEFAULT_EXCLUDE = (KeyboardInterrupt, SystemExit)
20+
import sys
21+
22+
sys.path.insert(0, "demo")
23+
24+
from accelerate_imdb import IMDBConfig, IMDBRunner # noqa: E402
25+
26+
27+
class Test:
28+
config = IMDBConfig().boot()
29+
runner = IMDBRunner(config)
30+
31+
def test_train(self):
32+
self.runner.train()
33+
34+
def test_evaluate(self):
35+
self.runner.evaluate(["val"])

‎tests/runner/test_mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import sys
2121

22-
sys.path.insert(0, "demo/vision")
22+
sys.path.insert(0, "demo")
2323

2424
from torch_mnist import MNISTConfig, MNISTRunner # noqa: E402
2525

0 commit comments

Comments
 (0)
Please sign in to comment.