Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] Accelerate Multi Node Startup Time #15650

Merged
merged 21 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/app_multi_node/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ lightning run app train_lite.py

Using Lite, you retain control over your loops while accessing in a minimal way all Lightning distributed strategies.

## Multi Node with PyTorch Lightning
## Multi Node with Lightning Trainer

Lightning supports running PyTorch Lightning from a script or within a Lightning Work.
Lightning supports running Lightning Trainer from a script or within a Lightning Work.

You can either run a script directly

Expand Down
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642))


- Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650))


## [1.8.0] - 2022-11-01

Expand Down
33 changes: 21 additions & 12 deletions src/lightning_app/components/database/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import tempfile
import threading
import traceback
from typing import List, Optional, Type, Union

import uvicorn
Expand Down Expand Up @@ -36,6 +37,9 @@ def install_signal_handlers(self):
"""Ignore Uvicorn Signal Handlers."""


_lock = threading.Lock()


class Database(LightningWork):
def __init__(
self,
Expand Down Expand Up @@ -146,25 +150,29 @@ class CounterModel(SQLModel, table=True):
self._exit_event = None

def store_database(self):
with tempfile.TemporaryDirectory() as tmpdir:
tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))
try:
with tempfile.TemporaryDirectory() as tmpdir:
tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename))

source = sqlite3.connect(self.db_filename)
dest = sqlite3.connect(tmp_db_filename)
source = sqlite3.connect(self.db_filename)
dest = sqlite3.connect(tmp_db_filename)

source.backup(dest)
source.backup(dest)

source.close()
dest.close()
source.close()
dest.close()

drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)
drive.put(os.path.basename(tmp_db_filename))
drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir)
drive.put(os.path.basename(tmp_db_filename))

print("Stored the database to the Drive.")
print("Stored the database to the Drive.")
except Exception:
print(traceback.print_exc())

def periodic_store_database(self, store_interval):
while not self._exit_event.is_set():
self.store_database()
with _lock:
self.store_database()
self._exit_event.wait(store_interval)

def run(self, token: Optional[str] = None) -> None:
Expand Down Expand Up @@ -210,4 +218,5 @@ def db_url(self) -> Optional[str]:

def on_exit(self):
self._exit_event.set()
self.store_database()
with _lock:
self.store_database()
50 changes: 17 additions & 33 deletions src/lightning_app/components/multi_node/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from lightning_app import structures
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.enum import WorkStageStatus
from lightning_app.utilities.packaging.cloud_compute import CloudCompute


Expand Down Expand Up @@ -52,46 +51,31 @@ def run(
work_kwargs: Keywords arguments to be provided to the work on instantiation.
"""
super().__init__()
self.ws = structures.List()
self._work_cls = work_cls
self.num_nodes = num_nodes
self._cloud_compute = cloud_compute
self._work_args = work_args
self._work_kwargs = work_kwargs
self.has_started = False
self.ws = structures.List(
*[
work_cls(
*work_args,
cloud_compute=cloud_compute,
**work_kwargs,
parallel=True,
)
for _ in range(num_nodes)
]
)

def run(self) -> None:
if not self.has_started:

# 1. Create & start the works
if not self.ws:
for node_rank in range(self.num_nodes):
self.ws.append(
self._work_cls(
*self._work_args,
cloud_compute=self._cloud_compute,
**self._work_kwargs,
parallel=True,
)
)

# Starting node `node_rank`` ...
self.ws[-1].start()

# 2. Wait for all machines to be started !
if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws):
return

self.has_started = True
# 1. Wait for all works to be started !
if not all(w.internal_ip for w in self.ws):
return

# Loop over all node machines
for node_rank in range(self.num_nodes):
# 2. Loop over all node machines
for node_rank in range(len(self.ws)):

# 3. Run the user code in a distributed way !
self.ws[node_rank].run(
main_address=self.ws[0].internal_ip,
main_port=self.ws[0].port,
num_nodes=self.num_nodes,
num_nodes=len(self.ws),
node_rank=node_rank,
)

Expand Down
10 changes: 10 additions & 0 deletions src/lightning_app/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ def _run(self) -> bool:
self._original_state = deepcopy(self.state)
done = False

self._start_with_flow_works()

if self.should_publish_changes_to_api and self.api_publish_state_queue:
logger.debug("Publishing the state with changes")
# Push two states to optimize start in the cloud.
Expand Down Expand Up @@ -668,3 +670,11 @@ def _send_flow_to_work_deltas(self, state) -> None:
if deep_diff:
logger.debug(f"Sending deep_diff to {w.name} : {deep_diff}")
self.flow_to_work_delta_queues[w.name].put(deep_diff)

def _start_with_flow_works(self):
for w in self.works:
if w._start_with_flow:
parallel = w.parallel
w._parallel = True
w.start()
w._parallel = parallel
129 changes: 64 additions & 65 deletions src/lightning_app/runners/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,78 +142,77 @@ def dispatch(
v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0"))

works: List[V1Work] = []
for flow in self.app.flows:
for work in flow.works(recurse=False):
if not work._start_with_flow:
continue

work_requirements = "\n".join(work.cloud_build_config.requirements)
build_spec = V1BuildSpec(
commands=work.cloud_build_config.build_commands(),
python_dependencies=V1PythonDependencyInfo(
package_manager=V1PackageManager.PIP, packages=work_requirements
),
image=work.cloud_build_config.image,
)
user_compute_config = V1UserRequestedComputeConfig(
name=work.cloud_compute.name,
count=1,
disk_size=work.cloud_compute.disk_size,
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)
for work in self.app.works:
if not work._start_with_flow:
continue

drive_specs: List[V1LightningworkDrives] = []
for drive_attr_name, drive in [
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
]:
if drive.protocol == "lit://":
drive_type = V1DriveType.NO_MOUNT_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown drive protocol `{drive.protocol}`. Please verify this "
f"drive type has been configured for use in the cloud dispatcher."
)
work_requirements = "\n".join(work.cloud_build_config.requirements)
build_spec = V1BuildSpec(
commands=work.cloud_build_config.build_commands(),
python_dependencies=V1PythonDependencyInfo(
package_manager=V1PackageManager.PIP, packages=work_requirements
),
image=work.cloud_build_config.image,
)
user_compute_config = V1UserRequestedComputeConfig(
name=work.cloud_compute.name,
count=1,
disk_size=work.cloud_compute.disk_size,
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)

drive_specs.append(
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name=f"{work.name}.{drive_attr_name}",
),
spec=V1DriveSpec(
drive_type=drive_type,
source_type=source_type,
source=f"{drive.protocol}{drive.id}",
),
status=V1DriveStatus(),
drive_specs: List[V1LightningworkDrives] = []
for drive_attr_name, drive in [
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
]:
if drive.protocol == "lit://":
drive_type = V1DriveType.NO_MOUNT_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown drive protocol `{drive.protocol}`. Please verify this "
f"drive type has been configured for use in the cloud dispatcher."
)

drive_specs.append(
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name=f"{work.name}.{drive_attr_name}",
),
spec=V1DriveSpec(
drive_type=drive_type,
source_type=source_type,
source=f"{drive.protocol}{drive.id}",
),
mount_location=str(drive.root_folder),
status=V1DriveStatus(),
),
)
mount_location=str(drive.root_folder),
),
)

# TODO: Move this to the CloudCompute class and update backend
if work.cloud_compute.mounts is not None:
mounts = work.cloud_compute.mounts
if isinstance(mounts, Mount):
mounts = [mounts]
for mount in mounts:
drive_specs.append(
_create_mount_drive_spec(
work_name=work.name,
mount=mount,
)
# TODO: Move this to the CloudCompute class and update backend
if work.cloud_compute.mounts is not None:
mounts = work.cloud_compute.mounts
if isinstance(mounts, Mount):
mounts = [mounts]
for mount in mounts:
drive_specs.append(
_create_mount_drive_spec(
work_name=work.name,
mount=mount,
)
)

random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
work_spec = V1LightningworkSpec(
build_spec=build_spec,
drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
works.append(V1Work(name=work.name, spec=work_spec))
random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
work_spec = V1LightningworkSpec(
build_spec=build_spec,
drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
works.append(V1Work(name=work.name, spec=work_spec))

# We need to collect a spec for each flow that contains a frontend so that the backend knows
# for which flows it needs to start servers by invoking the cli (see the serve_frontend() method below)
Expand Down
6 changes: 2 additions & 4 deletions src/lightning_app/utilities/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ class ProxyWorkRun:
caller_queue: "BaseQueue"

def __post_init__(self):
self.cache_calls = self.work.cache_calls
self.parallel = self.work.parallel
self.work_state = None

def __call__(self, *args, **kwargs):
Expand All @@ -123,7 +121,7 @@ def __call__(self, *args, **kwargs):

# The if/else conditions are left un-compressed to simplify readability
# for the readers.
if self.cache_calls:
if self.work.cache_calls:
if not entered or stopped_on_sigterm:
_send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash)
else:
Expand All @@ -137,7 +135,7 @@ def __call__(self, *args, **kwargs):
# the previous task has completed and we can re-queue the next one.
# overriding the return value for next loop iteration.
_send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash)
if not self.parallel:
if not self.work.parallel:
raise CacheMissException("Task never called before. Triggered now")

def _validate_call_args(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
Expand Down
11 changes: 7 additions & 4 deletions tests/tests_app/components/database/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import tempfile
import time
import traceback
from pathlib import Path
from time import sleep
from typing import List, Optional
Expand Down Expand Up @@ -197,7 +198,9 @@ def run(self):
assert len(self._client.select_all()) == 1
self._exit()

with tempfile.TemporaryDirectory() as tmpdir:

app = LightningApp(Flow(tmpdir))
MultiProcessRuntime(app).dispatch()
try:
with tempfile.TemporaryDirectory() as tmpdir:
app = LightningApp(Flow(tmpdir))
MultiProcessRuntime(app).dispatch()
except Exception:
print(traceback.print_exc())
2 changes: 1 addition & 1 deletion tests/tests_app/core/test_lightning_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

class WorkA(LightningWork):
def __init__(self):
super().__init__(parallel=True)
super().__init__(parallel=True, start_with_flow=False)
self.var_a = 0
self.drive = Drive("lit://test_app_state_api")

Expand Down
Loading