From 76bd1ba7898cfbc846befdea0808f5da44a6ec91 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 8 Nov 2022 18:46:30 +0000 Subject: [PATCH 01/34] update --- examples/app_boring/app_dynamic.py | 2 +- examples/app_commands_and_api/app.py | 2 +- examples/app_mount/app.py | 2 +- examples/app_multi_node/train_pl.py | 15 ++++++++------- examples/app_template_streamlit_ui/app.py | 2 +- examples/app_v0/app.py | 2 +- examples/app_works_on_default_machine/app_v2.py | 2 +- src/lightning_app/core/app.py | 16 +++++++++++----- src/lightning_app/testing/testing.py | 2 +- src/lightning_app/utilities/app_helpers.py | 5 +++++ tests/tests_app/core/test_lightning_api.py | 10 +++++----- tests/tests_app/core/test_lightning_app.py | 10 +++++----- tests/tests_app/storage/test_path.py | 4 ++-- tests/tests_app/storage/test_payload.py | 2 +- tests/tests_app/structures/test_structures.py | 2 +- tests/tests_app/utilities/test_proxies.py | 2 +- tests/tests_app_examples/collect_failures/app.py | 2 +- .../custom_work_dependencies/app.py | 2 +- tests/tests_app_examples/idle_timeout/app.py | 2 +- 19 files changed, 49 insertions(+), 37 deletions(-) diff --git a/examples/app_boring/app_dynamic.py b/examples/app_boring/app_dynamic.py index 1308d0c15c469..cfd303505a0ed 100644 --- a/examples/app_boring/app_dynamic.py +++ b/examples/app_boring/app_dynamic.py @@ -64,4 +64,4 @@ def configure_layout(self): return {"name": "Boring Tab", "content": self.dict["dst_w"].url + "/file" if "dst_w" in self.dict else ""} -app = L.LightningApp(BoringApp(), debug=True) +app = L.LightningApp(BoringApp(), log_level="debug") diff --git a/examples/app_commands_and_api/app.py b/examples/app_commands_and_api/app.py index ea00cf72a9e4b..8c62510dea280 100644 --- a/examples/app_commands_and_api/app.py +++ b/examples/app_commands_and_api/app.py @@ -50,4 +50,4 @@ def configure_api(self): ] -app = LightningApp(FlowCommands(), debug=True) +app = LightningApp(FlowCommands(), log_level="debug") diff --git a/examples/app_mount/app.py b/examples/app_mount/app.py index 9754735f0e3d8..11da2f02552d8 100644 --- a/examples/app_mount/app.py +++ b/examples/app_mount/app.py @@ -32,4 +32,4 @@ def run(self): self.work_1.run() -app = L.LightningApp(Flow(), debug=True) +app = L.LightningApp(Flow(), log_level="debug") diff --git a/examples/app_multi_node/train_pl.py b/examples/app_multi_node/train_pl.py index e887eaef7c075..2971f01aeb0b3 100644 --- a/examples/app_multi_node/train_pl.py +++ b/examples/app_multi_node/train_pl.py @@ -14,11 +14,12 @@ def run(): trainer.fit(model) -# Run over 2 nodes of 4 x V100 -app = L.LightningApp( - PyTorchLightningMultiNode( - PyTorchLightningDistributed, - num_nodes=2, - cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 +if __name__ == "__main__": + # Run over 2 nodes of 4 x V100 + app = L.LightningApp( + PyTorchLightningMultiNode( + PyTorchLightningDistributed, + num_nodes=2, + cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 + ) ) -) diff --git a/examples/app_template_streamlit_ui/app.py b/examples/app_template_streamlit_ui/app.py index 45bb775984cd3..6f344ac98eb8d 100644 --- a/examples/app_template_streamlit_ui/app.py +++ b/examples/app_template_streamlit_ui/app.py @@ -45,4 +45,4 @@ def configure_layout(self): return [{"name": "StreamLitUI", "content": self.streamlit_ui}] -app = LightningApp(HelloWorld(), debug=True) +app = LightningApp(HelloWorld(), log_level="debug") diff --git a/examples/app_v0/app.py b/examples/app_v0/app.py index 84512fb474bf3..bf8803fe13598 100644 --- a/examples/app_v0/app.py +++ b/examples/app_v0/app.py @@ -46,4 +46,4 @@ def configure_layout(self): return [tab1, tab2, tab3] -app = L.LightningApp(V0App(), debug=True) +app = L.LightningApp(V0App(), log_level="debug") diff --git a/examples/app_works_on_default_machine/app_v2.py b/examples/app_works_on_default_machine/app_v2.py index f1d3c36d2a184..ee60e77e3db73 100644 --- a/examples/app_works_on_default_machine/app_v2.py +++ b/examples/app_works_on_default_machine/app_v2.py @@ -50,4 +50,4 @@ def configure_layout(self): return [{"name": w.name, "content": w} for i, w in enumerate(self.works())] -app = LightningApp(Flow(), debug=True) +app = LightningApp(Flow(), log_level="debug") diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 0ed3ea22bce19..9c675fe891389 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -6,7 +6,7 @@ import warnings from copy import deepcopy from time import time -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Dict, List, Literal, Optional, Tuple, TYPE_CHECKING, Union from deepdiff import DeepDiff, Delta from lightning_utilities.core.apply_func import apply_to_collection @@ -27,7 +27,7 @@ from lightning_app.storage import Drive, Path from lightning_app.storage.path import _storage_root_dir from lightning_app.utilities import frontend -from lightning_app.utilities.app_helpers import _delta_to_app_state_delta, _LightningAppRef, Logger +from lightning_app.utilities.app_helpers import _delta_to_app_state_delta, _LightningAppRef, debugger_is_active, Logger from lightning_app.utilities.commands.base import _process_requests from lightning_app.utilities.component import _convert_paths_after_init, _validate_root_flow from lightning_app.utilities.enum import AppStage, CacheCallsKeys @@ -52,7 +52,7 @@ def __init__( self, root: Union["LightningFlow", "LightningWork"], flow_cloud_compute: Optional["lightning_app.CloudCompute"] = None, - debug: bool = False, + log_level: Literal["info", "debug"] = "info", info: frontend.AppInfo = None, root_path: str = "", ): @@ -70,7 +70,7 @@ def __init__( root: The root ``LightningFlow`` or ``LightningWork`` component, that defines all the app's nested components, running infinitely. It must define a `run()` method that the app can call. flow_cloud_compute: The default Cloud Compute used for flow, Rest API and frontend's. - debug: Whether to activate the Lightning Logger debug mode. + log_level: Whether to activate the Lightning Logger debug mode. This can be helpful when reporting bugs on Lightning repo. info: Provide additional info about the app which will be used to update html title, description and image meta tags and specify any additional tags as list of html strings. @@ -151,7 +151,7 @@ def __init__( _convert_paths_after_init(self.root) # Lazily enable debugging. - if debug or DEBUG_ENABLED: + if log_level == "debug" or DEBUG_ENABLED: if not DEBUG_ENABLED: os.environ["LIGHTNING_DEBUG"] = "2" _console.setLevel(logging.DEBUG) @@ -162,6 +162,12 @@ def __init__( # this should happen once for all apps before the ui server starts running. frontend.update_index_file(FRONTEND_DIR, info=info, root_path=root_path) + if debugger_is_active() and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))): + os.environ["LIGHTNING_DISPATCHED"] = "1" + from lightning.app.runners import MultiProcessRuntime + + MultiProcessRuntime(self).dispatch() + def get_component_by_name(self, component_name: str): """Returns the instance corresponding to the given component name.""" from lightning_app.structures import Dict as LightningDict diff --git a/src/lightning_app/testing/testing.py b/src/lightning_app/testing/testing.py index 6bf17707cf512..f4c8c001acad7 100644 --- a/src/lightning_app/testing/testing.py +++ b/src/lightning_app/testing/testing.py @@ -167,7 +167,7 @@ def run(self): def run_work_isolated(work, *args, start_server: bool = False, **kwargs): """This function is used to run a work a single time with multiprocessing runtime.""" MultiProcessRuntime( - LightningApp(_SingleWorkFlow(work, args, kwargs), debug=True), + LightningApp(_SingleWorkFlow(work, args, kwargs), log_level="debug"), start_server=start_server, ).dispatch() # pop the stopped status. diff --git a/src/lightning_app/utilities/app_helpers.py b/src/lightning_app/utilities/app_helpers.py index 36109dd628fe4..29923380468c1 100644 --- a/src/lightning_app/utilities/app_helpers.py +++ b/src/lightning_app/utilities/app_helpers.py @@ -488,3 +488,8 @@ def _load_state_dict(root_flow: "LightningFlow", state: Dict[str, Any], strict: def is_static_method(klass_or_instance, attr) -> bool: return isinstance(inspect.getattr_static(klass_or_instance, attr), staticmethod) + + +def debugger_is_active() -> bool: + """Return if the debugger is currently active.""" + return hasattr(sys, "gettrace") and sys.gettrace() is not None diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index e5494757cdb61..a0069f1314841 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -75,7 +75,7 @@ def run(self): @pytest.mark.parametrize("runtime_cls", [MultiProcessRuntime]) def test_app_state_api(runtime_cls): """This test validates the AppState can properly broadcast changes from work within its own process.""" - app = LightningApp(_A(), debug=True) + app = LightningApp(_A(), log_level="debug") runtime_cls(app, start_server=True).dispatch() assert app.root.work_a.var_a == -1 _set_work_context() @@ -110,7 +110,7 @@ def run(self): @pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime]) def test_app_state_api_with_flows(runtime_cls, tmpdir): """This test validates the AppState can properly broadcast changes from flows.""" - app = LightningApp(A2(), debug=True) + app = LightningApp(A2(), log_level="debug") runtime_cls(app, start_server=True).dispatch() assert app.root.var_a == -1 @@ -185,7 +185,7 @@ def maybe_apply_changes(self): def test_app_stage_from_frontend(runtime_cls): """This test validates that delta from the `api_delta_queue` manipulating the ['app_state']['stage'] would start and stop the app.""" - app = AppStageTestingApp(FlowA(), debug=True) + app = AppStageTestingApp(FlowA(), log_level="debug") app.stage = AppStage.BLOCKING runtime_cls(app, start_server=True).dispatch() @@ -197,7 +197,7 @@ def test_update_publish_state_and_maybe_refresh_ui(): - receives a notification to refresh the UI and makes a GET Request (streamlit). """ - app = AppStageTestingApp(FlowA(), debug=True) + app = AppStageTestingApp(FlowA(), log_level="debug") publish_state_queue = _MockQueue("publish_state_queue") api_response_queue = _MockQueue("api_response_queue") @@ -224,7 +224,7 @@ class InfiniteQueue(_MockQueue): def get(self, timeout: int = 0): return self._queue[0] - app = AppStageTestingApp(FlowA(), debug=True) + app = AppStageTestingApp(FlowA(), log_level="debug") app._update_layout() app.stage = AppStage.BLOCKING publish_state_queue = InfiniteQueue("publish_state_queue") diff --git a/tests/tests_app/core/test_lightning_app.py b/tests/tests_app/core/test_lightning_app.py index 8eee33d4f8f0d..de4ae6a56d94d 100644 --- a/tests/tests_app/core/test_lightning_app.py +++ b/tests/tests_app/core/test_lightning_app.py @@ -107,7 +107,7 @@ def run(self): @pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime]) def test_simple_app(component_cls, runtime_cls, tmpdir): comp = component_cls() - app = LightningApp(comp, debug=True) + app = LightningApp(comp, log_level="debug") assert app.root == comp expected = { "app_state": ANY, @@ -249,7 +249,7 @@ def test_get_component_by_name_raises(): @pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime]) def test_nested_component(runtime_cls): - app = LightningApp(A(), debug=True) + app = LightningApp(A(), log_level="debug") runtime_cls(app, start_server=False).dispatch() assert app.root.w_a.c == 1 assert app.root.b.w_b.c == 1 @@ -361,7 +361,7 @@ def _apply_restarting(self): @pytest.mark.parametrize("runtime_cls", [SingleProcessRuntime, MultiProcessRuntime]) def test_app_restarting_move_to_blocking(runtime_cls, tmpdir): """Validates sending restarting move the app to blocking again.""" - app = SimpleApp2(CounterFlow(), debug=True) + app = SimpleApp2(CounterFlow(), log_level="debug") runtime_cls(app, start_server=False).dispatch() @@ -395,7 +395,7 @@ def run_once(self): @mock.patch("lightning_app.frontend.stream_lit.StreamlitFrontend.stop_server") def test_app_starts_with_complete_state_copy(_, __): """Test that the LightningApp captures the initial state in a separate copy when _run() gets called.""" - app = AppWithFrontend(FlowWithFrontend(), debug=True) + app = AppWithFrontend(FlowWithFrontend(), log_level="debug") MultiProcessRuntime(app, start_server=False).dispatch() assert app.run_once_call_count == 3 @@ -992,7 +992,7 @@ def test_debug_mode_logging(): from lightning_app.core.app import _console - app = LightningApp(A4(), debug=True) + app = LightningApp(A4(), log_level="debug") assert _console.level == logging.DEBUG assert os.getenv("LIGHTNING_DEBUG") == "2" diff --git a/tests/tests_app/storage/test_path.py b/tests/tests_app/storage/test_path.py index 78694c2edb2ed..3cd501f7344c8 100644 --- a/tests/tests_app/storage/test_path.py +++ b/tests/tests_app/storage/test_path.py @@ -377,7 +377,7 @@ def run(self): def test_multiprocess_path_in_work_and_flow(tmpdir): root = SourceToDestFlow(tmpdir) - app = LightningApp(root, debug=True) + app = LightningApp(root, log_level="debug") MultiProcessRuntime(app, start_server=False).dispatch() @@ -551,7 +551,7 @@ def run(self): def test_path_get_overwrite(tmpdir): """Test that .get(overwrite=True) overwrites the entire directory and replaces all files.""" root = OverwriteFolderFlow(tmpdir) - app = LightningApp(root, debug=True) + app = LightningApp(root, log_level="debug") MultiProcessRuntime(app, start_server=False).dispatch() diff --git a/tests/tests_app/storage/test_payload.py b/tests/tests_app/storage/test_payload.py index 4e7a297e1f0b8..ebe563f15ec29 100644 --- a/tests/tests_app/storage/test_payload.py +++ b/tests/tests_app/storage/test_payload.py @@ -146,7 +146,7 @@ def run(self): def test_payload_works(tmpdir): """This tests validates the payload api can be used to transfer return values from a work to another.""" with mock.patch("lightning_app.storage.path._storage_root_dir", lambda: pathlib.Path(tmpdir)): - app = LightningApp(Flow(), debug=True) + app = LightningApp(Flow(), log_level="debug") MultiProcessRuntime(app, start_server=False).dispatch() os.remove("value_all") diff --git a/tests/tests_app/structures/test_structures.py b/tests/tests_app/structures/test_structures.py index 91c7dfe91c32c..7b84e31402f36 100644 --- a/tests/tests_app/structures/test_structures.py +++ b/tests/tests_app/structures/test_structures.py @@ -494,6 +494,6 @@ def run(self): def test_structures_with_payload(): - app = LightningApp(FlowPayload(), debug=True) + app = LightningApp(FlowPayload(), log_level="debug") MultiProcessRuntime(app, start_server=False).dispatch() os.remove("payload") diff --git a/tests/tests_app/utilities/test_proxies.py b/tests/tests_app/utilities/test_proxies.py index 557d88c3d836d..832021dc0553c 100644 --- a/tests/tests_app/utilities/test_proxies.py +++ b/tests/tests_app/utilities/test_proxies.py @@ -266,7 +266,7 @@ def __call__(self): @mock.patch("lightning_app.runners.backends.mp_process.WorkRunner", WorkRunnerPatch) def test_proxy_timeout(): - app = LightningApp(FlowTimeout(), debug=True) + app = LightningApp(FlowTimeout(), log_level="debug") MultiProcessRuntime(app, start_server=False).dispatch() call_hash = app.root.work._calls[CacheCallsKeys.LATEST_CALL_HASH] diff --git a/tests/tests_app_examples/collect_failures/app.py b/tests/tests_app_examples/collect_failures/app.py index 6675cff61dea9..f9491a8de27aa 100644 --- a/tests/tests_app_examples/collect_failures/app.py +++ b/tests/tests_app_examples/collect_failures/app.py @@ -43,4 +43,4 @@ def run(self): if __name__ == "__main__": - app = LightningApp(RootFlow(), debug=True) + app = LightningApp(RootFlow(), log_level="debug") diff --git a/tests/tests_app_examples/custom_work_dependencies/app.py b/tests/tests_app_examples/custom_work_dependencies/app.py index 0b3ba1d5f5c0b..9b45bc7c190f6 100644 --- a/tests/tests_app_examples/custom_work_dependencies/app.py +++ b/tests/tests_app_examples/custom_work_dependencies/app.py @@ -50,4 +50,4 @@ def run(self): self._exit() -app = LightningApp(CustomWorkBuildConfigChecker(), debug=True) +app = LightningApp(CustomWorkBuildConfigChecker(), log_level="debug") diff --git a/tests/tests_app_examples/idle_timeout/app.py b/tests/tests_app_examples/idle_timeout/app.py index aa6442180c16e..ab96ca8b074a5 100644 --- a/tests/tests_app_examples/idle_timeout/app.py +++ b/tests/tests_app_examples/idle_timeout/app.py @@ -68,4 +68,4 @@ def run(self): self._exit() -app = LightningApp(RootFlow(), debug=True) +app = LightningApp(RootFlow(), log_level="debug") From a753b8a6c3e0cbab5baaef77a7a5d6a95c4a10a6 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 8 Nov 2022 18:57:36 +0000 Subject: [PATCH 02/34] update --- examples/app_multi_node/train_pl.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/examples/app_multi_node/train_pl.py b/examples/app_multi_node/train_pl.py index 2971f01aeb0b3..e887eaef7c075 100644 --- a/examples/app_multi_node/train_pl.py +++ b/examples/app_multi_node/train_pl.py @@ -14,12 +14,11 @@ def run(): trainer.fit(model) -if __name__ == "__main__": - # Run over 2 nodes of 4 x V100 - app = L.LightningApp( - PyTorchLightningMultiNode( - PyTorchLightningDistributed, - num_nodes=2, - cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 - ) +# Run over 2 nodes of 4 x V100 +app = L.LightningApp( + PyTorchLightningMultiNode( + PyTorchLightningDistributed, + num_nodes=2, + cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 ) +) From 447adc2e047a57b289b67be7389da58aa0decfad Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 8 Nov 2022 19:17:36 +0000 Subject: [PATCH 03/34] update --- src/lightning_app/core/app.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 9c675fe891389..6022f718fd8bc 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -6,7 +6,7 @@ import warnings from copy import deepcopy from time import time -from typing import Dict, List, Literal, Optional, Tuple, TYPE_CHECKING, Union +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union from deepdiff import DeepDiff, Delta from lightning_utilities.core.apply_func import apply_to_collection @@ -52,7 +52,7 @@ def __init__( self, root: Union["LightningFlow", "LightningWork"], flow_cloud_compute: Optional["lightning_app.CloudCompute"] = None, - log_level: Literal["info", "debug"] = "info", + log_level: str = "info", info: frontend.AppInfo = None, root_path: str = "", ): @@ -150,6 +150,9 @@ def __init__( # is only available after all Flows and Works have been instantiated. _convert_paths_after_init(self.root) + if log_level not in ("debug", "info"): + raise Exception(f"Log Level should be in ['debug', 'info']. Found {log_level}") + # Lazily enable debugging. if log_level == "debug" or DEBUG_ENABLED: if not DEBUG_ENABLED: From f63c47d3014cdec92a3970d0b5442b1f514a233c Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 8 Nov 2022 19:23:16 +0000 Subject: [PATCH 04/34] update --- src/lightning_app/core/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 6022f718fd8bc..c893f625f20a9 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -167,7 +167,7 @@ def __init__( if debugger_is_active() and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))): os.environ["LIGHTNING_DISPATCHED"] = "1" - from lightning.app.runners import MultiProcessRuntime + from lightning_app.runners import MultiProcessRuntime MultiProcessRuntime(self).dispatch() From a67632ca253cac342995fd80646adb5e3150a23b Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 8 Nov 2022 19:30:40 +0000 Subject: [PATCH 05/34] update --- src/lightning_app/core/app.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index c893f625f20a9..9de4798615322 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -165,7 +165,13 @@ def __init__( # this should happen once for all apps before the ui server starts running. frontend.update_index_file(FRONTEND_DIR, info=info, root_path=root_path) - if debugger_is_active() and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))): + print(os.environ) + + if ( + debugger_is_active() + and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))) + and os.getenv("PYTEST_CURRENT_TEST", None) is None + ): os.environ["LIGHTNING_DISPATCHED"] = "1" from lightning_app.runners import MultiProcessRuntime From 50efaaf79b523acbb9820cf910cd4fb9dcc27ef1 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 8 Nov 2022 19:31:07 +0000 Subject: [PATCH 06/34] update --- src/lightning_app/core/app.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 9de4798615322..62e1dce194998 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -165,8 +165,6 @@ def __init__( # this should happen once for all apps before the ui server starts running. frontend.update_index_file(FRONTEND_DIR, info=info, root_path=root_path) - print(os.environ) - if ( debugger_is_active() and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))) From ad6468700bcb18b265ece45693719fea67a7b256 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 8 Nov 2022 19:39:49 +0000 Subject: [PATCH 07/34] update --- src/lightning_app/core/app.py | 6 +----- tests/tests_app/conftest.py | 2 ++ tests/tests_app_examples/conftest.py | 2 ++ 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 62e1dce194998..c893f625f20a9 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -165,11 +165,7 @@ def __init__( # this should happen once for all apps before the ui server starts running. frontend.update_index_file(FRONTEND_DIR, info=info, root_path=root_path) - if ( - debugger_is_active() - and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))) - and os.getenv("PYTEST_CURRENT_TEST", None) is None - ): + if debugger_is_active() and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))): os.environ["LIGHTNING_DISPATCHED"] = "1" from lightning_app.runners import MultiProcessRuntime diff --git a/tests/tests_app/conftest.py b/tests/tests_app/conftest.py index 891cf97fd0c8d..434c849f569c9 100644 --- a/tests/tests_app/conftest.py +++ b/tests/tests_app/conftest.py @@ -20,6 +20,8 @@ "template_react_ui": "https://github.com/Lightning-AI/lightning-template-react.git", } +os.environ["LIGHTNING_DISPATCHED"] = "1" + def pytest_sessionstart(*_): """Pytest hook that get called after the Session object has been created and before performing collection and diff --git a/tests/tests_app_examples/conftest.py b/tests/tests_app_examples/conftest.py index b5a845c42c516..493d2c941eac3 100644 --- a/tests/tests_app_examples/conftest.py +++ b/tests/tests_app_examples/conftest.py @@ -11,6 +11,8 @@ from lightning_app.utilities.packaging.app_config import _APP_CONFIG_FILENAME from lightning_app.utilities.state import AppState +os.environ["LIGHTNING_DISPATCHED"] = "1" + def pytest_sessionfinish(session, exitstatus): """Pytest hook that get called after whole test run finished, right before returning the exit status to the From 1ff6571fac1e81903d191a511cae220801fc9aaa Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 08:50:45 +0000 Subject: [PATCH 08/34] update --- src/lightning_app/runners/runtime.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/lightning_app/runners/runtime.py b/src/lightning_app/runners/runtime.py index ad0eb1c6bcc8f..5947d410ea8eb 100644 --- a/src/lightning_app/runners/runtime.py +++ b/src/lightning_app/runners/runtime.py @@ -1,4 +1,5 @@ import multiprocessing +import os import sys from dataclasses import dataclass, field from pathlib import Path @@ -52,6 +53,9 @@ def dispatch( from lightning_app.runners.runtime_type import RuntimeType from lightning_app.utilities.component import _set_flow_context + # Used to indicate Lightning has been dispatched + os.environ["LIGHTNING_DISPATCHED"] = "1" + _set_flow_context() runtime_type = RuntimeType(runtime_type) From b75524fe974c51448e3bfcc9758547d0e66591c2 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 12:21:13 +0000 Subject: [PATCH 09/34] update --- src/lightning/__init__.py | 3 +++ src/lightning_app/core/app.py | 11 ++++++++--- src/lightning_app/utilities/app_helpers.py | 10 +++++++++- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/lightning/__init__.py b/src/lightning/__init__.py index cae7ecd152fd4..30950d8c6bdbb 100644 --- a/src/lightning/__init__.py +++ b/src/lightning/__init__.py @@ -45,6 +45,9 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None: lightning.app._PROJECT_ROOT = os.path.dirname(lightning.app._PROJECT_ROOT) +# Enable breakpoint within forked processes. +__builtins__["breakpoint"] = pdb.set_trace + __all__ = [ "LightningApp", "LightningFlow", diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index c893f625f20a9..9cd1a7241dbf1 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -27,7 +27,12 @@ from lightning_app.storage import Drive, Path from lightning_app.storage.path import _storage_root_dir from lightning_app.utilities import frontend -from lightning_app.utilities.app_helpers import _delta_to_app_state_delta, _LightningAppRef, debugger_is_active, Logger +from lightning_app.utilities.app_helpers import ( + _delta_to_app_state_delta, + _LightningAppRef, + _should_dispatch_app, + Logger, +) from lightning_app.utilities.commands.base import _process_requests from lightning_app.utilities.component import _convert_paths_after_init, _validate_root_flow from lightning_app.utilities.enum import AppStage, CacheCallsKeys @@ -70,7 +75,7 @@ def __init__( root: The root ``LightningFlow`` or ``LightningWork`` component, that defines all the app's nested components, running infinitely. It must define a `run()` method that the app can call. flow_cloud_compute: The default Cloud Compute used for flow, Rest API and frontend's. - log_level: Whether to activate the Lightning Logger debug mode. + log_level: The log level for the app, one of [`info`, `debug`]. This can be helpful when reporting bugs on Lightning repo. info: Provide additional info about the app which will be used to update html title, description and image meta tags and specify any additional tags as list of html strings. @@ -165,7 +170,7 @@ def __init__( # this should happen once for all apps before the ui server starts running. frontend.update_index_file(FRONTEND_DIR, info=info, root_path=root_path) - if debugger_is_active() and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))): + if _should_dispatch_app(): os.environ["LIGHTNING_DISPATCHED"] = "1" from lightning_app.runners import MultiProcessRuntime diff --git a/src/lightning_app/utilities/app_helpers.py b/src/lightning_app/utilities/app_helpers.py index 29923380468c1..3f2de886bcc64 100644 --- a/src/lightning_app/utilities/app_helpers.py +++ b/src/lightning_app/utilities/app_helpers.py @@ -490,6 +490,14 @@ def is_static_method(klass_or_instance, attr) -> bool: return isinstance(inspect.getattr_static(klass_or_instance, attr), staticmethod) -def debugger_is_active() -> bool: +def _debugger_is_active() -> bool: """Return if the debugger is currently active.""" return hasattr(sys, "gettrace") and sys.gettrace() is not None + + +def _should_dispatch_app() -> bool: + return ( + _debugger_is_active() + and not bool(int(os.getenv("LIGHTNING_DISPATCHED", "0"))) + and "LIGHTNING_APP_STATE_URL" not in os.environ + ) From c63307e1fc4282a1777129b2810034273eb8abb8 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 12:52:38 +0000 Subject: [PATCH 10/34] update --- src/lightning_app/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index e00ae73f41226..dcb1fc26c6f9b 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Expose `RunWorkExecutor` to the work and provides default ones for the `MultiNode` Component ([#15561](https://github.com/Lightning-AI/lightning/pull/15561)) +- Added support for running Lightning App with IDE debugger ([#15590](https://github.com/Lightning-AI/lightning/pull/15590)) + ### Changed From 34b893eef1c1cae18afc48c7acb9ab2787794e96 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 15:00:06 +0000 Subject: [PATCH 11/34] update --- examples/app_multi_node/train_lite.py | 3 +- examples/app_multi_node/train_pl.py | 4 +- .../app_multi_node/train_pytorch_spawn.py | 3 +- .../components/multi_node/pytorch_spawn.py | 19 ++++++++- src/lightning_app/utilities/proxies.py | 40 +++++++++++++------ 5 files changed, 48 insertions(+), 21 deletions(-) diff --git a/examples/app_multi_node/train_lite.py b/examples/app_multi_node/train_lite.py index ed9777a1064f6..60d16f93dd17f 100644 --- a/examples/app_multi_node/train_lite.py +++ b/examples/app_multi_node/train_lite.py @@ -6,8 +6,7 @@ class LitePyTorchDistributed(L.LightningWork): - @staticmethod - def run(): + def run(self): # 1. Create LightningLite. lite = LightningLite(strategy="ddp", precision="bf16") diff --git a/examples/app_multi_node/train_pl.py b/examples/app_multi_node/train_pl.py index e887eaef7c075..11232b9195517 100644 --- a/examples/app_multi_node/train_pl.py +++ b/examples/app_multi_node/train_pl.py @@ -4,8 +4,8 @@ class PyTorchLightningDistributed(L.LightningWork): - @staticmethod - def run(): + + def run(self): model = BoringModel() trainer = L.Trainer( max_epochs=10, diff --git a/examples/app_multi_node/train_pytorch_spawn.py b/examples/app_multi_node/train_pytorch_spawn.py index d3a570fd4ff77..f9838ca04cb62 100644 --- a/examples/app_multi_node/train_pytorch_spawn.py +++ b/examples/app_multi_node/train_pytorch_spawn.py @@ -7,9 +7,8 @@ class PyTorchDistributed(L.LightningWork): - # Note: Only staticmethod are support for now with `PyTorchSpawnMultiNode` - @staticmethod def run( + self, world_size: int, node_rank: int, global_rank: str, diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index 62ccfb95174eb..8f52c08d65a8a 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -22,6 +22,9 @@ def run( class _PyTorchSpawnRunExecutor(WorkRunExecutor): + + enable_start_observer: bool = False + def __call__( self, main_address: str, @@ -33,19 +36,31 @@ def __call__( nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 torch.multiprocessing.spawn( - self.run, args=(self.work_run, main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs + self.run, args=(self.work, main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs ) @staticmethod def run( local_rank: int, - work_run: Callable, + work: "LightningWork", main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int, ): + if local_rank == 0: + _proxy_setattr(work, self.delta_queue, self.state_observer, cleanup=cleanup) + setattr_proxy = LightningWorkSetAttrProxy( + self.work_name, + self.work, + delta_queue=self.delta_queue, + state_observer=state_observer, + lock=_state_observer_lock, + ) + self.work._setattr_replacement = setattr_proxy + + import torch # 1. Setting distributed environment diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index a1d3972d7ecd0..8c8f2b8840662 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -278,10 +278,11 @@ def run(self): work: "LightningWork" delta_queue: "BaseQueue" state_observer: "WorkStateObserver" + lock: threading.Lock def __call__(self, name: str, value: Any) -> None: logger.debug(f"Setting {name}: {value}") - with _state_observer_lock: + with self.lock: state = deepcopy(self.work.state) self.work._default_setattr(name, value) delta = Delta(DeepDiff(state, self.work.state, verbose_level=2)) @@ -292,7 +293,8 @@ def __call__(self, name: str, value: Any) -> None: self.delta_queue.put(ComponentDelta(id=self.work_name, delta=delta)) # add the delta to the buffer to let WorkStateObserver know we already sent this one to the Flow - self.state_observer._delta_memory.append(delta) + if self.state_observer: + self.state_observer._delta_memory.append(delta) @dataclass @@ -306,6 +308,8 @@ class WorkRunExecutor: work: "LightningWork" work_run: Callable + delta_queue: "BaseQueue" + enable_start_observer: bool = True def __call__(self, *args, **kwargs): return self.work_run(*args, **kwargs) @@ -404,7 +408,8 @@ def run_once(self): self._transfer_path_attributes() # 6. Create the state observer thread. - self.state_observer = WorkStateObserver(self.work, delta_queue=self.delta_queue) + if self.run_executor_cls.enable_start_observer: + self.state_observer = WorkStateObserver(self.work, delta_queue=self.delta_queue) # 7. Deepcopy the work state and send the first `RUNNING` status delta to the flow. reference_state = deepcopy(self.work.state) @@ -435,7 +440,8 @@ def run_once(self): # 11. Start the state observer thread. It will look for state changes and send them back to the Flow # The observer has to be initialized here, after the set_state call above so that the thread can start with # the proper initial state of the work - self.state_observer.start() + if self.run_executor_cls.enable_start_observer: + self.state_observer.start() # 12. Run the `work_run` method. # If an exception is raised, send a `FAILED` status delta to the flow and call the `on_exception` hook. @@ -482,7 +488,8 @@ def run_once(self): return # 13. Destroy the state observer. - self.state_observer.join(0) + if self.run_executor_cls.enable_start_observer: + self.state_observer.join(0) self.state_observer = None # 14. Copy all artifacts to the shared storage so other Works can access them while this Work gets scaled down @@ -531,14 +538,7 @@ def _sigterm_signal_handler(self, signum, frame, call_hash: str) -> None: raise LightningSigtermStateException(0) def _proxy_setattr(self, cleanup: bool = False): - if cleanup: - setattr_proxy = None - else: - assert self.state_observer - setattr_proxy = LightningWorkSetAttrProxy( - self.work_name, self.work, delta_queue=self.delta_queue, state_observer=self.state_observer - ) - self.work._setattr_replacement = setattr_proxy + _proxy_setattr(self.work, self.delta_queue, self.state_observer, cleanup=cleanup) def _process_call_args( self, args: Tuple[Any, ...], kwargs: Dict[str, Any] @@ -645,3 +645,17 @@ def persist_artifacts(work: "LightningWork") -> None: f"All {destination_paths} artifacts from Work {work.name} successfully " "stored at {artifacts_path(work.name)}." ) + + +def _proxy_setattr(work, delta_queue, state_observer: WorkStateObserver, cleanup: bool = False): + if cleanup: + setattr_proxy = None + else: + setattr_proxy = LightningWorkSetAttrProxy( + work.name, + work, + delta_queue=delta_queue, + state_observer=state_observer, + lock=_state_observer_lock, + ) + work._setattr_replacement = setattr_proxy \ No newline at end of file From 92c8ce6d1f74933aa0443db060b2276979b82217 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 15:00:30 +0000 Subject: [PATCH 12/34] update --- examples/app_multi_node/train_pl.py | 1 - examples/app_multi_node/train_pytorch_spawn.py | 1 - src/lightning_app/components/multi_node/pytorch_spawn.py | 2 +- src/lightning_app/utilities/proxies.py | 2 +- 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/app_multi_node/train_pl.py b/examples/app_multi_node/train_pl.py index 11232b9195517..364680772722d 100644 --- a/examples/app_multi_node/train_pl.py +++ b/examples/app_multi_node/train_pl.py @@ -4,7 +4,6 @@ class PyTorchLightningDistributed(L.LightningWork): - def run(self): model = BoringModel() trainer = L.Trainer( diff --git a/examples/app_multi_node/train_pytorch_spawn.py b/examples/app_multi_node/train_pytorch_spawn.py index f9838ca04cb62..583bf4a48810a 100644 --- a/examples/app_multi_node/train_pytorch_spawn.py +++ b/examples/app_multi_node/train_pytorch_spawn.py @@ -6,7 +6,6 @@ class PyTorchDistributed(L.LightningWork): - def run( self, world_size: int, diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index 8f52c08d65a8a..46331baec33b6 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -36,7 +36,7 @@ def __call__( nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 torch.multiprocessing.spawn( - self.run, args=(self.work, main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs + self.run, args=(self.work, self.delta_queue, main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs ) @staticmethod diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index 8c8f2b8840662..63c51d1b3aaca 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -658,4 +658,4 @@ def _proxy_setattr(work, delta_queue, state_observer: WorkStateObserver, cleanup state_observer=state_observer, lock=_state_observer_lock, ) - work._setattr_replacement = setattr_proxy \ No newline at end of file + work._setattr_replacement = setattr_proxy From 83d303fe58abf12df440ce9ee47f2dd23ea42e7e Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 18:02:56 +0000 Subject: [PATCH 13/34] update --- examples/app_multi_node/train_lite.py | 10 +++++- .../app_multi_node/train_pytorch_spawn.py | 8 +++++ .../components/multi_node/lite.py | 18 +++++------ src/lightning_app/components/multi_node/pl.py | 18 +++++------ .../components/multi_node/pytorch_spawn.py | 32 +++++++------------ src/lightning_app/utilities/proxies.py | 6 ++-- 6 files changed, 50 insertions(+), 42 deletions(-) diff --git a/examples/app_multi_node/train_lite.py b/examples/app_multi_node/train_lite.py index 60d16f93dd17f..f33960efe5f8c 100644 --- a/examples/app_multi_node/train_lite.py +++ b/examples/app_multi_node/train_lite.py @@ -6,6 +6,10 @@ class LitePyTorchDistributed(L.LightningWork): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.step = 0 + def run(self): # 1. Create LightningLite. lite = LightningLite(strategy="ddp", precision="bf16") @@ -18,6 +22,9 @@ def run(self): # 3. Train the model for 50 steps. for step in range(50): + + self.step = step + model.zero_grad() x = torch.randn(64, 32).to(lite.device) output = model(x) @@ -33,5 +40,6 @@ def run(self): LitePyTorchDistributed, cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 num_nodes=2, - ) + ), + log_level="debug", ) diff --git a/examples/app_multi_node/train_pytorch_spawn.py b/examples/app_multi_node/train_pytorch_spawn.py index 583bf4a48810a..07835153a6e60 100644 --- a/examples/app_multi_node/train_pytorch_spawn.py +++ b/examples/app_multi_node/train_pytorch_spawn.py @@ -6,6 +6,10 @@ class PyTorchDistributed(L.LightningWork): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.step = 0 + def run( self, world_size: int, @@ -25,6 +29,10 @@ def run( # 3. Train the model for 50 steps. for step in range(50): + + # 4. Update step + self.step = step + model.zero_grad() x = torch.randn(64, 32).to(device) output = model(x) diff --git a/src/lightning_app/components/multi_node/lite.py b/src/lightning_app/components/multi_node/lite.py index 5295d0beb869e..0e711b049d7b3 100644 --- a/src/lightning_app/components/multi_node/lite.py +++ b/src/lightning_app/components/multi_node/lite.py @@ -1,13 +1,13 @@ import os from dataclasses import dataclass -from typing import Any, Callable, Type +from typing import Any, Type from typing_extensions import Protocol, runtime_checkable +from lightning.app.utilities.proxies import _proxy_setattr, unwrap from lightning_app.components.multi_node.base import MultiNode from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor from lightning_app.core.work import LightningWork -from lightning_app.utilities.app_helpers import is_static_method from lightning_app.utilities.packaging.cloud_compute import CloudCompute from lightning_app.utilities.tracer import Tracer @@ -24,13 +24,18 @@ class _LiteRunExecutor(_PyTorchSpawnRunExecutor): @staticmethod def run( local_rank: int, - work_run: Callable, + work: "LightningWork", + delta_queue, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int, ): + if local_rank == 0: + _proxy_setattr(work, delta_queue, None) + pass + from lightning.lite import LightningLite from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy @@ -68,7 +73,7 @@ def pre_fn(lite, *args, **kwargs): tracer = Tracer() tracer.add_traced(LightningLite, "__init__", pre_fn=pre_fn) tracer._instrument() - work_run() + unwrap(work.run)() tracer._restore() @@ -82,11 +87,6 @@ def __init__( **work_kwargs: Any, ) -> None: assert issubclass(work_cls, _LiteWorkProtocol) - if not is_static_method(work_cls, "run"): - raise TypeError( - f"The provided {work_cls} run method needs to be static for now." - "HINT: Remove `self` and add staticmethod decorator." - ) # Note: Private way to modify the work run executor # Probably exposed to the users in the future if needed. diff --git a/src/lightning_app/components/multi_node/pl.py b/src/lightning_app/components/multi_node/pl.py index c11b72b6ce68d..159566ec9dd20 100644 --- a/src/lightning_app/components/multi_node/pl.py +++ b/src/lightning_app/components/multi_node/pl.py @@ -1,13 +1,13 @@ import os from dataclasses import dataclass -from typing import Any, Callable, Type +from typing import Any, Type from typing_extensions import Protocol, runtime_checkable +from lightning.app.utilities.proxies import _proxy_setattr, unwrap from lightning_app.components.multi_node.base import MultiNode from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor from lightning_app.core.work import LightningWork -from lightning_app.utilities.app_helpers import is_static_method from lightning_app.utilities.packaging.cloud_compute import CloudCompute from lightning_app.utilities.tracer import Tracer @@ -24,13 +24,18 @@ class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor): @staticmethod def run( local_rank: int, - work_run: Callable, + work: "LightningWork", + delta_queue, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int, ): + if local_rank == 0: + _proxy_setattr(work, delta_queue, None) + pass + from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy from lightning.pytorch import Trainer as LTrainer from pytorch_lightning import Trainer as PLTrainer @@ -67,7 +72,7 @@ def pre_fn(trainer, *args, **kwargs): tracer.add_traced(PLTrainer, "__init__", pre_fn=pre_fn) tracer.add_traced(LTrainer, "__init__", pre_fn=pre_fn) tracer._instrument() - work_run() + unwrap(work.run)() tracer._restore() @@ -81,11 +86,6 @@ def __init__( **work_kwargs: Any, ) -> None: assert issubclass(work_cls, _PyTorchLightningWorkProtocol) - if not is_static_method(work_cls, "run"): - raise TypeError( - f"The provided {work_cls} run method needs to be static for now." - "HINT: Remove `self` and add staticmethod decorator." - ) # Note: Private way to modify the work run executor # Probably exposed to the users in the future if needed. diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index 46331baec33b6..eee38a5058241 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -1,12 +1,11 @@ -from typing import Any, Callable, Type +from typing import Any, Type from typing_extensions import Protocol, runtime_checkable from lightning_app.components.multi_node.base import MultiNode from lightning_app.core.work import LightningWork -from lightning_app.utilities.app_helpers import is_static_method from lightning_app.utilities.packaging.cloud_compute import CloudCompute -from lightning_app.utilities.proxies import WorkRunExecutor +from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor @runtime_checkable @@ -34,15 +33,21 @@ def __call__( ): import torch + # Remove the wrapper. + self.work._setattr_replacement = None + nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 torch.multiprocessing.spawn( - self.run, args=(self.work, self.delta_queue, main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs + self.run, + args=(self.work, self.delta_queue, main_address, main_port, num_nodes, node_rank, nprocs), + nprocs=nprocs, ) @staticmethod def run( local_rank: int, work: "LightningWork", + delta_queue, main_address: str, main_port: int, num_nodes: int, @@ -50,16 +55,8 @@ def run( nprocs: int, ): if local_rank == 0: - _proxy_setattr(work, self.delta_queue, self.state_observer, cleanup=cleanup) - setattr_proxy = LightningWorkSetAttrProxy( - self.work_name, - self.work, - delta_queue=self.delta_queue, - state_observer=state_observer, - lock=_state_observer_lock, - ) - self.work._setattr_replacement = setattr_proxy - + _proxy_setattr(work, delta_queue, None) + pass import torch @@ -78,7 +75,7 @@ def run( elif world_size > 1: raise Exception("Torch distributed should be available.") - work_run(world_size, node_rank, global_rank, local_rank) + unwrap(work.run)(world_size, node_rank, global_rank, local_rank) class PyTorchSpawnMultiNode(MultiNode): @@ -91,11 +88,6 @@ def __init__( **work_kwargs: Any, ) -> None: assert issubclass(work_cls, _PyTorchSpawnWorkProtocol) - if not is_static_method(work_cls, "run"): - raise TypeError( - f"The provided {work_cls} run method needs to be static for now." - "HINT: Remove `self` and add staticmethod decorator." - ) # Note: Private way to modify the work run executor # Probably exposed to the users in the future if needed. diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index 63c51d1b3aaca..0c43750ee8854 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -277,7 +277,7 @@ def run(self): work_name: str work: "LightningWork" delta_queue: "BaseQueue" - state_observer: "WorkStateObserver" + state_observer: Optional["WorkStateObserver"] lock: threading.Lock def __call__(self, name: str, value: Any) -> None: @@ -446,7 +446,7 @@ def run_once(self): # 12. Run the `work_run` method. # If an exception is raised, send a `FAILED` status delta to the flow and call the `on_exception` hook. try: - ret = self.run_executor_cls(self.work, work_run)(*args, **kwargs) + ret = self.run_executor_cls(self.work, work_run, self.delta_queue)(*args, **kwargs) except LightningSigtermStateException as e: raise e except BaseException as e: @@ -647,7 +647,7 @@ def persist_artifacts(work: "LightningWork") -> None: ) -def _proxy_setattr(work, delta_queue, state_observer: WorkStateObserver, cleanup: bool = False): +def _proxy_setattr(work, delta_queue, state_observer: Optional[WorkStateObserver], cleanup: bool = False): if cleanup: setattr_proxy = None else: From c0308c967e75519c31425438010f5a640289adc4 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 18:03:36 +0000 Subject: [PATCH 14/34] update --- src/lightning_app/components/multi_node/pytorch_spawn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index eee38a5058241..e387a248acde0 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -34,6 +34,7 @@ def __call__( import torch # Remove the wrapper. + setattr_fn = self.work._setattr_replacement self.work._setattr_replacement = None nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 @@ -43,6 +44,9 @@ def __call__( nprocs=nprocs, ) + # Re-attach the wrapper. + self.work._setattr_replacement = setattr_fn + @staticmethod def run( local_rank: int, From e0b406bf5888e93fd6a96b0a2e29438de01f7ed4 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 19:47:39 +0000 Subject: [PATCH 15/34] update --- src/lightning_app/components/multi_node/pytorch_spawn.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index e387a248acde0..86eaf19f5527e 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -5,7 +5,7 @@ from lightning_app.components.multi_node.base import MultiNode from lightning_app.core.work import LightningWork from lightning_app.utilities.packaging.cloud_compute import CloudCompute -from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor +from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver @runtime_checkable @@ -59,7 +59,9 @@ def run( nprocs: int, ): if local_rank == 0: - _proxy_setattr(work, delta_queue, None) + state_observer = WorkStateObserver(work, delta_queue=delta_queue) + state_observer.start() + _proxy_setattr(work, delta_queue, state_observer) pass import torch @@ -81,6 +83,9 @@ def run( unwrap(work.run)(world_size, node_rank, global_rank, local_rank) + if local_rank == 0: + state_observer.join(0) + class PyTorchSpawnMultiNode(MultiNode): def __init__( From 55d34064571b8da83f4d28c1df1d3b6635d26d77 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 20:00:32 +0000 Subject: [PATCH 16/34] update --- src/lightning_app/components/multi_node/lite.py | 9 +++++++-- src/lightning_app/components/multi_node/pl.py | 9 +++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/lightning_app/components/multi_node/lite.py b/src/lightning_app/components/multi_node/lite.py index 0e711b049d7b3..07d4d2747589b 100644 --- a/src/lightning_app/components/multi_node/lite.py +++ b/src/lightning_app/components/multi_node/lite.py @@ -4,7 +4,7 @@ from typing_extensions import Protocol, runtime_checkable -from lightning.app.utilities.proxies import _proxy_setattr, unwrap +from lightning.app.utilities.proxies import _proxy_setattr, unwrap, WorkStateObserver from lightning_app.components.multi_node.base import MultiNode from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor from lightning_app.core.work import LightningWork @@ -33,7 +33,9 @@ def run( nprocs: int, ): if local_rank == 0: - _proxy_setattr(work, delta_queue, None) + state_observer = WorkStateObserver(work, delta_queue=delta_queue) + state_observer.start() + _proxy_setattr(work, delta_queue, state_observer) pass from lightning.lite import LightningLite @@ -76,6 +78,9 @@ def pre_fn(lite, *args, **kwargs): unwrap(work.run)() tracer._restore() + if local_rank == 0: + state_observer.join(0) + class LiteMultiNode(MultiNode): def __init__( diff --git a/src/lightning_app/components/multi_node/pl.py b/src/lightning_app/components/multi_node/pl.py index 159566ec9dd20..29bc9fa3a94fa 100644 --- a/src/lightning_app/components/multi_node/pl.py +++ b/src/lightning_app/components/multi_node/pl.py @@ -4,7 +4,7 @@ from typing_extensions import Protocol, runtime_checkable -from lightning.app.utilities.proxies import _proxy_setattr, unwrap +from lightning.app.utilities.proxies import _proxy_setattr, unwrap, WorkStateObserver from lightning_app.components.multi_node.base import MultiNode from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor from lightning_app.core.work import LightningWork @@ -33,7 +33,9 @@ def run( nprocs: int, ): if local_rank == 0: - _proxy_setattr(work, delta_queue, None) + state_observer = WorkStateObserver(work, delta_queue=delta_queue) + state_observer.start() + _proxy_setattr(work, delta_queue, state_observer) pass from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy @@ -75,6 +77,9 @@ def pre_fn(trainer, *args, **kwargs): unwrap(work.run)() tracer._restore() + if local_rank == 0: + state_observer.join(0) + class PyTorchLightningMultiNode(MultiNode): def __init__( From 3614a90426e15ddcaff269c7ddb59ede1bbc354e Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 21:25:35 +0000 Subject: [PATCH 17/34] update --- src/lightning_app/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 62b117a21ed13..311b2e735aa1f 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for running Lightning App with VSCode IDE debugger ([#15590](https://github.com/Lightning-AI/lightning/pull/15590)) +- Enabled MultiNode Components to support state broadcasting ([#15607](https://github.com/Lightning-AI/lightning/pull/15607)) + ### Changed From bc11d7b15ef8627aa426db28b19e03ea3f6ed4e7 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 21:26:20 +0000 Subject: [PATCH 18/34] update --- src/lightning_app/components/multi_node/pl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_app/components/multi_node/pl.py b/src/lightning_app/components/multi_node/pl.py index 29bc9fa3a94fa..6ae787e8f1978 100644 --- a/src/lightning_app/components/multi_node/pl.py +++ b/src/lightning_app/components/multi_node/pl.py @@ -4,11 +4,11 @@ from typing_extensions import Protocol, runtime_checkable -from lightning.app.utilities.proxies import _proxy_setattr, unwrap, WorkStateObserver from lightning_app.components.multi_node.base import MultiNode from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor from lightning_app.core.work import LightningWork from lightning_app.utilities.packaging.cloud_compute import CloudCompute +from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkStateObserver from lightning_app.utilities.tracer import Tracer From 756685e7e4fca58fb25ad63cb086d778799d6974 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 9 Nov 2022 22:19:42 +0000 Subject: [PATCH 19/34] update --- src/lightning_app/components/multi_node/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_app/components/multi_node/lite.py b/src/lightning_app/components/multi_node/lite.py index 07d4d2747589b..c3863f5ea9f1b 100644 --- a/src/lightning_app/components/multi_node/lite.py +++ b/src/lightning_app/components/multi_node/lite.py @@ -4,11 +4,11 @@ from typing_extensions import Protocol, runtime_checkable -from lightning.app.utilities.proxies import _proxy_setattr, unwrap, WorkStateObserver from lightning_app.components.multi_node.base import MultiNode from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor from lightning_app.core.work import LightningWork from lightning_app.utilities.packaging.cloud_compute import CloudCompute +from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkStateObserver from lightning_app.utilities.tracer import Tracer From 2a4a66014a2fdddfbbda72281146c6f9d5f046ea Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 10:27:37 +0000 Subject: [PATCH 20/34] update --- src/lightning_app/utilities/proxies.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index 0c43750ee8854..dec4fb41495b1 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -278,11 +278,10 @@ def run(self): work: "LightningWork" delta_queue: "BaseQueue" state_observer: Optional["WorkStateObserver"] - lock: threading.Lock def __call__(self, name: str, value: Any) -> None: logger.debug(f"Setting {name}: {value}") - with self.lock: + with _state_observer_lock: state = deepcopy(self.work.state) self.work._default_setattr(name, value) delta = Delta(DeepDiff(state, self.work.state, verbose_level=2)) @@ -656,6 +655,5 @@ def _proxy_setattr(work, delta_queue, state_observer: Optional[WorkStateObserver work, delta_queue=delta_queue, state_observer=state_observer, - lock=_state_observer_lock, ) work._setattr_replacement = setattr_proxy From 2382cca0baa96747714e2fd8c660aa08c23d60a5 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 14:52:39 +0000 Subject: [PATCH 21/34] update --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 005eba2846a31..bc8d9c7658dcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ module = [ "lightning_app.components.multi_node.lite", "lightning_app.components.multi_node.base", "lightning_app.components.multi_node.pytorch_spawn", - "lightning_app.components.multi_node.pl", + "lightning_app.components.multi_node.trainer", "lightning_app.api.http_methods", "lightning_app.api.request_types", "lightning_app.cli.commands.app_commands", From 8c23ca31b6130b36810bdca6cc8f75f992030883 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 15:05:21 +0000 Subject: [PATCH 22/34] update --- examples/app_multi_node/train_pl.py | 2 +- examples/app_multi_node/train_pytorch.py | 27 +++++++------- .../app_multi_node/train_pytorch_spawn.py | 37 ++++++++----------- 3 files changed, 29 insertions(+), 37 deletions(-) diff --git a/examples/app_multi_node/train_pl.py b/examples/app_multi_node/train_pl.py index 364680772722d..8193e813d9b5b 100644 --- a/examples/app_multi_node/train_pl.py +++ b/examples/app_multi_node/train_pl.py @@ -7,7 +7,7 @@ class PyTorchLightningDistributed(L.LightningWork): def run(self): model = BoringModel() trainer = L.Trainer( - max_epochs=10, + max_steps=1000, strategy="ddp", ) trainer.fit(model) diff --git a/examples/app_multi_node/train_pytorch.py b/examples/app_multi_node/train_pytorch.py index 825112a9c17f1..e1e8b5bb5c6db 100644 --- a/examples/app_multi_node/train_pytorch.py +++ b/examples/app_multi_node/train_pytorch.py @@ -18,29 +18,28 @@ def distributed_train(local_rank: int, main_address: str, main_port: int, num_no init_method=f"tcp://{main_address}:{main_port}", ) - # 2. Prepare distributed model - model = torch.nn.Linear(32, 2) + # 2. Prepare the model + model = torch.nn.Sequential( + torch.nn.Linear(1, 1), + torch.nn.ReLU(), + torch.nn.Linear(1, 1), + ) # 3. Setup distributed training - if torch.cuda.is_available(): - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) - else: - device = torch.device("cpu") - - model = model.to(device) - model = DistributedDataParallel(model, device_ids=[device.index] if torch.cuda.is_available() else None) + device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu") + model = DistributedDataParallel(model.to(device), device_ids=[device.index] if torch.cuda.is_available() else None) # 4. Prepare loss and optimizer criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - # 5. Train the model for 50 steps. - for step in range(50): + # 5. Train the model for 1000 steps. + for step in range(1000): model.zero_grad() - x = torch.randn(64, 32).to(device) + x = torch.tensor([0.8]).to(device) + target = torch.tensor([1.0]).to(device) output = model(x) - loss = criterion(output, torch.ones_like(output)) + loss = criterion(output, target) print(f"global_rank: {global_rank} step: {step} loss: {loss}") loss.backward() optimizer.step() diff --git a/examples/app_multi_node/train_pytorch_spawn.py b/examples/app_multi_node/train_pytorch_spawn.py index 0ece52926a903..8354ec3aa5b94 100644 --- a/examples/app_multi_node/train_pytorch_spawn.py +++ b/examples/app_multi_node/train_pytorch_spawn.py @@ -6,10 +6,6 @@ class PyTorchDistributed(L.LightningWork): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.step = 0 - def run( self, world_size: int, @@ -17,33 +13,30 @@ def run( global_rank: str, local_rank: int, ): - # 1. Prepare distributed model - model = torch.nn.Linear(32, 2) + # 1. Prepare the model + model = torch.nn.Sequential( + torch.nn.Linear(1, 1), + torch.nn.ReLU(), + torch.nn.Linear(1, 1), + ) # 2. Setup distributed training - if torch.cuda.is_available(): - device = torch.device(f"cuda:{local_rank}") - torch.cuda.set_device(device) - else: - device = torch.device("cpu") - - model = model.to(device) - model = DistributedDataParallel(model, device_ids=[device.index] if torch.cuda.is_available() else None) + device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu") + model = DistributedDataParallel( + model.to(device), device_ids=[device.index] if torch.cuda.is_available() else None + ) # 3. Prepare loss and optimizer criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - # 4. Train the model for 50 steps. - for step in range(50): - - # 4. Update step - self.step = step - + # 4. Train the model for 1000 steps. + for step in range(1000): model.zero_grad() - x = torch.randn(64, 32).to(device) + x = torch.tensor([0.8]).to(device) + target = torch.tensor([1.0]).to(device) output = model(x) - loss = criterion(output, torch.ones_like(output)) + loss = criterion(output, target) print(f"global_rank: {global_rank} step: {step} loss: {loss}") loss.backward() optimizer.step() From 324c9fefff1ff35f0e524eccc59e2c76ac2c71d6 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 15:16:34 +0000 Subject: [PATCH 23/34] update --- examples/app_multi_node/train_lite.py | 3 +-- .../components/multi_node/pytorch_spawn.py | 19 +++++++------------ src/lightning_app/utilities/proxies.py | 12 ++++++++++++ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/app_multi_node/train_lite.py b/examples/app_multi_node/train_lite.py index 33341e0870d49..a5632366de08f 100644 --- a/examples/app_multi_node/train_lite.py +++ b/examples/app_multi_node/train_lite.py @@ -40,6 +40,5 @@ def run(self): LitePyTorchDistributed, cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 num_nodes=2, - ), - log_level="debug", + ) ) diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index 86eaf19f5527e..79dc0d9ab471e 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -34,18 +34,13 @@ def __call__( import torch # Remove the wrapper. - setattr_fn = self.work._setattr_replacement - self.work._setattr_replacement = None - - nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 - torch.multiprocessing.spawn( - self.run, - args=(self.work, self.delta_queue, main_address, main_port, num_nodes, node_rank, nprocs), - nprocs=nprocs, - ) - - # Re-attach the wrapper. - self.work._setattr_replacement = setattr_fn + with self.disable_setattr_wrapper(): + nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 + torch.multiprocessing.spawn( + self.run, + args=(self.work, self.delta_queue, main_address, main_port, num_nodes, node_rank, nprocs), + nprocs=nprocs, + ) @staticmethod def run( diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index 8f652a92a372c..2395eb24ed6c0 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -7,6 +7,7 @@ import time import traceback import warnings +from contextlib import contextmanager from copy import deepcopy from dataclasses import dataclass, field from functools import partial @@ -350,6 +351,17 @@ class WorkRunExecutor: def __call__(self, *args, **kwargs): return self.work_run(*args, **kwargs) + @contextmanager + def disable_setattr_wrapper(self): + setattr_fn = self.work._setattr_replacement + self.work._setattr_replacement = None + try: + yield + except Exception as e: + self.work._setattr_replacement = setattr_fn + raise e + self.work._setattr_replacement = setattr_fn + @dataclass class WorkRunner: From 434a20f7cbf09aabbc81ce50dbabe15148af2a89 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 15:24:46 +0000 Subject: [PATCH 24/34] update --- src/lightning_app/components/multi_node/pytorch_spawn.py | 2 +- src/lightning_app/utilities/proxies.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index 79dc0d9ab471e..b8ae9843c0cb3 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -34,7 +34,7 @@ def __call__( import torch # Remove the wrapper. - with self.disable_setattr_wrapper(): + with self.make_work_pickalable(): nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 torch.multiprocessing.spawn( self.run, diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index 2395eb24ed6c0..e1d032271d978 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -352,15 +352,17 @@ def __call__(self, *args, **kwargs): return self.work_run(*args, **kwargs) @contextmanager - def disable_setattr_wrapper(self): + def enable_spawn(self): setattr_fn = self.work._setattr_replacement self.work._setattr_replacement = None + backend = self.work._backend try: yield except Exception as e: self.work._setattr_replacement = setattr_fn raise e self.work._setattr_replacement = setattr_fn + self.work._backend = backend @dataclass From 78c41468b04cbaa67a20272c72841fe7d33a9dc6 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 15:25:20 +0000 Subject: [PATCH 25/34] update --- src/lightning_app/components/multi_node/pytorch_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index b8ae9843c0cb3..f3792b1fd34c4 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -34,7 +34,7 @@ def __call__( import torch # Remove the wrapper. - with self.make_work_pickalable(): + with self.enable_spawn(): nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 torch.multiprocessing.spawn( self.run, From 061fdacf63e409415e6043b8a09f9af8efe18ec0 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 15:34:09 +0000 Subject: [PATCH 26/34] update --- examples/app_multi_node/train_lite.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/examples/app_multi_node/train_lite.py b/examples/app_multi_node/train_lite.py index a5632366de08f..80bda48c0e1f9 100644 --- a/examples/app_multi_node/train_lite.py +++ b/examples/app_multi_node/train_lite.py @@ -6,25 +6,21 @@ class LitePyTorchDistributed(L.LightningWork): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.step = 0 - def run(self): - # 1. Create LightningLite. + # 1. Prepare the model + model = torch.nn.Sequential( + torch.nn.Linear(1, 1), + torch.nn.ReLU(), + torch.nn.Linear(1, 1), + ) + + # 2. Create LightningLite. lite = LightningLite(strategy="ddp", precision=16) - - # 2. Prepare distributed model and optimizer. - model = torch.nn.Linear(32, 2) - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - model, optimizer = lite.setup(model, optimizer) + model, optimizer = lite.setup(model, torch.optim.SGD(model.parameters(), lr=0.01)) criterion = torch.nn.MSELoss() - # 3. Train the model for 50 steps. - for step in range(50): - - self.step = step - + # 3. Train the model for 1000 steps. + for step in range(1000): model.zero_grad() x = torch.randn(64, 32).to(lite.device) output = model(x) From 831b19e34de4e84b8d199cb7af616114cf1e58f3 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 16:13:53 +0000 Subject: [PATCH 27/34] update --- examples/app_multi_node/train_pytorch.py | 2 +- examples/app_multi_node/train_pytorch_spawn.py | 2 +- src/lightning_app/components/multi_node/pytorch_spawn.py | 1 - src/lightning_app/utilities/proxies.py | 4 +++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/app_multi_node/train_pytorch.py b/examples/app_multi_node/train_pytorch.py index e1e8b5bb5c6db..9ce662fa40009 100644 --- a/examples/app_multi_node/train_pytorch.py +++ b/examples/app_multi_node/train_pytorch.py @@ -27,7 +27,7 @@ def distributed_train(local_rank: int, main_address: str, main_port: int, num_no # 3. Setup distributed training device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu") - model = DistributedDataParallel(model.to(device), device_ids=[device.index] if torch.cuda.is_available() else None) + model = DistributedDataParallel(model.to(device), device_ids=[local_rank] if torch.cuda.is_available() else None) # 4. Prepare loss and optimizer criterion = torch.nn.MSELoss() diff --git a/examples/app_multi_node/train_pytorch_spawn.py b/examples/app_multi_node/train_pytorch_spawn.py index 8354ec3aa5b94..d29ec83562ffb 100644 --- a/examples/app_multi_node/train_pytorch_spawn.py +++ b/examples/app_multi_node/train_pytorch_spawn.py @@ -23,7 +23,7 @@ def run( # 2. Setup distributed training device = torch.device(f"cuda:{local_rank}") if torch.cuda.is_available() else torch.device("cpu") model = DistributedDataParallel( - model.to(device), device_ids=[device.index] if torch.cuda.is_available() else None + model.to(device), device_ids=[local_rank] if torch.cuda.is_available() else None ) # 3. Prepare loss and optimizer diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index f3792b1fd34c4..3da1384fb68c1 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -33,7 +33,6 @@ def __call__( ): import torch - # Remove the wrapper. with self.enable_spawn(): nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 torch.multiprocessing.spawn( diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index e1d032271d978..63492fa14e015 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -356,10 +356,12 @@ def enable_spawn(self): setattr_fn = self.work._setattr_replacement self.work._setattr_replacement = None backend = self.work._backend + self.work._backend = None try: yield except Exception as e: self.work._setattr_replacement = setattr_fn + self.work._backend = backend raise e self.work._setattr_replacement = setattr_fn self.work._backend = backend @@ -519,7 +521,7 @@ def run_once(self): used_runpy = True if user_exception: trace.append(p) - if "ret = self.run_executor_cls(self.work, work_run)(*args, **kwargs)" in p: + if "ret = self.run_executor_cls(" in p: user_exception = True if used_runpy: From f2d60909342a79a0dbed40ac16e180d2b33ca773 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 16:17:41 +0000 Subject: [PATCH 28/34] update --- tests/tests_app_examples/custom_work_dependencies/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_app_examples/custom_work_dependencies/app.py b/tests/tests_app_examples/custom_work_dependencies/app.py index cb993a7c263ec..c27d91ab142b6 100644 --- a/tests/tests_app_examples/custom_work_dependencies/app.py +++ b/tests/tests_app_examples/custom_work_dependencies/app.py @@ -51,4 +51,4 @@ def run(self): self._exit() -app = LightningApp(CustomWorkBuildConfigChecker(), log_level="debug") +app = LightningApp(CustomWorkBuildConfigChecker()) From c8d727f6e2ac4b95d39892da3e85a9a06522ae89 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 17:49:21 +0000 Subject: [PATCH 29/34] update --- .../app_multi_node/train_pytorch_spawn.py | 5 ++- .../components/multi_node/lite.py | 17 ++------- .../components/multi_node/pytorch_spawn.py | 38 ++++++++++++------- .../components/multi_node/trainer.py | 17 ++------- src/lightning_app/core/queues.py | 33 ++++++++++++++-- src/lightning_app/utilities/proxies.py | 33 ++++++++++------ 6 files changed, 83 insertions(+), 60 deletions(-) diff --git a/examples/app_multi_node/train_pytorch_spawn.py b/examples/app_multi_node/train_pytorch_spawn.py index d29ec83562ffb..949a8ac81a186 100644 --- a/examples/app_multi_node/train_pytorch_spawn.py +++ b/examples/app_multi_node/train_pytorch_spawn.py @@ -47,6 +47,7 @@ def run( PyTorchSpawnMultiNode( PyTorchDistributed, num_nodes=2, - cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 - ) + cloud_compute=L.CloudCompute("cpu"), # 4 x V100 + ), + log_level="debug", ) diff --git a/src/lightning_app/components/multi_node/lite.py b/src/lightning_app/components/multi_node/lite.py index c3863f5ea9f1b..2a9b33b0880d1 100644 --- a/src/lightning_app/components/multi_node/lite.py +++ b/src/lightning_app/components/multi_node/lite.py @@ -1,6 +1,6 @@ import os from dataclasses import dataclass -from typing import Any, Type +from typing import Any, Callable, Type from typing_extensions import Protocol, runtime_checkable @@ -8,7 +8,6 @@ from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor from lightning_app.core.work import LightningWork from lightning_app.utilities.packaging.cloud_compute import CloudCompute -from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkStateObserver from lightning_app.utilities.tracer import Tracer @@ -24,20 +23,13 @@ class _LiteRunExecutor(_PyTorchSpawnRunExecutor): @staticmethod def run( local_rank: int, - work: "LightningWork", - delta_queue, + work_run: Callable, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int, ): - if local_rank == 0: - state_observer = WorkStateObserver(work, delta_queue=delta_queue) - state_observer.start() - _proxy_setattr(work, delta_queue, state_observer) - pass - from lightning.lite import LightningLite from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy @@ -75,12 +67,9 @@ def pre_fn(lite, *args, **kwargs): tracer = Tracer() tracer.add_traced(LightningLite, "__init__", pre_fn=pre_fn) tracer._instrument() - unwrap(work.run)() + work_run() tracer._restore() - if local_rank == 0: - state_observer.join(0) - class LiteMultiNode(MultiNode): def __init__( diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index 3da1384fb68c1..dc25bfd1bdaca 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -1,8 +1,9 @@ -from typing import Any, Type +from typing import Any, Callable, Type from typing_extensions import Protocol, runtime_checkable from lightning_app.components.multi_node.base import MultiNode +from lightning_app.core.queues import MultiProcessQueue from lightning_app.core.work import LightningWork from lightning_app.utilities.packaging.cloud_compute import CloudCompute from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver @@ -35,28 +36,40 @@ def __call__( with self.enable_spawn(): nprocs = torch.cuda.device_count() if torch.cuda.is_available() else 1 + queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict() torch.multiprocessing.spawn( - self.run, - args=(self.work, self.delta_queue, main_address, main_port, num_nodes, node_rank, nprocs), + self.dispatch_run, + args=(self.__class__, self.work, queue, main_address, main_port, num_nodes, node_rank, nprocs), nprocs=nprocs, ) + @staticmethod + def dispatch_run(local_rank, cls, work, delta_queue, *args, **kwargs): + if local_rank == 0: + if isinstance(delta_queue, dict): + delta_queue = WorkRunExecutor.process_queue(delta_queue) + work._request_queue = WorkRunExecutor.process_queue(work._request_queue) + work._response_queue = WorkRunExecutor.process_queue(work._response_queue) + + state_observer = WorkStateObserver(work, delta_queue=delta_queue) + state_observer.start() + _proxy_setattr(work, delta_queue, state_observer) + + cls.run(local_rank, unwrap(work.run), *args, **kwargs) + + if local_rank == 0: + state_observer.join(0) + @staticmethod def run( local_rank: int, - work: "LightningWork", - delta_queue, + work_run: Callable, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int, ): - if local_rank == 0: - state_observer = WorkStateObserver(work, delta_queue=delta_queue) - state_observer.start() - _proxy_setattr(work, delta_queue, state_observer) - pass import torch @@ -75,10 +88,7 @@ def run( elif world_size > 1: raise Exception("Torch distributed should be available.") - unwrap(work.run)(world_size, node_rank, global_rank, local_rank) - - if local_rank == 0: - state_observer.join(0) + work_run(world_size, node_rank, global_rank, local_rank) class PyTorchSpawnMultiNode(MultiNode): diff --git a/src/lightning_app/components/multi_node/trainer.py b/src/lightning_app/components/multi_node/trainer.py index ec48cb7a4a718..222f71ce59557 100644 --- a/src/lightning_app/components/multi_node/trainer.py +++ b/src/lightning_app/components/multi_node/trainer.py @@ -1,6 +1,6 @@ import os from dataclasses import dataclass -from typing import Any, Type +from typing import Any, Callable, Type from typing_extensions import Protocol, runtime_checkable @@ -8,7 +8,6 @@ from lightning_app.components.multi_node.pytorch_spawn import _PyTorchSpawnRunExecutor from lightning_app.core.work import LightningWork from lightning_app.utilities.packaging.cloud_compute import CloudCompute -from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkStateObserver from lightning_app.utilities.tracer import Tracer @@ -24,20 +23,13 @@ class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor): @staticmethod def run( local_rank: int, - work: "LightningWork", - delta_queue, + work_run: Callable, main_address: str, main_port: int, num_nodes: int, node_rank: int, nprocs: int, ): - if local_rank == 0: - state_observer = WorkStateObserver(work, delta_queue=delta_queue) - state_observer.start() - _proxy_setattr(work, delta_queue, state_observer) - pass - from lightning.lite.strategies import DDPSpawnShardedStrategy, DDPSpawnStrategy from lightning.pytorch import Trainer as LTrainer from pytorch_lightning import Trainer as PLTrainer @@ -74,12 +66,9 @@ def pre_fn(trainer, *args, **kwargs): tracer.add_traced(PLTrainer, "__init__", pre_fn=pre_fn) tracer.add_traced(LTrainer, "__init__", pre_fn=pre_fn) tracer._instrument() - unwrap(work.run)() + work_run() tracer._restore() - if local_rank == 0: - state_observer.join(0) - class LightningTrainerMultiNode(MultiNode): def __init__( diff --git a/src/lightning_app/core/queues.py b/src/lightning_app/core/queues.py index 5d8f4e06ad429..a7fee9a3b6e12 100644 --- a/src/lightning_app/core/queues.py +++ b/src/lightning_app/core/queues.py @@ -235,12 +235,12 @@ def __init__( """ if name is None: raise ValueError("You must specify a name for the queue") - host = host or REDIS_HOST - port = port or REDIS_PORT - password = password or REDIS_PASSWORD + self.host = host or REDIS_HOST + self.port = port or REDIS_PORT + self.password = password or REDIS_PASSWORD self.name = name self.default_timeout = default_timeout - self.redis = redis.Redis(host=host, port=port, password=password) + self.redis = redis.Redis(host=self.host, port=self.port, password=self.password) def put(self, item: Any) -> None: from lightning_app import LightningWork @@ -329,6 +329,20 @@ def is_running(self) -> bool: except redis.exceptions.ConnectionError: return False + def to_dict(self): + return { + "type": "redis", + "name": self.name, + "default_timeout": self.default_timeout, + "host": self.host, + "port": self.port, + "password": self.password, + } + + @classmethod + def from_dict(cls, state): + return cls(**state) + class HTTPQueue(BaseQueue): def __init__(self, name: str, default_timeout: float): @@ -414,6 +428,17 @@ def _split_app_id_and_queue_name(queue_name): app_id, queue_name = queue_name.split("_", 1) return app_id, queue_name + def to_dict(self): + return { + "type": "http", + "name": self.name, + "default_timeout": self.default_timeout, + } + + @classmethod + def from_dict(cls, state): + return cls(**state) + def debug_log_callback(message: str, *args: Any, **kwargs: Any) -> None: if QUEUE_DEBUG_ENABLED or (Path(LIGHTNING_DIR) / "QUEUE_DEBUG_ENABLED").exists(): diff --git a/src/lightning_app/utilities/proxies.py b/src/lightning_app/utilities/proxies.py index 63492fa14e015..5ec99f222e47d 100644 --- a/src/lightning_app/utilities/proxies.py +++ b/src/lightning_app/utilities/proxies.py @@ -12,7 +12,7 @@ from dataclasses import dataclass, field from functools import partial from threading import Event, Thread -from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generator, Optional, Set, Tuple, Type, TYPE_CHECKING, Union from deepdiff import DeepDiff, Delta from lightning_utilities.core.apply_func import apply_to_collection @@ -352,19 +352,28 @@ def __call__(self, *args, **kwargs): return self.work_run(*args, **kwargs) @contextmanager - def enable_spawn(self): - setattr_fn = self.work._setattr_replacement + def enable_spawn(self) -> Generator: self.work._setattr_replacement = None - backend = self.work._backend self.work._backend = None - try: - yield - except Exception as e: - self.work._setattr_replacement = setattr_fn - self.work._backend = backend - raise e - self.work._setattr_replacement = setattr_fn - self.work._backend = backend + self._clean_queues() + yield + + def _clean_queues(self): + if "LIGHTNING_APP_STATE_URL" in os.environ: + self.work._request_queue = self.work._request_queue.to_dict() + self.work._response_queue = self.work._response_queue.to_dict() + + @staticmethod + def process_queue(queue): + from lightning_app.core.queues import HTTPQueue, RedisQueue + + if isinstance(queue, dict): + queue_type = queue.pop("type") + if queue_type == "redis": + return RedisQueue.from_dict(queue) + else: + return HTTPQueue.from_dict(queue) + return queue @dataclass From 2f8cc1ad81c0f4fdd3a2443d77e165cc032e7641 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 18:05:22 +0000 Subject: [PATCH 30/34] update --- MANIFEST.in | 7 +++++++ examples/app_multi_node/train_pytorch_spawn.py | 3 +-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 10af40c3dd1cf..7453da2135ca1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,3 +4,10 @@ exclude __pycache__ include .actions/setup_tools.py include .actions/assistant.py include *.cff # citation info +recursive-include src/lightning *.md +recursive-include requirements *.txt +recursive-include src/lightning/app/ui * +recursive-include src/lightning/cli/*-template * +prune src/lightning_app +prune src/lightning_lite +prune src/pytorch_lightning diff --git a/examples/app_multi_node/train_pytorch_spawn.py b/examples/app_multi_node/train_pytorch_spawn.py index 949a8ac81a186..c09978dbb4d0f 100644 --- a/examples/app_multi_node/train_pytorch_spawn.py +++ b/examples/app_multi_node/train_pytorch_spawn.py @@ -48,6 +48,5 @@ def run( PyTorchDistributed, num_nodes=2, cloud_compute=L.CloudCompute("cpu"), # 4 x V100 - ), - log_level="debug", + ) ) From 830ac249235883fba182a60220233e6b10242785 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 18:06:19 +0000 Subject: [PATCH 31/34] update --- examples/app_multi_node/train_pytorch_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/app_multi_node/train_pytorch_spawn.py b/examples/app_multi_node/train_pytorch_spawn.py index c09978dbb4d0f..d29ec83562ffb 100644 --- a/examples/app_multi_node/train_pytorch_spawn.py +++ b/examples/app_multi_node/train_pytorch_spawn.py @@ -47,6 +47,6 @@ def run( PyTorchSpawnMultiNode( PyTorchDistributed, num_nodes=2, - cloud_compute=L.CloudCompute("cpu"), # 4 x V100 + cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 ) ) From e51d57e707d3128ae22579b7713e4413811c9e53 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 18:06:54 +0000 Subject: [PATCH 32/34] update --- src/lightning_app/components/multi_node/pytorch_spawn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning_app/components/multi_node/pytorch_spawn.py b/src/lightning_app/components/multi_node/pytorch_spawn.py index dc25bfd1bdaca..3119ffc51e0b5 100644 --- a/src/lightning_app/components/multi_node/pytorch_spawn.py +++ b/src/lightning_app/components/multi_node/pytorch_spawn.py @@ -47,9 +47,9 @@ def __call__( def dispatch_run(local_rank, cls, work, delta_queue, *args, **kwargs): if local_rank == 0: if isinstance(delta_queue, dict): - delta_queue = WorkRunExecutor.process_queue(delta_queue) - work._request_queue = WorkRunExecutor.process_queue(work._request_queue) - work._response_queue = WorkRunExecutor.process_queue(work._response_queue) + delta_queue = cls.process_queue(delta_queue) + work._request_queue = cls.process_queue(work._request_queue) + work._response_queue = cls.process_queue(work._response_queue) state_observer = WorkStateObserver(work, delta_queue=delta_queue) state_observer.start() From 1402f34f3dbfaf66804f6b38525590ecc863a8d2 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 18:18:02 +0000 Subject: [PATCH 33/34] update --- examples/app_multi_node/train_lite.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/app_multi_node/train_lite.py b/examples/app_multi_node/train_lite.py index 80bda48c0e1f9..8e546b270a693 100644 --- a/examples/app_multi_node/train_lite.py +++ b/examples/app_multi_node/train_lite.py @@ -22,9 +22,10 @@ def run(self): # 3. Train the model for 1000 steps. for step in range(1000): model.zero_grad() - x = torch.randn(64, 32).to(lite.device) + x = torch.tensor([0.8]).to(lite.device) + target = torch.tensor([1.0]).to(lite.device) output = model(x) - loss = criterion(output, torch.ones_like(output)) + loss = criterion(output, target) print(f"global_rank: {lite.global_rank} step: {step} loss: {loss}") lite.backward(loss) optimizer.step() From 376a4a3972310d0464ba73c9d009b0b858684dd2 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 10 Nov 2022 18:27:14 +0000 Subject: [PATCH 34/34] update --- MANIFEST.in | 7 ------- 1 file changed, 7 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 7453da2135ca1..10af40c3dd1cf 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,10 +4,3 @@ exclude __pycache__ include .actions/setup_tools.py include .actions/assistant.py include *.cff # citation info -recursive-include src/lightning *.md -recursive-include requirements *.txt -recursive-include src/lightning/app/ui * -recursive-include src/lightning/cli/*-template * -prune src/lightning_app -prune src/lightning_lite -prune src/pytorch_lightning