Skip to content

Commit

Permalink
Ignore num_nodes when running MultiNode components locally (#15806)
Browse files Browse the repository at this point in the history
(cherry picked from commit a970f09)
  • Loading branch information
awaelchli authored and Borda committed Nov 30, 2022
1 parent a33d9b8 commit d9f6317
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- The `MultiNode` components now warn the user when running with `num_nodes > 1` locally ([#15806](https://github.com/Lightning-AI/lightning/pull/15806))


### Deprecated
Expand Down
15 changes: 13 additions & 2 deletions src/lightning_app/components/multi_node/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import warnings
from typing import Any, Type

from lightning_app import structures
from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.cloud import is_running_in_cloud
from lightning_app.utilities.packaging.cloud_compute import CloudCompute


Expand Down Expand Up @@ -45,12 +47,21 @@ def run(
Arguments:
work_cls: The work to be executed
num_nodes: Number of nodes.
cloud_compute: The cloud compute object used in the cloud.
num_nodes: Number of nodes. Gets ignored when running locally. Launch the app with --cloud to run on
multiple cloud machines.
cloud_compute: The cloud compute object used in the cloud. The value provided here gets ignored when
running locally.
work_args: Arguments to be provided to the work on instantiation.
work_kwargs: Keywords arguments to be provided to the work on instantiation.
"""
super().__init__()
if num_nodes > 1 and not is_running_in_cloud():
num_nodes = 1
warnings.warn(
f"You set {type(self).__name__}(num_nodes={num_nodes}, ...)` but this app is running locally."
" We assume you are debugging and will ignore the `num_nodes` argument."
" To run on multiple nodes in the cloud, launch your app with `--cloud`."
)
self.ws = structures.List(
*[
work_cls(
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_app/components/multi_node/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from re import escape

import pytest
from tests_app.helpers.utils import no_warning_call

from lightning_app import CloudCompute, LightningWork
from lightning_app.components import MultiNode


def test_multi_node_warn_running_locally():
class Work(LightningWork):
def run(self):
pass

with pytest.warns(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu"))

with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
MultiNode(Work, num_nodes=1, cloud_compute=CloudCompute("gpu"))
Empty file.
30 changes: 30 additions & 0 deletions tests/tests_app/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import re
from contextlib import contextmanager
from typing import Optional, Type

import pytest


@contextmanager
def no_warning_call(expected_warning: Type[Warning] = UserWarning, match: Optional[str] = None):
# TODO: Replace with `lightning_utilities.test.warning.no_warning_call`
# https://github.com/Lightning-AI/utilities/issues/57

with pytest.warns(None) as record:
yield

if match is None:
try:
w = record.pop(expected_warning)
except AssertionError:
# no warning raised
return
else:
for w in record.list:
if w.category is expected_warning and re.compile(match).search(w.message.args[0]):
break
else:
return

msg = "A warning" if expected_warning is None else f"`{expected_warning.__name__}`"
raise AssertionError(f"{msg} was raised: {w}")
7 changes: 5 additions & 2 deletions tests/tests_examples_app/public/test_multi_node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
from unittest import mock

import pytest
from tests_examples_app.public import _PATH_EXAMPLES
Expand All @@ -17,7 +18,8 @@ def on_before_run_once(self):


@pytest.mark.skip(reason="flaky")
def test_multi_node_example(monkeypatch):
@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True)
def test_multi_node_example(_, monkeypatch):
monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node"))
command_line = [
"app.py",
Expand Down Expand Up @@ -50,7 +52,8 @@ def on_before_run_once(self):
],
)
@pytest.mark.skipif(sys.platform == "win32", reason="flaky")
def test_multi_node_examples(app_name, monkeypatch):
@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", return_value=True)
def test_multi_node_examples(_, app_name, monkeypatch):
monkeypatch.chdir(os.path.join(_PATH_EXAMPLES, "app_multi_node"))
command_line = [
app_name,
Expand Down

0 comments on commit d9f6317

Please sign in to comment.