diff --git a/examples/app_multi_node/README.md b/examples/app_multi_node/README.md index 23e7afa23d68e..0fd2f369bb786 100644 --- a/examples/app_multi_node/README.md +++ b/examples/app_multi_node/README.md @@ -28,9 +28,9 @@ lightning run app train_lite.py Using Lite, you retain control over your loops while accessing in a minimal way all Lightning distributed strategies. -## Multi Node with PyTorch Lightning +## Multi Node with Lightning Trainer -Lightning supports running PyTorch Lightning from a script or within a Lightning Work. +Lightning supports running Lightning Trainer from a script or within a Lightning Work. You can either run a script directly diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index b08c7edae7bf7..915545f8677b1 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -74,6 +74,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bi-directional queues sending delta with Drive Component name changes ([#15642](https://github.com/Lightning-AI/lightning/pull/15642)) +- Fixed CloudRuntime works collection with structures and accelerated multi node startup time ([#15650](https://github.com/Lightning-AI/lightning/pull/15650)) + ## [1.8.0] - 2022-11-01 diff --git a/src/lightning_app/components/database/server.py b/src/lightning_app/components/database/server.py index a5499aaae17b8..01bd8f3b12033 100644 --- a/src/lightning_app/components/database/server.py +++ b/src/lightning_app/components/database/server.py @@ -4,6 +4,7 @@ import sys import tempfile import threading +import traceback from typing import List, Optional, Type, Union import uvicorn @@ -36,6 +37,9 @@ def install_signal_handlers(self): """Ignore Uvicorn Signal Handlers.""" +_lock = threading.Lock() + + class Database(LightningWork): def __init__( self, @@ -146,25 +150,29 @@ class CounterModel(SQLModel, table=True): self._exit_event = None def store_database(self): - with tempfile.TemporaryDirectory() as tmpdir: - tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename)) + try: + with tempfile.TemporaryDirectory() as tmpdir: + tmp_db_filename = os.path.join(tmpdir, os.path.basename(self.db_filename)) - source = sqlite3.connect(self.db_filename) - dest = sqlite3.connect(tmp_db_filename) + source = sqlite3.connect(self.db_filename) + dest = sqlite3.connect(tmp_db_filename) - source.backup(dest) + source.backup(dest) - source.close() - dest.close() + source.close() + dest.close() - drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir) - drive.put(os.path.basename(tmp_db_filename)) + drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir) + drive.put(os.path.basename(tmp_db_filename)) - print("Stored the database to the Drive.") + print("Stored the database to the Drive.") + except Exception: + print(traceback.print_exc()) def periodic_store_database(self, store_interval): while not self._exit_event.is_set(): - self.store_database() + with _lock: + self.store_database() self._exit_event.wait(store_interval) def run(self, token: Optional[str] = None) -> None: @@ -210,4 +218,5 @@ def db_url(self) -> Optional[str]: def on_exit(self): self._exit_event.set() - self.store_database() + with _lock: + self.store_database() diff --git a/src/lightning_app/components/multi_node/base.py b/src/lightning_app/components/multi_node/base.py index 02adf218d3e36..4f2005771212a 100644 --- a/src/lightning_app/components/multi_node/base.py +++ b/src/lightning_app/components/multi_node/base.py @@ -3,7 +3,6 @@ from lightning_app import structures from lightning_app.core.flow import LightningFlow from lightning_app.core.work import LightningWork -from lightning_app.utilities.enum import WorkStageStatus from lightning_app.utilities.packaging.cloud_compute import CloudCompute @@ -52,46 +51,31 @@ def run( work_kwargs: Keywords arguments to be provided to the work on instantiation. """ super().__init__() - self.ws = structures.List() - self._work_cls = work_cls - self.num_nodes = num_nodes - self._cloud_compute = cloud_compute - self._work_args = work_args - self._work_kwargs = work_kwargs - self.has_started = False + self.ws = structures.List( + *[ + work_cls( + *work_args, + cloud_compute=cloud_compute, + **work_kwargs, + parallel=True, + ) + for _ in range(num_nodes) + ] + ) def run(self) -> None: - if not self.has_started: - - # 1. Create & start the works - if not self.ws: - for node_rank in range(self.num_nodes): - self.ws.append( - self._work_cls( - *self._work_args, - cloud_compute=self._cloud_compute, - **self._work_kwargs, - parallel=True, - ) - ) - - # Starting node `node_rank`` ... - self.ws[-1].start() - - # 2. Wait for all machines to be started ! - if not all(w.status.stage == WorkStageStatus.STARTED for w in self.ws): - return - - self.has_started = True + # 1. Wait for all works to be started ! + if not all(w.internal_ip for w in self.ws): + return - # Loop over all node machines - for node_rank in range(self.num_nodes): + # 2. Loop over all node machines + for node_rank in range(len(self.ws)): # 3. Run the user code in a distributed way ! self.ws[node_rank].run( main_address=self.ws[0].internal_ip, main_port=self.ws[0].port, - num_nodes=self.num_nodes, + num_nodes=len(self.ws), node_rank=node_rank, ) diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 9620f4bb96cc6..255f498507f67 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -472,6 +472,8 @@ def _run(self) -> bool: self._original_state = deepcopy(self.state) done = False + self._start_with_flow_works() + if self.should_publish_changes_to_api and self.api_publish_state_queue: logger.debug("Publishing the state with changes") # Push two states to optimize start in the cloud. @@ -668,3 +670,11 @@ def _send_flow_to_work_deltas(self, state) -> None: if deep_diff: logger.debug(f"Sending deep_diff to {w.name} : {deep_diff}") self.flow_to_work_delta_queues[w.name].put(deep_diff) + + def _start_with_flow_works(self): + for w in self.works: + if w._start_with_flow: + parallel = w.parallel + w._parallel = True + w.start() + w._parallel = parallel diff --git a/src/lightning_app/runners/cloud.py b/src/lightning_app/runners/cloud.py index c551c1c76ec57..1011ba64463b8 100644 --- a/src/lightning_app/runners/cloud.py +++ b/src/lightning_app/runners/cloud.py @@ -142,78 +142,77 @@ def dispatch( v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0")) works: List[V1Work] = [] - for flow in self.app.flows: - for work in flow.works(recurse=False): - if not work._start_with_flow: - continue - - work_requirements = "\n".join(work.cloud_build_config.requirements) - build_spec = V1BuildSpec( - commands=work.cloud_build_config.build_commands(), - python_dependencies=V1PythonDependencyInfo( - package_manager=V1PackageManager.PIP, packages=work_requirements - ), - image=work.cloud_build_config.image, - ) - user_compute_config = V1UserRequestedComputeConfig( - name=work.cloud_compute.name, - count=1, - disk_size=work.cloud_compute.disk_size, - preemptible=work.cloud_compute.preemptible, - shm_size=work.cloud_compute.shm_size, - ) + for work in self.app.works: + if not work._start_with_flow: + continue - drive_specs: List[V1LightningworkDrives] = [] - for drive_attr_name, drive in [ - (k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive) - ]: - if drive.protocol == "lit://": - drive_type = V1DriveType.NO_MOUNT_S3 - source_type = V1SourceType.S3 - else: - raise RuntimeError( - f"unknown drive protocol `{drive.protocol}`. Please verify this " - f"drive type has been configured for use in the cloud dispatcher." - ) + work_requirements = "\n".join(work.cloud_build_config.requirements) + build_spec = V1BuildSpec( + commands=work.cloud_build_config.build_commands(), + python_dependencies=V1PythonDependencyInfo( + package_manager=V1PackageManager.PIP, packages=work_requirements + ), + image=work.cloud_build_config.image, + ) + user_compute_config = V1UserRequestedComputeConfig( + name=work.cloud_compute.name, + count=1, + disk_size=work.cloud_compute.disk_size, + preemptible=work.cloud_compute.preemptible, + shm_size=work.cloud_compute.shm_size, + ) - drive_specs.append( - V1LightningworkDrives( - drive=V1Drive( - metadata=V1Metadata( - name=f"{work.name}.{drive_attr_name}", - ), - spec=V1DriveSpec( - drive_type=drive_type, - source_type=source_type, - source=f"{drive.protocol}{drive.id}", - ), - status=V1DriveStatus(), + drive_specs: List[V1LightningworkDrives] = [] + for drive_attr_name, drive in [ + (k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive) + ]: + if drive.protocol == "lit://": + drive_type = V1DriveType.NO_MOUNT_S3 + source_type = V1SourceType.S3 + else: + raise RuntimeError( + f"unknown drive protocol `{drive.protocol}`. Please verify this " + f"drive type has been configured for use in the cloud dispatcher." + ) + + drive_specs.append( + V1LightningworkDrives( + drive=V1Drive( + metadata=V1Metadata( + name=f"{work.name}.{drive_attr_name}", + ), + spec=V1DriveSpec( + drive_type=drive_type, + source_type=source_type, + source=f"{drive.protocol}{drive.id}", ), - mount_location=str(drive.root_folder), + status=V1DriveStatus(), ), - ) + mount_location=str(drive.root_folder), + ), + ) - # TODO: Move this to the CloudCompute class and update backend - if work.cloud_compute.mounts is not None: - mounts = work.cloud_compute.mounts - if isinstance(mounts, Mount): - mounts = [mounts] - for mount in mounts: - drive_specs.append( - _create_mount_drive_spec( - work_name=work.name, - mount=mount, - ) + # TODO: Move this to the CloudCompute class and update backend + if work.cloud_compute.mounts is not None: + mounts = work.cloud_compute.mounts + if isinstance(mounts, Mount): + mounts = [mounts] + for mount in mounts: + drive_specs.append( + _create_mount_drive_spec( + work_name=work.name, + mount=mount, ) + ) - random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) - work_spec = V1LightningworkSpec( - build_spec=build_spec, - drives=drive_specs, - user_requested_compute_config=user_compute_config, - network_config=[V1NetworkConfig(name=random_name, port=work.port)], - ) - works.append(V1Work(name=work.name, spec=work_spec)) + random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) + work_spec = V1LightningworkSpec( + build_spec=build_spec, + drives=drive_specs, + user_requested_compute_config=user_compute_config, + network_config=[V1NetworkConfig(name=random_name, port=work.port)], + ) + works.append(V1Work(name=work.name, spec=work_spec)) # We need to collect a spec for each flow that contains a frontend so that the backend knows # for which flows it needs to start servers by invoking the cli (see the serve_frontend() method below) diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index 5ec99f222e47d..07b03da7d9201 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -103,8 +103,6 @@ class ProxyWorkRun: caller_queue: "BaseQueue" def __post_init__(self): - self.cache_calls = self.work.cache_calls - self.parallel = self.work.parallel self.work_state = None def __call__(self, *args, **kwargs): @@ -123,7 +121,7 @@ def __call__(self, *args, **kwargs): # The if/else conditions are left un-compressed to simplify readability # for the readers. - if self.cache_calls: + if self.work.cache_calls: if not entered or stopped_on_sigterm: _send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash) else: @@ -137,7 +135,7 @@ def __call__(self, *args, **kwargs): # the previous task has completed and we can re-queue the next one. # overriding the return value for next loop iteration. _send_data_to_caller_queue(self, self.work, self.caller_queue, data, call_hash) - if not self.parallel: + if not self.work.parallel: raise CacheMissException("Task never called before. Triggered now") def _validate_call_args(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: diff --git a/tests/tests_app/components/database/test_client_server.py b/tests/tests_app/components/database/test_client_server.py index 6ebec90ff9b1e..7b193d8f74c20 100644 --- a/tests/tests_app/components/database/test_client_server.py +++ b/tests/tests_app/components/database/test_client_server.py @@ -2,6 +2,7 @@ import sys import tempfile import time +import traceback from pathlib import Path from time import sleep from typing import List, Optional @@ -197,7 +198,9 @@ def run(self): assert len(self._client.select_all()) == 1 self._exit() - with tempfile.TemporaryDirectory() as tmpdir: - - app = LightningApp(Flow(tmpdir)) - MultiProcessRuntime(app).dispatch() + try: + with tempfile.TemporaryDirectory() as tmpdir: + app = LightningApp(Flow(tmpdir)) + MultiProcessRuntime(app).dispatch() + except Exception: + print(traceback.print_exc()) diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index a0069f1314841..d81c72c06f071 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -42,7 +42,7 @@ class WorkA(LightningWork): def __init__(self): - super().__init__(parallel=True) + super().__init__(parallel=True, start_with_flow=False) self.var_a = 0 self.drive = Drive("lit://test_app_state_api") diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index d95cac9899761..1b438f14632bb 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -247,10 +247,9 @@ def test_get_component_by_name_raises(): app.get_component_by_name("root.b.w_b.c") -@pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime]) -def test_nested_component(runtime_cls): +def test_nested_component(): app = LightningApp(A(), log_level="debug") - runtime_cls(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() assert app.root.w_a.c == 1 assert app.root.b.w_b.c == 1 assert app.root.b.c.w_c.c == 1 @@ -601,9 +600,10 @@ def run(self): class CheckpointFlow(LightningFlow): - def __init__(self, work: LightningWork, depth=0): + def __init__(self, work: CheckpointCounter, depth=0): super().__init__() self.depth = depth + if depth == 0: self.counter = 0 @@ -613,10 +613,9 @@ def __init__(self, work: LightningWork, depth=0): self.flow = CheckpointFlow(work, depth + 1) def run(self): - if hasattr(self, "counter"): - self.counter += 1 - if self.counter > 5: - self._exit() + if self.works()[0].counter == 5: + self._exit() + if self.depth >= 10: self.work.run() else: @@ -627,19 +626,16 @@ def test_lightning_app_checkpointing_with_nested_flows(): work = CheckpointCounter() app = LightningApp(CheckpointFlow(work)) app.checkpointing = True - SingleProcessRuntime(app, start_server=False).dispatch() + MultiProcessRuntime(app, start_server=False).dispatch() - assert app.root.counter == 6 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5 work = CheckpointCounter() app = LightningApp(CheckpointFlow(work)) - assert app.root.counter == 0 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 0 app.load_state_dict_from_checkpoint_dir(app.checkpoint_dir) # The counter was increment to 6 after the latest checkpoints was created. - assert app.root.counter == 5 assert app.root.flow.flow.flow.flow.flow.flow.flow.flow.flow.flow.work.counter == 5 @@ -956,8 +952,8 @@ def run(self): def test_state_size_constant_growth(): app = LightningApp(SizeFlow()) MultiProcessRuntime(app, start_server=False).dispatch() - assert app.root._state_sizes[0] <= 6952 - assert app.root._state_sizes[20] <= 26080 + assert app.root._state_sizes[0] <= 7824 + assert app.root._state_sizes[20] <= 26500 class FlowUpdated(LightningFlow): diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index 23a465968efc8..f6764bb692868 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -402,7 +402,6 @@ def test_call_with_work_app(self, lightningapps, start_with_flow, monkeypatch, t monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock()) app = mock.MagicMock() - flow = mock.MagicMock() work = MyWork(start_with_flow=start_with_flow) monkeypatch.setattr(work, "_name", "test-work") @@ -412,8 +411,7 @@ def test_call_with_work_app(self, lightningapps, start_with_flow, monkeypatch, t monkeypatch.setattr(work._cloud_compute, "disk_size", 0) monkeypatch.setattr(work, "_port", 8080) - flow.works = lambda recurse: [work] - app.flows = [flow] + app.works = [work] cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning_app.runners.cloud._get_project", @@ -575,7 +573,6 @@ def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock()) app = mock.MagicMock() - flow = mock.MagicMock() mocked_drive = MagicMock(spec=Drive) setattr(mocked_drive, "id", "foobar") @@ -598,8 +595,7 @@ def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch monkeypatch.setattr(work._cloud_compute, "disk_size", 0) monkeypatch.setattr(work, "_port", 8080) - flow.works = lambda recurse: [work] - app.flows = [flow] + app.works = [work] cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning_app.runners.cloud._get_project", @@ -712,7 +708,6 @@ def test_call_with_work_app_and_app_comment_command_execution_set(self, lightnin monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock()) app = mock.MagicMock() - flow = mock.MagicMock() work = MyWork() monkeypatch.setattr(work, "_state", {"_port"}) @@ -723,8 +718,7 @@ def test_call_with_work_app_and_app_comment_command_execution_set(self, lightnin monkeypatch.setattr(work._cloud_compute, "disk_size", 0) monkeypatch.setattr(work, "_port", 8080) - flow.works = lambda recurse: [work] - app.flows = [flow] + app.works = [work] cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning_app.runners.cloud._get_project", @@ -829,7 +823,6 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock()) app = mock.MagicMock() - flow = mock.MagicMock() mocked_lit_drive = MagicMock(spec=Drive) setattr(mocked_lit_drive, "id", "foobar") @@ -853,8 +846,7 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo monkeypatch.setattr(work._cloud_compute, "disk_size", 0) monkeypatch.setattr(work, "_port", 8080) - flow.works = lambda recurse: [work] - app.flows = [flow] + app.works = [work] cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning_app.runners.cloud._get_project", @@ -1034,7 +1026,6 @@ def test_call_with_work_app_and_attached_mount_and_drive(self, lightningapps, mo monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) monkeypatch.setattr(cloud, "_prepare_lightning_wheels_and_requirements", mock.MagicMock()) app = mock.MagicMock() - flow = mock.MagicMock() mocked_drive = MagicMock(spec=Drive) setattr(mocked_drive, "id", "foobar") @@ -1063,8 +1054,7 @@ def test_call_with_work_app_and_attached_mount_and_drive(self, lightningapps, mo monkeypatch.setattr(work._cloud_compute, "mounts", mocked_mount) monkeypatch.setattr(work, "_port", 8080) - flow.works = lambda recurse: [work] - app.flows = [flow] + app.works = [work] cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file=(source_code_root_dir / "entrypoint.py")) monkeypatch.setattr( "lightning_app.runners.cloud._get_project", diff --git a/tests/tests_app/storage/test_drive.py b/tests/tests_app/storage/test_drive.py index bee8de5e093a8..d39623bd74296 100644 --- a/tests/tests_app/storage/test_drive.py +++ b/tests/tests_app/storage/test_drive.py @@ -50,7 +50,8 @@ def test_synchronization_lit_drive(tmpdir): os.remove("a.txt") app = LightningApp(SyncFlowLITDrives(tmpdir)) MultiProcessRuntime(app, start_server=False).dispatch() - os.remove("a.txt") + if os.path.exists("a.txt"): + os.remove("a.txt") class LITDriveWork(LightningWork): diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py index 682138d20654e..c9d35423c56b2 100644 --- a/tests/tests_app/utilities/test_proxies.py +++ b/tests/tests_app/utilities/test_proxies.py @@ -216,7 +216,7 @@ def __init__(self): class WorkTimeout(LightningWork): def __init__(self): - super().__init__(parallel=True) + super().__init__(parallel=True, start_with_flow=False) self.counter = 0 def run(self):