Skip to content

Commit aefb9ab

Browse files
authored
(app) Introduce LightningTrainingComponent (#13830)
1 parent cd92e35 commit aefb9ab

File tree

21 files changed

+485
-8
lines changed

21 files changed

+485
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,4 @@ src/lightning_app/ui/*
163163
*examples/template_react_ui*
164164
hars*
165165
artifacts/*
166+
*docs/examples*

docs/source-app/api_reference/components.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ ___________________
2020

2121
~python.popen.PopenPythonScript
2222
~python.tracer.TracerPythonScript
23+
~training.LightningTrainingComponent
2324
~serve.gradio.ServeGradio
2425
~serve.serve.ModelInferenceAPI

examples/app_multi_node/app.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from lightning import LightningApp
2+
from lightning.app.components.training import LightningTrainingComponent
3+
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
4+
5+
app = LightningApp(
6+
LightningTrainingComponent(
7+
"train.py",
8+
num_nodes=2,
9+
cloud_compute=CloudCompute("gpu-fast-multi"),
10+
),
11+
)
File renamed without changes.
File renamed without changes.

examples/app_multi_node/train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from lightning.pytorch import Trainer
2+
from lightning.pytorch.demos.boring_classes import BoringModel
3+
4+
if __name__ == "__main__":
5+
model = BoringModel()
6+
trainer = Trainer(max_epochs=1)
7+
trainer.fit(model)

src/lightning_app/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602))
1212

13+
- Adds `LightningTrainingComponent`. `LightningTrainingComponent` orchestrates multi-node training in the cloud ([#13830](https://github.com/Lightning-AI/lightning/pull/13830))
14+
1315
### Changed
1416

1517
- Update the Lightning App docs ([#13537](https://github.com/Lightning-AI/lightning/pull/13537))

src/lightning_app/components/python/tracer.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,24 @@
22
import os
33
import signal
44
import sys
5-
from typing import Any, Dict, List, Optional, Union
5+
from copy import deepcopy
6+
from typing import Any, Dict, List, Optional, TypedDict, Union
67

78
from lightning_app import LightningWork
9+
from lightning_app.storage.drive import Drive
810
from lightning_app.storage.payload import Payload
911
from lightning_app.utilities.app_helpers import _collect_child_process_pids
12+
from lightning_app.utilities.packaging.tarfile import clean_tarfile, extract_tarfile
1013
from lightning_app.utilities.tracer import Tracer
1114

1215
logger = logging.getLogger(__name__)
1316

1417

18+
class Code(TypedDict):
19+
drive: Drive
20+
name: str
21+
22+
1523
class TracerPythonScript(LightningWork):
1624
def on_before_run(self):
1725
"""Called before the python script is executed."""
@@ -31,6 +39,7 @@ def __init__(
3139
script_args: Optional[Union[list, str]] = None,
3240
outputs: Optional[List[str]] = None,
3341
env: Optional[Dict] = None,
42+
code: Optional[Code] = None,
3443
**kwargs,
3544
):
3645
"""The TracerPythonScript class enables to easily run a python script.
@@ -97,17 +106,46 @@ def __init__(
97106
if isinstance(script_args, str):
98107
script_args = script_args.split(" ")
99108
self.script_args = script_args if script_args else []
109+
self.original_args = deepcopy(self.script_args)
100110
self.env = env
101111
self.outputs = outputs or []
102112
for name in self.outputs:
103113
setattr(self, name, None)
114+
self.params = None
115+
self.drive = code.get("drive") if code else None
116+
self.code_name = code.get("name") if code else None
117+
self.restart_count = 0
118+
119+
def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[int] = None, **kwargs):
120+
"""
121+
Arguments:
122+
params: A dictionary of arguments to be be added to script_args.
123+
restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks.
124+
"""
125+
if restart_count:
126+
self.restart_count = restart_count
127+
128+
if params:
129+
self.params = params
130+
self.script_args = self.original_args + [self._to_script_args(k, v) for k, v in params.items()]
131+
132+
if self.drive:
133+
assert self.code_name
134+
if os.path.exists(self.code_name):
135+
clean_tarfile(self.code_name, "r:gz")
136+
137+
if self.code_name in self.drive.list():
138+
self.drive.get(self.code_name)
139+
extract_tarfile(self.code_name, ".", "r:gz")
104140

105-
def run(self, **kwargs):
106141
if not os.path.exists(self.script_path):
107142
raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.")
143+
108144
kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()}
145+
109146
init_globals = globals()
110147
init_globals.update(kwargs)
148+
111149
self.on_before_run()
112150
env_copy = os.environ.copy()
113151
if self.env:
@@ -125,5 +163,11 @@ def on_exit(self):
125163
for child_pid in _collect_child_process_pids(os.getpid()):
126164
os.kill(child_pid, signal.SIGTERM)
127165

166+
@staticmethod
167+
def _to_script_args(k: str, v: str) -> str:
168+
if k.startswith("--"):
169+
return f"{k}={v}"
170+
return f"--{k}={v}"
171+
128172

129173
__all__ = ["TracerPythonScript"]
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import logging
2+
import os
3+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
4+
5+
from lightning import CloudCompute
6+
from lightning_app import LightningFlow, structures
7+
from lightning_app.components.python import TracerPythonScript
8+
from lightning_app.storage.path import Path
9+
10+
_logger = logging.getLogger(__name__)
11+
12+
13+
class PyTorchLightningScriptRunner(TracerPythonScript):
14+
def __init__(
15+
self,
16+
script_path: str,
17+
script_args: Optional[Union[list, str]] = None,
18+
node_rank: int = 1,
19+
num_nodes: int = 1,
20+
sanity_serving: bool = False,
21+
cloud_compute: Optional[CloudCompute] = None,
22+
parallel: bool = True,
23+
raise_exception: bool = True,
24+
env: Optional[Dict[str, Any]] = None,
25+
**kwargs,
26+
):
27+
super().__init__(
28+
script_path,
29+
script_args,
30+
raise_exception=raise_exception,
31+
parallel=parallel,
32+
cloud_compute=cloud_compute,
33+
**kwargs,
34+
)
35+
self.node_rank = node_rank
36+
self.num_nodes = num_nodes
37+
self.best_model_path = None
38+
self.best_model_score = None
39+
self.monitor = None
40+
self.sanity_serving = sanity_serving
41+
self.has_finished = False
42+
self.env = env
43+
44+
def configure_tracer(self):
45+
from pytorch_lightning import Trainer
46+
47+
tracer = super().configure_tracer()
48+
tracer.add_traced(Trainer, "__init__", pre_fn=self._trainer_init_pre_middleware)
49+
return tracer
50+
51+
def run(self, internal_urls: Optional[List[Tuple[str, str]]] = None, **kwargs) -> None:
52+
if not internal_urls:
53+
# Note: This is called only once.
54+
_logger.info(f"The node {self.node_rank} started !")
55+
return None
56+
57+
if self.env:
58+
os.environ.update(self.env)
59+
60+
distributed_env_vars = {
61+
"MASTER_ADDR": internal_urls[0][0],
62+
"MASTER_PORT": str(internal_urls[0][1]),
63+
"NODE_RANK": str(self.node_rank),
64+
"PL_TRAINER_NUM_NODES": str(self.num_nodes),
65+
"PL_TRAINER_DEVICES": "auto",
66+
"PL_TRAINER_ACCELERATOR": "auto",
67+
}
68+
69+
os.environ.update(distributed_env_vars)
70+
return super().run(**kwargs)
71+
72+
def on_after_run(self, script_globals):
73+
from pytorch_lightning import Trainer
74+
from pytorch_lightning.cli import LightningCLI
75+
76+
for v in script_globals.values():
77+
if isinstance(v, LightningCLI):
78+
trainer = v.trainer
79+
break
80+
elif isinstance(v, Trainer):
81+
trainer = v
82+
break
83+
else:
84+
raise RuntimeError("No trainer instance found.")
85+
86+
self.monitor = trainer.checkpoint_callback.monitor
87+
88+
if trainer.checkpoint_callback.best_model_score:
89+
self.best_model_path = Path(trainer.checkpoint_callback.best_model_path)
90+
self.best_model_score = float(trainer.checkpoint_callback.best_model_score)
91+
else:
92+
self.best_model_path = Path(trainer.checkpoint_callback.last_model_path)
93+
94+
self.has_finished = True
95+
96+
def _trainer_init_pre_middleware(self, trainer, *args, **kwargs):
97+
if self.node_rank != 0:
98+
return {}, args, kwargs
99+
100+
from pytorch_lightning.serve import ServableModuleValidator
101+
102+
callbacks = kwargs.get("callbacks", [])
103+
if self.sanity_serving:
104+
callbacks = callbacks + [ServableModuleValidator()]
105+
kwargs["callbacks"] = callbacks
106+
return {}, args, kwargs
107+
108+
@property
109+
def is_running_in_cloud(self) -> bool:
110+
return "LIGHTNING_APP_STATE_URL" in os.environ
111+
112+
113+
class LightningTrainingComponent(LightningFlow):
114+
def __init__(
115+
self,
116+
script_path: str,
117+
script_args: Optional[Union[list, str]] = None,
118+
num_nodes: int = 1,
119+
cloud_compute: CloudCompute = CloudCompute("default"),
120+
sanity_serving: bool = False,
121+
script_runner: Type[TracerPythonScript] = PyTorchLightningScriptRunner,
122+
**script_runner_kwargs,
123+
):
124+
"""This component enables performing distributed multi-node multi-device training.
125+
126+
Example::
127+
128+
from lightning import LightningApp
129+
from lightning.app.components.training import LightningTrainingComponent
130+
from lightning.app.utilities.packaging.cloud_compute import CloudCompute
131+
132+
app = LightningApp(
133+
LightningTrainingComponent(
134+
"train.py",
135+
num_nodes=2,
136+
cloud_compute=CloudCompute("gpu"),
137+
),
138+
)
139+
140+
Arguments:
141+
script_path: Path to the script to be executed.
142+
script_args: The arguments to be pass to the script.
143+
num_nodes: Number of nodes.
144+
cloud_compute: The cloud compute object used in the cloud.
145+
sanity_serving: Whether to validate that the model correctly implements
146+
the ServableModule API
147+
"""
148+
super().__init__()
149+
self.ws = structures.List()
150+
self.has_initialized = False
151+
self.script_path = script_path
152+
self.script_args = script_args
153+
self.num_nodes = num_nodes
154+
self._cloud_compute = cloud_compute # TODO: Add support for cloudCompute
155+
self.sanity_serving = sanity_serving
156+
self._script_runner = script_runner
157+
self._script_runner_kwargs = script_runner_kwargs
158+
159+
def run(self, **run_kwargs):
160+
if not self.has_initialized:
161+
for node_rank in range(self.num_nodes):
162+
self.ws.append(
163+
self._script_runner(
164+
script_path=self.script_path,
165+
script_args=self.script_args,
166+
cloud_compute=self._cloud_compute,
167+
node_rank=node_rank,
168+
sanity_serving=self.sanity_serving,
169+
num_nodes=self.num_nodes,
170+
**self._script_runner_kwargs,
171+
)
172+
)
173+
174+
self.has_initialized = True
175+
176+
for work in self.ws:
177+
if all(w.internal_ip for w in self.ws):
178+
internal_urls = [(w.internal_ip, w.port) for w in self.ws]
179+
work.run(internal_urls=internal_urls, **run_kwargs)
180+
if all(w.has_finished for w in self.ws):
181+
for w in self.ws:
182+
w.stop()
183+
else:
184+
work.run()
185+
186+
@property
187+
def best_model_score(self) -> Optional[float]:
188+
return self.ws[0].best_model_score
189+
190+
@property
191+
def best_model_paths(self) -> List[Optional[Path]]:
192+
return [self.ws[node_idx].best_mode_path for node_idx in range(len(self.ws))]

0 commit comments

Comments
 (0)