Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions examples/app_argparse/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse

import lightning as L


class Work(L.LightningWork):
def __init__(self, cloud_compute):
super().__init__(cloud_compute=cloud_compute)

def run(self):
pass


class Flow(L.LightningFlow):
def __init__(self, cloud_compute):
super().__init__()
self.work = Work(cloud_compute)

def run(self):
assert self.work.cloud_compute.name == "gpu", self.work.cloud_compute.name
self._exit()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--use_gpu", action="store_true", default=False, help="Whether to use GPU in the cloud")
hparams = parser.parse_args()
app = L.LightningApp(Flow(L.CloudCompute("gpu" if hparams.use_gpu else "cpu")))
19 changes: 11 additions & 8 deletions src/lightning_app/cli/lightning_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
from pathlib import Path
from typing import Tuple, Union
from typing import List, Tuple, Union

import click
from requests.exceptions import ConnectionError
Expand Down Expand Up @@ -109,8 +109,17 @@ def run():
@click.option("--blocking", "blocking", type=bool, default=False)
@click.option("--open-ui", type=bool, default=True, help="Decide whether to launch the app UI in a web browser")
@click.option("--env", type=str, default=[], multiple=True, help="Env variables to be set for the app.")
@click.option("--app_args", type=str, default=[], multiple=True, help="Collection of arguments for the app.")
def run_app(
file: str, cloud: bool, without_server: bool, no_cache: bool, name: str, blocking: bool, open_ui: bool, env: tuple
file: str,
cloud: bool,
without_server: bool,
no_cache: bool,
name: str,
blocking: bool,
open_ui: bool,
env: tuple,
app_args: List[str],
):
"""Run an app from a file."""
_run_app(file, cloud, without_server, no_cache, name, blocking, open_ui, env)
Expand Down Expand Up @@ -263,10 +272,4 @@ def _prepare_file(file: str) -> str:
if exists:
return file

if not exists and file == "quick_start.py":
from lightning_app.demo.quick_start import app

logger.info(f"For demo purposes, Lightning will run the {app.__file__} file.")
return app.__file__

raise FileNotFoundError(f"The provided file {file} hasn't been found.")
10 changes: 8 additions & 2 deletions src/lightning_app/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,20 @@ def restart_work(self, work_name: str):


@requires("click")
def application_testing(lightning_app_cls: Type[LightningTestApp], command_line: List[str] = []) -> Any:
def application_testing(
lightning_app_cls: Type[LightningTestApp] = LightningTestApp, command_line: List[str] = []
) -> Any:
from unittest import mock

from click.testing import CliRunner

with mock.patch("lightning.LightningApp", lightning_app_cls):
original = sys.argv
sys.argv = command_line
runner = CliRunner()
return runner.invoke(run_app, command_line, catch_exceptions=False)
result = runner.invoke(run_app, command_line, catch_exceptions=False)
sys.argv = original
return result


class SingleWorkFlow(LightningFlow):
Expand Down
46 changes: 45 additions & 1 deletion src/lightning_app/utilities/load_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import traceback
import types
from contextlib import contextmanager
from typing import Dict, List, TYPE_CHECKING, Union

from lightning_app.utilities.exceptions import MisconfigurationException
Expand All @@ -26,7 +27,8 @@ def load_app_from_file(filepath: str) -> "LightningApp":
code = _create_code(filepath)
module = _create_fake_main_module(filepath)
try:
exec(code, module.__dict__)
with _patch_sys_argv():
exec(code, module.__dict__)
except Exception:
# we want to format the exception as if no frame was on top.
exp, val, tb = sys.exc_info()
Expand Down Expand Up @@ -113,6 +115,48 @@ def _create_fake_main_module(script_path):
return module


@contextmanager
def _patch_sys_argv():
"""This function modifies the ``sys.argv`` by extracting the arguments after ``--app_args`` and removed
everything else before executing the user app script.

The command: ``lightning run app app.py --without-server --app_args --use_gpu --env ...`` will be converted into
``app.py --use_gpu``
"""
from lightning_app.cli.lightning_cli import run_app

original_argv = sys.argv
# 1: Remove the CLI command
if sys.argv[:3] == ["lightning", "run", "app"]:
sys.argv = sys.argv[3:]

if "--app_args" not in sys.argv:
# 2: If app_args wasn't used, there is no arguments, so we assign the shorten arguments.
new_argv = sys.argv[:1]
else:
# 3: Collect all the arguments from the CLI
options = [p.opts[0] for p in run_app.params[1:] if p.opts[0] != "--app_args"]
argv_slice = sys.argv
# 4: Find the index of `app_args`
first_index = argv_slice.index("--app_args") + 1
# 5: Find the next argument from the CLI if any.
matches = [
argv_slice.index(opt) for opt in options if opt in argv_slice and argv_slice.index(opt) >= first_index
]
if not matches:
last_index = len(argv_slice)
else:
last_index = min(matches)
# 6: last_index is either the fully command or the latest match from the CLI options.
new_argv = [argv_slice[0]] + argv_slice[first_index:last_index]

# 7: Patch the command
sys.argv = new_argv
yield
# 8: Restore the command
sys.argv = original_argv


def component_to_metadata(obj: Union["LightningWork", "LightningFlow"]) -> Dict:
from lightning_app import LightningWork

Expand Down
67 changes: 67 additions & 0 deletions tests/tests_app_examples/test_argparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import sys

from lightning_app import _PACKAGE_ROOT
from lightning_app.testing.testing import application_testing
from lightning_app.utilities.load_app import _patch_sys_argv


def test_app_argparse_example():
original_argv = sys.argv

command_line = [
os.path.join(os.path.dirname(os.path.dirname(_PACKAGE_ROOT)), "examples/app_argparse/app.py"),
"--app_args",
"--use_gpu",
"--without-server",
]
result = application_testing(command_line=command_line)
assert result.exit_code == 0, result.__dict__
assert sys.argv == original_argv


def test_patch_sys_argv():
original_argv = sys.argv

sys.argv = expected = ["lightning", "run", "app", "app.py"]
with _patch_sys_argv():
assert sys.argv == ["app.py"]

assert sys.argv == expected

sys.argv = expected = ["lightning", "run", "app", "app.py", "--without-server", "--env", "name=something"]
with _patch_sys_argv():
assert sys.argv == ["app.py"]

assert sys.argv == expected

sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args"]
with _patch_sys_argv():
assert sys.argv == ["app.py"]

assert sys.argv == expected

sys.argv = expected = ["lightning", "run", "app", "app.py", "--app_args", "--env", "name=something"]
with _patch_sys_argv():
assert sys.argv == ["app.py"]

assert sys.argv == expected

sys.argv = expected = [
"lightning",
"run",
"app",
"app.py",
"--without-server",
"--app_args",
"--use_gpu",
"--name=hello",
"--env",
"name=something",
]
with _patch_sys_argv():
assert sys.argv == ["app.py", "--use_gpu", "--name=hello"]

assert sys.argv == expected

sys.argv = original_argv