diff --git a/pyproject.toml b/pyproject.toml index c0bac18466c85..5508e98981877 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,6 +141,7 @@ exclude = [ "src/lightning/app/cli/component-template", "src/lightning/app/cli/pl-app-template", "src/lightning/app/cli/react-ui-template", + "src/lightning/app/launcher", ] install_types = "True" non_interactive = "True" diff --git a/src/lightning/app/cli/lightning_cli.py b/src/lightning/app/cli/lightning_cli.py index 43c9e82ff477f..9bf8877fac975 100644 --- a/src/lightning/app/cli/lightning_cli.py +++ b/src/lightning/app/cli/lightning_cli.py @@ -39,6 +39,7 @@ ) from lightning.app.cli.connect.data import connect_data from lightning.app.cli.lightning_cli_delete import delete +from lightning.app.cli.lightning_cli_launch import launch from lightning.app.cli.lightning_cli_list import get_list from lightning.app.core.constants import ENABLE_APP_COMMENT_COMMAND_EXECUTION, get_lightning_cloud_url from lightning.app.runners.cloud import CloudRuntime @@ -324,6 +325,7 @@ def open(path: str, name: str) -> None: _main.add_command(get_list) _main.add_command(delete) +_main.add_command(launch) _main.add_command(cmd_install.install) diff --git a/src/lightning/app/cli/lightning_cli_launch.py b/src/lightning/app/cli/lightning_cli_launch.py new file mode 100644 index 0000000000000..8cf56453d86f9 --- /dev/null +++ b/src/lightning/app/cli/lightning_cli_launch.py @@ -0,0 +1,127 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Tuple + +import click + +from lightning.app.core.constants import APP_SERVER_HOST, APP_SERVER_PORT +from lightning.app.launcher.launcher import ( + run_lightning_flow, + run_lightning_work, + serve_frontend, + start_application_server, + start_flow_and_servers, +) + +logger = logging.getLogger(__name__) + + +@click.group(name="launch", hidden=True) +def launch() -> None: + """Launch your application.""" + + +@launch.command("server", hidden=True) +@click.argument("file", type=click.Path(exists=True)) +@click.option("--queue-id", help="ID for identifying queue", default="", type=str) +@click.option("--host", help="Application running host", default=APP_SERVER_HOST, type=str) +@click.option("--port", help="Application running port", default=APP_SERVER_PORT, type=int) +def run_server(file: str, queue_id: str, host: str, port: int) -> None: + """It takes the application file as input, build the application object and then use that to run the + application server. + + This is used by the cloud runners to start the status server for the application + """ + logger.debug(f"Run Server: {file} {queue_id} {host} {port}") + start_application_server(file, host, port, queue_id=queue_id) + + +@launch.command("flow", hidden=True) +@click.argument("file", type=click.Path(exists=True)) +@click.option("--queue-id", help="ID for identifying queue", default="", type=str) +@click.option("--base-url", help="Base url at which the app server is hosted", default="") +def run_flow(file: str, queue_id: str, base_url: str) -> None: + """It takes the application file as input, build the application object, proxy all the work components and then + run the application flow defined in the root component. + + It does exactly what a singleprocess dispatcher would do but with proxied work components. + """ + logger.debug(f"Run Flow: {file} {queue_id} {base_url}") + run_lightning_flow(file, queue_id=queue_id, base_url=base_url) + + +@launch.command("work", hidden=True) +@click.argument("file", type=click.Path(exists=True)) +@click.option("--work-name", type=str) +@click.option("--queue-id", help="ID for identifying queue", default="", type=str) +def run_work(file: str, work_name: str, queue_id: str) -> None: + """Unlike other entrypoints, this command will take the file path or module details for a work component and + run that by fetching the states from the queues.""" + logger.debug(f"Run Work: {file} {work_name} {queue_id}") + run_lightning_work( + file=file, + work_name=work_name, + queue_id=queue_id, + ) + + +@launch.command("frontend", hidden=True) +@click.argument("file", type=click.Path(exists=True)) +@click.option("--flow-name") +@click.option("--host") +@click.option("--port", type=int) +def run_frontend(file: str, flow_name: str, host: str, port: int) -> None: + """Serve the frontend specified by the given flow.""" + logger.debug(f"Run Frontend: {file} {flow_name} {host}") + serve_frontend(file=file, flow_name=flow_name, host=host, port=port) + + +@launch.command("flow-and-servers", hidden=True) +@click.argument("file", type=click.Path(exists=True)) +@click.option("--queue-id", help="ID for identifying queue", default="", type=str) +@click.option("--base-url", help="Base url at which the app server is hosted", default="") +@click.option("--host", help="Application running host", default=APP_SERVER_HOST, type=str) +@click.option("--port", help="Application running port", default=APP_SERVER_PORT, type=int) +@click.option( + "--flow-port", + help="Pair of flow name and frontend port", + type=(str, int), + multiple=True, +) +def run_flow_and_servers( + file: str, + base_url: str, + queue_id: str, + host: str, + port: int, + flow_port: Tuple[Tuple[str, int]], +) -> None: + """It takes the application file as input, build the application object and then use that to run the + application flow defined in the root component, the application server and all the flow frontends. + + This is used by the cloud runners to start the flow, the status server and all frontends for the application + """ + logger.debug(f"Run Flow: {file} {queue_id} {base_url}") + logger.debug(f"Run Server: {file} {queue_id} {host} {port}.") + logger.debug(f"Run Frontend's: {flow_port}") + start_flow_and_servers( + entrypoint_file=file, + base_url=base_url, + queue_id=queue_id, + host=host, + port=port, + flow_names_and_ports=flow_port, + ) diff --git a/src/lightning/app/launcher/__init__.py b/src/lightning/app/launcher/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/lightning/app/launcher/launcher.py b/src/lightning/app/launcher/launcher.py new file mode 100644 index 0000000000000..8f00731161dfc --- /dev/null +++ b/src/lightning/app/launcher/launcher.py @@ -0,0 +1,439 @@ +import inspect +import logging +import os +import signal +import sys +import time +import traceback +from functools import partial +from multiprocessing import Process +from typing import Callable, Dict, List, Optional, Tuple, TypedDict + +ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER = bool(int(os.getenv("ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER", "0"))) + +if True: # Avoid Module level import not at top of file + from lightning.app import LightningFlow + from lightning.app.core import constants + from lightning.app.core.api import start_server + from lightning.app.core.queues import MultiProcessQueue, QueuingSystem + from lightning.app.storage.orchestrator import StorageOrchestrator + from lightning.app.utilities.app_commands import run_app_commands + from lightning.app.utilities.cloud import _sigterm_flow_handler + from lightning.app.utilities.component import _set_flow_context, _set_frontend_context + from lightning.app.utilities.enum import AppStage + from lightning.app.utilities.exceptions import ExitAppException + from lightning.app.utilities.load_app import extract_metadata_from_app, load_app_from_file + from lightning.app.utilities.proxies import WorkRunner + from lightning.app.utilities.redis import check_if_redis_running + +if ENABLE_MULTIPLE_WORKS_IN_DEFAULT_CONTAINER: + from lightning.app.launcher.lightning_hybrid_backend import CloudHybridBackend as CloudBackend +else: + from lightning.app.launcher.lightning_backend import CloudBackend + +if True: # Avoid Module level import not at top of file + from lightning.app.utilities.app_helpers import convert_print_to_logger_info + from lightning.app.utilities.packaging.lightning_utils import enable_debugging + +if hasattr(constants, "get_cloud_queue_type"): + CLOUD_QUEUE_TYPE = constants.get_cloud_queue_type() or "redis" +else: + CLOUD_QUEUE_TYPE = "redis" + +logger = logging.getLogger(__name__) + + +class FlowRestAPIQueues(TypedDict): + api_publish_state_queue: MultiProcessQueue + api_response_queue: MultiProcessQueue + + +@convert_print_to_logger_info +@enable_debugging +def start_application_server( + entrypoint_file: str, host: str, port: int, queue_id: str, queues: Optional[FlowRestAPIQueues] = None +): + logger.debug(f"Run Lightning Work {entrypoint_file} {host} {port} {queue_id}") + queue_system = QueuingSystem(CLOUD_QUEUE_TYPE) + + wait_for_queues(queue_system) + + kwargs = { + "api_delta_queue": queue_system.get_api_delta_queue(queue_id=queue_id), + } + + # Note: Override the queues if provided + if isinstance(queues, Dict): + kwargs.update(queues) + else: + kwargs.update( + { + "api_publish_state_queue": queue_system.get_api_state_publish_queue(queue_id=queue_id), + "api_response_queue": queue_system.get_api_response_queue(queue_id=queue_id), + } + ) + + app = load_app_from_file(entrypoint_file) + + from lightning.app.api.http_methods import _add_tags_to_api, _validate_api + from lightning.app.utilities.app_helpers import is_overridden + from lightning.app.utilities.commands.base import _commands_to_api, _prepare_commands + + apis = [] + if is_overridden("configure_api", app.root): + apis = app.root.configure_api() + _validate_api(apis) + _add_tags_to_api(apis, ["app_api"]) + + if is_overridden("configure_commands", app.root): + commands = _prepare_commands(app) + apis += _commands_to_api(commands) + + start_server( + host=host, + port=port, + apis=apis, + **kwargs, + spec=extract_metadata_from_app(app), + ) + + +@convert_print_to_logger_info +@enable_debugging +def run_lightning_work( + file: str, + work_name: str, + queue_id: str, +): + """This staticmethod runs the specified work in the current process. + + It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud + specific logic is being implemented here + """ + logger.debug(f"Run Lightning Work {file} {work_name} {queue_id}") + + queues = QueuingSystem(CLOUD_QUEUE_TYPE) + wait_for_queues(queues) + + caller_queue = queues.get_caller_queue(work_name=work_name, queue_id=queue_id) + readiness_queue = queues.get_readiness_queue(queue_id=queue_id) + delta_queue = queues.get_delta_queue(queue_id=queue_id) + error_queue = queues.get_error_queue(queue_id=queue_id) + + request_queues = queues.get_orchestrator_request_queue(work_name=work_name, queue_id=queue_id) + response_queues = queues.get_orchestrator_response_queue(work_name=work_name, queue_id=queue_id) + copy_request_queues = queues.get_orchestrator_copy_request_queue(work_name=work_name, queue_id=queue_id) + copy_response_queues = queues.get_orchestrator_copy_response_queue(work_name=work_name, queue_id=queue_id) + + run_app_commands(file) + + load_app_from_file(file) + + queue = queues.get_work_queue(work_name=work_name, queue_id=queue_id) + work = queue.get() + + extras = {} + + if hasattr(work, "_run_executor_cls"): + extras["run_executor_cls"] = work._run_executor_cls + + WorkRunner( + work=work, + work_name=work_name, + caller_queue=caller_queue, + delta_queue=delta_queue, + readiness_queue=readiness_queue, + error_queue=error_queue, + request_queue=request_queues, + response_queue=response_queues, + copy_request_queue=copy_request_queues, + copy_response_queue=copy_response_queues, + **extras, + )() + + +@convert_print_to_logger_info +@enable_debugging +def run_lightning_flow(entrypoint_file: str, queue_id: str, base_url: str, queues: Optional[FlowRestAPIQueues] = None): + _set_flow_context() + + logger.debug(f"Run Lightning Flow {entrypoint_file} {queue_id} {base_url}") + + app = load_app_from_file(entrypoint_file) + app.backend = CloudBackend(entrypoint_file, queue_id=queue_id) + + queue_system = app.backend.queues + app.backend.update_lightning_app_frontend(app) + wait_for_queues(queue_system) + + app.backend.resolve_url(app, base_url) + if app.root_path != "": + app._update_index_file() + app.backend._prepare_queues(app) + + # Note: Override the queues if provided + if queues: + app.api_publish_state_queue = queues["api_publish_state_queue"] + app.api_response_queue = queues["api_response_queue"] + + LightningFlow._attach_backend(app.root, app.backend) + + app.should_publish_changes_to_api = True + + storage_orchestrator = StorageOrchestrator( + app, + app.request_queues, + app.response_queues, + app.copy_request_queues, + app.copy_response_queues, + ) + storage_orchestrator.setDaemon(True) + storage_orchestrator.start() + + # refresh the layout with the populated urls. + app._update_layout() + + # register a signal handler to clean all works. + if sys.platform != "win32": + signal.signal(signal.SIGTERM, partial(_sigterm_flow_handler, app=app)) + + if "apis" in inspect.signature(start_server).parameters: + from lightning.app.utilities.commands.base import _prepare_commands + + _prepare_commands(app) + + # Once the bootstrapping is done, running the rank 0 + # app with all the components inactive + try: + app._run() + except ExitAppException: + pass + except Exception: + app.stage = AppStage.FAILED + print(traceback.format_exc()) + + storage_orchestrator.join(0) + app.backend.stop_all_works(app.works) + + exit_code = 1 if app.stage == AppStage.FAILED else 0 + print(f"Finishing the App with exit_code: {str(exit_code)}...") + + if not exit_code: + app.backend.stop_app(app) + + sys.exit(exit_code) + + +@convert_print_to_logger_info +@enable_debugging +def serve_frontend(file: str, flow_name: str, host: str, port: int): + """This staticmethod runs the specified frontend for a given flow in a new process. + + It is organized under cloud runtime to indicate that it will be used by the cloud runner but otherwise, no cloud + specific logic is being implemented here. + """ + _set_frontend_context() + logger.debug(f"Run Serve Frontend {file} {flow_name} {host} {port}") + app = load_app_from_file(file) + if flow_name not in app.frontends: + raise ValueError(f"Could not find frontend for flow with name {flow_name}.") + frontend = app.frontends[flow_name] + assert frontend.flow.name == flow_name + + frontend.start_server(host, port) + + +def start_server_in_process(target: Callable, args: Tuple = (), kwargs: Dict = {}) -> Process: + p = Process(target=target, args=args, kwargs=kwargs) + p.start() + return p + + +def format_row(elements, col_widths, padding=1): + elements = [el.ljust(w - padding * 2) for el, w in zip(elements, col_widths)] + pad = " " * padding + elements = [f"{pad}{el}{pad}" for el in elements] + return f'|{"|".join(elements)}|' + + +def tabulate(data, headers): + data = [[str(el) for el in row] for row in data] + col_widths = [len(el) for el in headers] + for row in data: + col_widths = [max(len(el), curr) for el, curr in zip(row, col_widths)] + col_widths = [w + 2 for w in col_widths] + seps = ["-" * w for w in col_widths] + lines = [format_row(headers, col_widths), format_row(seps, col_widths, padding=0)] + [ + format_row(row, col_widths) for row in data + ] + return "\n".join(lines) + + +def manage_server_processes(processes: List[Tuple[str, Process]]) -> None: + if not processes: + return + + sigterm_called = [False] + + def _sigterm_handler(*_): + sigterm_called[0] = True + + if sys.platform != "win32": + signal.signal(signal.SIGTERM, _sigterm_handler) + + # Since frontends run user code, any of them could fail. In that case, + # we want to fail all of them, as well as the application server, and + # exit the command with an error status code. + + exitcode = 0 + + while True: + # We loop until + # 1. Get a sigterm + # 2. All the children died but all with exit code 0 + # 3. At-least one of the child died with non-zero exit code + + # sleeping quickly at the starting of every loop + # moving this to the end of the loop might result in some flaky tests + time.sleep(1) + + if sigterm_called[0]: + print("Got SIGTERM. Exiting execution!!!") + break + if all(not p.is_alive() and p.exitcode == 0 for _, p in processes): + print("All the components are inactive with exitcode 0. Exiting execution!!!") + break + if any((not p.is_alive() and p.exitcode != 0) for _, p in processes): + print("Found dead components with non-zero exit codes, exiting execution!!! Components: ") + print( + tabulate( + [(name, p.exitcode) for name, p in processes if not p.is_alive() and p.exitcode != 0], + headers=["Name", "Exit Code"], + ) + ) + exitcode = 1 + break + + # sleeping for the last set of logs to reach stdout + time.sleep(2) + + # Cleanup + for _, p in processes: + if p.is_alive(): + os.kill(p.pid, signal.SIGTERM) + + # Give processes time to terminate + for _, p in processes: + p.join(5) + + # clean the remaining ones. + if any(p.is_alive() for _, p in processes): + for _, p in processes: + if p.is_alive(): + os.kill(p.pid, signal.SIGKILL) + + # this sleep is just a precaution - signals might take a while to get raised. + time.sleep(1) + sys.exit(1) + + sys.exit(exitcode) + + +def _get_frontends_from_app(entrypoint_file): + """This function is used to get the frontends from the app. It will be used to start the frontends in a + separate process if the backend cannot provide flow_names_and_ports. This is useful if the app cannot be loaded + locally to set the frontend before dispatching to the cloud. The backend exposes by default 10 ports from 8081 + if the app.spec.frontends is not set. + + NOTE: frontend_name are sorted to ensure that they get consistent ports. + + :param entrypoint_file: The entrypoint file for the app + :return: A list of tuples of the form (frontend_name, port_number) + """ + app = load_app_from_file(entrypoint_file) + + frontends = [] + # This value of the port should be synced with the port value in the backend. + # If you change this value, you should also change the value in the backend. + flow_frontends_starting_port = 8081 + for frontend in sorted(app.frontends.keys()): + frontends.append((frontend, flow_frontends_starting_port)) + flow_frontends_starting_port += 1 + + return frontends + + +@convert_print_to_logger_info +@enable_debugging +def start_flow_and_servers( + entrypoint_file: str, + base_url: str, + queue_id: str, + host: str, + port: int, + flow_names_and_ports: Tuple[Tuple[str, int]], +): + processes: List[Tuple[str, Process]] = [] + + # Queues between Flow and its Rest API are using multiprocessing to: + # - reduce redis load + # - increase UI responsiveness and RPS + queue_system = QueuingSystem.MULTIPROCESS + queues = { + "api_publish_state_queue": queue_system.get_api_state_publish_queue(queue_id=queue_id), + "api_response_queue": queue_system.get_api_response_queue(queue_id=queue_id), + } + + # In order to avoid running this function 3 seperate times while executing the + # `run_lightning_flow`, `start_application_server`, & `serve_frontend` functions + # in a subprocess we extract this to the top level. If we intend to make changes + # to be able to start these components in seperate containers, the implementation + # will have to move a call to this function within the initialization process. + run_app_commands(entrypoint_file) + + flow_process = start_server_in_process( + run_lightning_flow, + args=( + entrypoint_file, + queue_id, + base_url, + ), + kwargs={"queues": queues}, + ) + processes.append(("Flow", flow_process)) + + server_process = start_server_in_process( + target=start_application_server, + args=( + entrypoint_file, + host, + port, + queue_id, + ), + kwargs={"queues": queues}, + ) + processes.append(("Server", server_process)) + + if not flow_names_and_ports: + flow_names_and_ports = _get_frontends_from_app(entrypoint_file) + + for name, fe_port in flow_names_and_ports: + frontend_process = start_server_in_process(target=serve_frontend, args=(entrypoint_file, name, host, fe_port)) + processes.append((name, frontend_process)) + + manage_server_processes(processes) + + +def wait_for_queues(queue_system: QueuingSystem) -> None: + queue_check_start_time = int(time.time()) + + if hasattr(queue_system, "get_queue"): + while not queue_system.get_queue("healthz").is_running: + if (int(time.time()) - queue_check_start_time) % 10 == 0: + logger.warning("Waiting for http queues to start...") + time.sleep(1) + else: + while not check_if_redis_running(): + if (int(time.time()) - queue_check_start_time) % 10 == 0: + logger.warning("Waiting for redis queues to start...") + time.sleep(1) diff --git a/src/lightning/app/launcher/lightning_backend.py b/src/lightning/app/launcher/lightning_backend.py new file mode 100644 index 0000000000000..1e3c096e45cf1 --- /dev/null +++ b/src/lightning/app/launcher/lightning_backend.py @@ -0,0 +1,523 @@ +import inspect +import json +import logging +import os +import random +import string +import urllib +from time import monotonic, sleep, time +from typing import List, Optional + +from lightning_cloud.openapi import ( + AppinstancesIdBody, + Externalv1LightningappInstance, + Externalv1Lightningwork, + V1BuildSpec, + V1Drive, + V1DriveSpec, + V1DriveStatus, + V1DriveType, + V1Flowserver, + V1LightningappInstanceState, + V1LightningappRestartPolicy, + V1LightningworkClusterDriver, + V1LightningworkDrives, + V1LightningworkSpec, + V1LightningworkState, + V1ListLightningworkResponse, + V1Metadata, + V1NetworkConfig, + V1PackageManager, + V1PythonDependencyInfo, + V1SourceType, + V1UserRequestedComputeConfig, +) +from lightning_cloud.openapi.rest import ApiException + +from lightning.app import LightningApp, LightningWork +from lightning.app.core.queues import QueuingSystem +from lightning.app.runners.backends.backend import Backend +from lightning.app.storage import Drive, Mount +from lightning.app.utilities.enum import make_status, WorkStageStatus, WorkStopReasons +from lightning.app.utilities.exceptions import LightningPlatformException +from lightning.app.utilities.network import _check_service_url_is_ready, LightningClient + +logger = logging.getLogger(__name__) + +from lightning_cloud.openapi import SpecLightningappInstanceIdWorksBody, WorksIdBody # noqa: E402 + +LIGHTNING_STOP_TIMEOUT = int(os.getenv("LIGHTNING_STOP_TIMEOUT", 2 * 60)) + + +def cloud_work_stage_to_work_status_stage(stage: V1LightningworkState) -> str: + """Maps the Work stage names from the cloud backend to the status names in the Lightning framework.""" + mapping = { + V1LightningworkState.STOPPED: WorkStageStatus.STOPPED, + V1LightningworkState.PENDING: WorkStageStatus.PENDING, + V1LightningworkState.NOT_STARTED: WorkStageStatus.PENDING, + V1LightningworkState.IMAGE_BUILDING: WorkStageStatus.PENDING, + V1LightningworkState.RUNNING: WorkStageStatus.RUNNING, + V1LightningworkState.FAILED: WorkStageStatus.FAILED, + } + if stage not in mapping: + raise ValueError(f"Cannot map the lightning-cloud work state {stage} to the lightning status stage.") + return mapping[stage] + + +class CloudBackend(Backend): + def __init__( + self, + entrypoint_file, + queue_id: Optional[str] = None, + status_update_interval: int = 5, + ) -> None: + # TODO: Properly handle queue_id in the cloud. + super().__init__(entrypoint_file, queues=QueuingSystem("http"), queue_id=queue_id) + self._status_update_interval = status_update_interval + self._last_time_updated = None + self.client = LightningClient(retry=True) + self.base_url: Optional[str] = None + + @staticmethod + def _work_to_spec(work: LightningWork) -> V1LightningworkSpec: + work_requirements = "\n".join(work.cloud_build_config.requirements) + + build_spec = V1BuildSpec( + commands=work.cloud_build_config.build_commands(), + python_dependencies=V1PythonDependencyInfo( + package_manager=V1PackageManager.PIP, packages=work_requirements + ), + image=work.cloud_build_config.image, + ) + + drive_specs: List[V1LightningworkDrives] = [] + for drive_attr_name, drive in [ + (k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive) + ]: + if drive.protocol == "lit://": + drive_type = V1DriveType.NO_MOUNT_S3 + source_type = V1SourceType.S3 + else: + drive_type = V1DriveType.UNSPECIFIED + source_type = V1SourceType.UNSPECIFIED + + drive_specs.append( + V1LightningworkDrives( + drive=V1Drive( + metadata=V1Metadata(name=f"{work.name}.{drive_attr_name}"), + spec=V1DriveSpec( + drive_type=drive_type, + source_type=source_type, + source=f"{drive.protocol}{drive.id}", + ), + status=V1DriveStatus(), + ), + mount_location=str(drive.root_folder), + ), + ) + + # this should really be part of the work.cloud_compute struct, but to save + # time we are not going to modify the backend in this set of PRs & instead + # use the same s3 drives API which we used before. + if work.cloud_compute.mounts is not None: + if isinstance(work.cloud_compute.mounts, Mount): + drive_specs.append( + _create_mount_drive_spec( + work_name=work.name, + mount=work.cloud_compute.mounts, + ) + ) + else: + for mount in work.cloud_compute.mounts: + drive_specs.append( + _create_mount_drive_spec( + work_name=work.name, + mount=mount, + ) + ) + + if hasattr(work.cloud_compute, "interruptible"): + preemptible = work.cloud_compute.interruptible + else: + preemptible = work.cloud_compute.preemptible + + colocation_group_id = None + if hasattr(work.cloud_compute, "colocation_group_id"): + colocation_group_id = work.cloud_compute.colocation_group_id + + user_compute_config = V1UserRequestedComputeConfig( + name=work.cloud_compute.name, + count=1, + disk_size=work.cloud_compute.disk_size, + preemptible=preemptible, + shm_size=work.cloud_compute.shm_size, + affinity_identifier=colocation_group_id, + ) + + random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5)) # noqa: S311 + + return V1LightningworkSpec( + build_spec=build_spec, + drives=drive_specs, + user_requested_compute_config=user_compute_config, + network_config=[V1NetworkConfig(name=random_name, port=work.port)], + desired_state=V1LightningworkState.RUNNING, + restart_policy=V1LightningappRestartPolicy.NEVER, + cluster_driver=V1LightningworkClusterDriver.DIRECT, + ) + + def create_work(self, app: LightningApp, work: LightningWork) -> None: + app_id = self._get_app_id() + project_id = self._get_project_id() + list_response: V1ListLightningworkResponse = self.client.lightningwork_service_list_lightningwork( + project_id=project_id, app_id=app_id + ) + external_specs: List[Externalv1Lightningwork] = list_response.lightningworks + + # Find THIS work in the list of all registered works + external_spec = None + for es in external_specs: + if es.name == work.name: + external_spec = es + break + + if external_spec is None: + spec = self._work_to_spec(work) + try: + fn = SpecLightningappInstanceIdWorksBody.__init__ + params = list(inspect.signature(fn).parameters) + extras = {} + if "display_name" in params: + extras["display_name"] = getattr(work, "display_name", "") + + external_spec = self.client.lightningwork_service_create_lightningwork( + project_id=project_id, + spec_lightningapp_instance_id=app_id, + body=SpecLightningappInstanceIdWorksBody( + name=work.name, + spec=spec, + **extras, + ), + ) + # overwriting spec with return value + spec = external_spec.spec + except ApiException as e: + # We might get exceed quotas, or be out of credits. + message = json.loads(e.body).get("message") + raise LightningPlatformException(message) from None + elif external_spec.spec.desired_state == V1LightningworkState.RUNNING: + spec = external_spec.spec + work._port = spec.network_config[0].port + else: + # Signal the LightningWorkState to go into state RUNNING + spec = external_spec.spec + + # getting the updated spec but ignoring everything other than port & drives + new_spec = self._work_to_spec(work) + + spec.desired_state = V1LightningworkState.RUNNING + spec.network_config[0].port = new_spec.network_config[0].port + spec.drives = new_spec.drives + spec.user_requested_compute_config = new_spec.user_requested_compute_config + spec.build_spec = new_spec.build_spec + spec.env = new_spec.env + try: + self.client.lightningwork_service_update_lightningwork( + project_id=project_id, + id=external_spec.id, + spec_lightningapp_instance_id=app_id, + body=WorksIdBody(spec), + ) + except ApiException as e: + # We might get exceed quotas, or be out of credits. + message = json.loads(e.body).get("message") + raise LightningPlatformException(message) from None + + # Replace the undefined url and host by the known one. + work._host = "0.0.0.0" # noqa: S104 + work._future_url = f"{self._get_proxy_scheme()}://{spec.network_config[0].host}" + + # removing the backend to avoid the threadlock error + _backend = work._backend + work._backend = None + app.work_queues[work.name].put(work) + work._backend = _backend + + logger.info(f"Starting work {work.name}") + logger.debug(f"With the following external spec: {external_spec}") + + def update_work_statuses(self, works: List[LightningWork]) -> None: + """Pulls the status of each Work instance in the cloud. + + Normally, the Lightning frameworks communicates statuses through the queues, but while the Work instance is + being provisionied, the queues don't exist yet and hence we need to make API calls directly to the backend to + fetch the status and update it in the states. + """ + if not works: + return + + # TODO: should this run in a timer thread instead? + if self._last_time_updated is not None and monotonic() - self._last_time_updated < self._status_update_interval: + return + + cloud_work_specs = self._get_cloud_work_specs(self.client) + local_works = works + for cloud_work_spec in cloud_work_specs: + for local_work in local_works: + # TODO (tchaton) Better resolve pending status after succeeded + + # 1. Skip if the work isn't the current one. + if local_work.name != cloud_work_spec.name: + continue + + # 2. Logic for idle timeout + self._handle_idle_timeout( + local_work.cloud_compute.idle_timeout, + local_work, + cloud_work_spec, + ) + + # 3. Map the cloud phase to the local one + cloud_stage = cloud_work_stage_to_work_status_stage( + cloud_work_spec.status.phase, + ) + + # 4. Detect if the work failed during pending phase + if local_work.status.stage == WorkStageStatus.PENDING and cloud_stage in WorkStageStatus.FAILED: + if local_work._raise_exception: + raise Exception(f"The work {local_work.name} failed during pending phase.") + logger.error(f"The work {local_work.name} failed during pending phase.") + + # 5. Skip the pending and running as this is already handled by Lightning. + if cloud_stage in (WorkStageStatus.PENDING, WorkStageStatus.RUNNING): + continue + + # TODO: Add the logic for wait_timeout + if local_work.status.stage != cloud_stage: + latest_hash = local_work._calls["latest_call_hash"] + if latest_hash is None: + continue + local_work._calls[latest_hash]["statuses"].append(make_status(cloud_stage)) + + self._last_time_updated = monotonic() + + def stop_all_works(self, works: List[LightningWork]) -> None: + """Stop resources for all LightningWorks in this app. + + The Works are stopped rather than deleted so that they can be inspected for debugging. + """ + cloud_works = self._get_cloud_work_specs(self.client) + + for cloud_work in cloud_works: + self._stop_work(cloud_work) + + def all_works_stopped(works: List[Externalv1Lightningwork]) -> bool: + for work in works: + # deleted work won't be in the request hence only checking for stopped & failed + if work.status.phase not in ( + V1LightningworkState.STOPPED, + V1LightningworkState.FAILED, + ): + return False + return True + + t0 = time() + while not all_works_stopped(self._get_cloud_work_specs(self.client)): + # Wait a little.. + print("Waiting for works to stop...") + sleep(3) + + # Break if we reached timeout. + if time() - t0 > LIGHTNING_STOP_TIMEOUT: + break + + def resolve_url(self, app, base_url: Optional[str] = None) -> None: + if not self.base_url: + self.base_url = base_url + + for flow in app.flows: + if self.base_url: + # Replacing the path with complete URL + if not (self.base_url.startswith("http://") or self.base_url.startswith("https://")): + raise ValueError( + "Base URL doesn't have a valid scheme, expected it to start with 'http://' or 'https://' " + ) + if isinstance(flow._layout, dict) and "target" not in flow._layout: + # FIXME: Why _check_service_url_is_ready doesn't work ? + frontend_url = urllib.parse.urljoin(self.base_url, flow.name + "/") + flow._layout["target"] = frontend_url + + for work in app.works: + if ( + work._url == "" + and work.status.stage + in ( + WorkStageStatus.RUNNING, + WorkStageStatus.SUCCEEDED, + ) + and work._internal_ip != "" + and _check_service_url_is_ready(f"http://{work._internal_ip}:{work._port}") + ): + work._url = work._future_url + + @staticmethod + def _get_proxy_scheme() -> str: + return os.environ.get("LIGHTNING_PROXY_SCHEME", "https") + + @staticmethod + def _get_app_id() -> str: + return os.environ["LIGHTNING_CLOUD_APP_ID"] + + @staticmethod + def _get_project_id() -> str: + return os.environ["LIGHTNING_CLOUD_PROJECT_ID"] + + @staticmethod + def _get_cloud_work_specs(client: LightningClient) -> List[Externalv1Lightningwork]: + list_response: V1ListLightningworkResponse = client.lightningwork_service_list_lightningwork( + project_id=CloudBackend._get_project_id(), + app_id=CloudBackend._get_app_id(), + ) + return list_response.lightningworks + + def _handle_idle_timeout(self, idle_timeout: float, work: LightningWork, resp: Externalv1Lightningwork) -> None: + if idle_timeout is None: + return + + if work.status.stage != WorkStageStatus.SUCCEEDED: + return + + if time() > (idle_timeout + work.status.timestamp): + logger.info(f"Idle Timeout {idle_timeout} has triggered. Stopping gracefully the {work.name}.") + latest_hash = work._calls["latest_call_hash"] + status = make_status(WorkStageStatus.STOPPED, reason=WorkStopReasons.PENDING) + work._calls[latest_hash]["statuses"].append(status) + self._stop_work(resp) + logger.debug(f"Stopping work: {resp.id}") + + def _register_queues(self, app, work): + super()._register_queues(app, work) + kw = {"queue_id": self.queue_id, "work_name": work.name} + app.work_queues.update({work.name: self.queues.get_work_queue(**kw)}) + + def stop_work(self, app: LightningApp, work: LightningWork) -> None: + cloud_works = self._get_cloud_work_specs(self.client) + for cloud_work in cloud_works: + if work.name == cloud_work.name: + self._stop_work(cloud_work) + + def _stop_work(self, work_resp: Externalv1Lightningwork) -> None: + spec: V1LightningworkSpec = work_resp.spec + if spec.desired_state == V1LightningworkState.DELETED: + # work is set to be deleted. Do nothing + return + if spec.desired_state == V1LightningworkState.STOPPED: + # work is set to be stopped already. Do nothing + return + if work_resp.status.phase == V1LightningworkState.FAILED: + # work is already failed. Do nothing + return + spec.desired_state = V1LightningworkState.STOPPED + self.client.lightningwork_service_update_lightningwork( + project_id=CloudBackend._get_project_id(), + id=work_resp.id, + spec_lightningapp_instance_id=CloudBackend._get_app_id(), + body=WorksIdBody(spec), + ) + print(f"Stopping {work_resp.name} ...") + + def delete_work(self, app: LightningApp, work: LightningWork) -> None: + cloud_works = self._get_cloud_work_specs(self.client) + for cloud_work in cloud_works: + if work.name == cloud_work.name: + self._delete_work(cloud_work) + + def _delete_work(self, work_resp: Externalv1Lightningwork) -> None: + spec: V1LightningworkSpec = work_resp.spec + if spec.desired_state == V1LightningworkState.DELETED: + # work is set to be deleted. Do nothing + return + spec.desired_state = V1LightningworkState.DELETED + self.client.lightningwork_service_update_lightningwork( + project_id=CloudBackend._get_project_id(), + id=work_resp.id, + spec_lightningapp_instance_id=CloudBackend._get_app_id(), + body=WorksIdBody(spec), + ) + print(f"Deleting {work_resp.name} ...") + + def update_lightning_app_frontend(self, app: "lightning.LightningApp"): # noqa: F821 + """Used to create frontend's if the app couldn't be loaded locally.""" + if not len(app.frontends.keys()): + return + + external_app_spec: "Externalv1LightningappInstance" = ( + self.client.lightningapp_instance_service_get_lightningapp_instance( + project_id=CloudBackend._get_project_id(), + id=CloudBackend._get_app_id(), + ) + ) + + frontend_specs = external_app_spec.spec.flow_servers + spec = external_app_spec.spec + if len(frontend_specs) != len(app.frontends.keys()): + frontend_specs: List[V1Flowserver] = [] + for flow_name in sorted(app.frontends.keys()): + frontend_spec = V1Flowserver(name=flow_name) + frontend_specs.append(frontend_spec) + + spec.flow_servers = frontend_specs + spec.enable_app_server = True + + logger.info("Found new frontends. Updating the app spec.") + + self.client.lightningapp_instance_service_update_lightningapp_instance( + project_id=CloudBackend._get_project_id(), + id=CloudBackend._get_app_id(), + body=AppinstancesIdBody(spec=spec), + ) + + def stop_app(self, app: "lightning.LightningApp"): # noqa: F821 + """Used to mark the App has stopped if everything has fine.""" + + external_app_spec: "Externalv1LightningappInstance" = ( + self.client.lightningapp_instance_service_get_lightningapp_instance( + project_id=CloudBackend._get_project_id(), + id=CloudBackend._get_app_id(), + ) + ) + + spec = external_app_spec.spec + spec.desired_state = V1LightningappInstanceState.STOPPED + + self.client.lightningapp_instance_service_update_lightningapp_instance( + project_id=CloudBackend._get_project_id(), + id=CloudBackend._get_app_id(), + body=AppinstancesIdBody(spec=spec), + ) + + +def _create_mount_drive_spec(work_name: str, mount: "Mount") -> V1LightningworkDrives: + if mount.protocol == "s3://": + drive_type = V1DriveType.INDEXED_S3 + source_type = V1SourceType.S3 + else: + raise RuntimeError( + f"unknown mounts protocol `{mount.protocol}`. Please verify this " + f"drive type has been configured for use in the cloud dispatcher." + ) + + return V1LightningworkDrives( + drive=V1Drive( + metadata=V1Metadata( + name=work_name, + ), + spec=V1DriveSpec( + drive_type=drive_type, + source_type=source_type, + source=mount.source, + ), + status=V1DriveStatus(), + ), + mount_location=str(mount.mount_path), + ) diff --git a/src/lightning/app/launcher/lightning_hybrid_backend.py b/src/lightning/app/launcher/lightning_hybrid_backend.py new file mode 100644 index 0000000000000..5391aca1d1566 --- /dev/null +++ b/src/lightning/app/launcher/lightning_hybrid_backend.py @@ -0,0 +1,155 @@ +import os +from typing import Optional + +from lightning_cloud.openapi import AppinstancesIdBody, Externalv1LightningappInstance + +from lightning.app.core import constants +from lightning.app.core.queues import QueuingSystem +from lightning.app.launcher.lightning_backend import CloudBackend +from lightning.app.runners.backends.backend import Backend +from lightning.app.runners.backends.mp_process import MultiProcessingBackend +from lightning.app.utilities.network import LightningClient + +if hasattr(constants, "get_cloud_queue_type"): + CLOUD_QUEUE_TYPE = constants.get_cloud_queue_type() or "redis" +else: + CLOUD_QUEUE_TYPE = "redis" + + +class CloudHybridBackend(Backend): + def __init__(self, *args, **kwargs): + super().__init__(*args, queues=QueuingSystem(CLOUD_QUEUE_TYPE), **kwargs) + cloud_backend = CloudBackend(*args, **kwargs) + kwargs.pop("queue_id") + multiprocess_backend = MultiProcessingBackend(*args, **kwargs) + + self.backends = {"cloud": cloud_backend, "multiprocess": multiprocess_backend} + self.work_to_network_configs = {} + + def create_work(self, app, work) -> None: + backend = self._get_backend(work) + if isinstance(backend, MultiProcessingBackend): + self._prepare_work_creation(app, work) + backend.create_work(app, work) + + def _prepare_work_creation(self, app, work) -> None: + app_id = self._get_app_id() + project_id = self._get_project_id() + assert project_id + + client = LightningClient() + list_apps_resp = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id) + lightning_app: Optional[Externalv1LightningappInstance] = None + + for lightningapp in list_apps_resp.lightningapps: + if lightningapp.id == app_id: + lightning_app = lightningapp + + assert lightning_app + + network_configs = lightning_app.spec.network_config + + index = len(self.work_to_network_configs) + + if work.name not in self.work_to_network_configs: + self.work_to_network_configs[work.name] = network_configs[index] + + # Enable Ingress and update the specs. + lightning_app.spec.network_config[index].enable = True + + client.lightningapp_instance_service_update_lightningapp_instance( + project_id=project_id, + id=lightning_app.id, + body=AppinstancesIdBody(name=lightning_app.name, spec=lightning_app.spec), + ) + + work_network_config = self.work_to_network_configs[work.name] + + work._host = "0.0.0.0" # noqa: S104 + work._port = work_network_config.port + work._future_url = f"{self._get_proxy_scheme()}://{work_network_config.host}" + + def update_work_statuses(self, works) -> None: + if works: + backend = self._get_backend(works[0]) + backend.update_work_statuses(works) + + def stop_all_works(self, works) -> None: + if works: + backend = self._get_backend(works[0]) + backend.stop_all_works(works) + + def resolve_url(self, app, base_url: Optional[str] = None) -> None: + works = app.works + if works: + backend = self._get_backend(works[0]) + backend.resolve_url(app, base_url) + + def update_lightning_app_frontend(self, app: "lightning.LightningApp"): # noqa: F821 + self.backends["cloud"].update_lightning_app_frontend(app) + + def stop_work(self, app, work) -> None: + backend = self._get_backend(work) + if isinstance(backend, MultiProcessingBackend): + self._prepare_work_stop(app, work) + backend.stop_work(app, work) + + def delete_work(self, app, work) -> None: + backend = self._get_backend(work) + if isinstance(backend, MultiProcessingBackend): + self._prepare_work_stop(app, work) + backend.delete_work(app, work) + + def _prepare_work_stop(self, app, work): + app_id = self._get_app_id() + project_id = self._get_project_id() + assert project_id + + client = LightningClient() + list_apps_resp = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project_id) + lightning_app: Optional[Externalv1LightningappInstance] = None + + for lightningapp in list_apps_resp.lightningapps: + if lightningapp.id == app_id: + lightning_app = lightningapp + + assert lightning_app + + network_config = self.work_to_network_configs[work.name] + + for nc in lightning_app.spec.network_config: + if nc.host == network_config.host: + nc.enable = False + + client.lightningapp_instance_service_update_lightningapp_instance( + project_id=project_id, + id=lightning_app.id, + body=AppinstancesIdBody(name=lightning_app.name, spec=lightning_app.spec), + ) + + del self.work_to_network_configs[work.name] + + def _register_queues(self, app, work): + backend = self._get_backend(work) + backend._register_queues(app, work) + + def _get_backend(self, work): + if work.cloud_compute.id == "default": + return self.backends["multiprocess"] + return self.backends["cloud"] + + @staticmethod + def _get_proxy_scheme() -> str: + return os.environ.get("LIGHTNING_PROXY_SCHEME", "https") + + @staticmethod + def _get_app_id() -> str: + return os.environ["LIGHTNING_CLOUD_APP_ID"] + + @staticmethod + def _get_project_id() -> str: + return os.environ["LIGHTNING_CLOUD_PROJECT_ID"] + + def stop_app(self, app: "lightning.LightningApp"): # noqa: F821 + """Used to mark the App has stopped if everything has fine.""" + self.backends["cloud"].stop_app(app) diff --git a/src/lightning/app/utilities/packaging/lightning_utils.py b/src/lightning/app/utilities/packaging/lightning_utils.py index c023e80776678..3852c941ed676 100644 --- a/src/lightning/app/utilities/packaging/lightning_utils.py +++ b/src/lightning/app/utilities/packaging/lightning_utils.py @@ -138,17 +138,6 @@ def _prepare_lightning_wheels_and_requirements(root: Path, package_name: str = " # Don't skip by default if (PACKAGE_LIGHTNING or is_lightning) and not bool(int(os.getenv("SKIP_LIGHTING_UTILITY_WHEELS_BUILD", "0"))): - # building and copying launcher wheel if installed in editable mode - launcher_project_path = get_dist_path_if_editable_install("lightning_launcher") - if launcher_project_path: - from lightning_launcher.__version__ import __version__ as launcher_version - - # todo: check why logging.info is missing in outputs - print(f"Packaged Lightning Launcher with your application. Version: {launcher_version}") - _prepare_wheel(launcher_project_path) - tar_name = _copy_tar(launcher_project_path, root) - tar_files.append(os.path.join(root, tar_name)) - # building and copying lightning-cloud wheel if installed in editable mode lightning_cloud_project_path = get_dist_path_if_editable_install("lightning_cloud") if lightning_cloud_project_path: diff --git a/tests/tests_app/cli/launch_data/app_v0/__init__.py b/tests/tests_app/cli/launch_data/app_v0/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_app/cli/launch_data/app_v0/app.py b/tests/tests_app/cli/launch_data/app_v0/app.py new file mode 100644 index 0000000000000..7a8a4f27ced46 --- /dev/null +++ b/tests/tests_app/cli/launch_data/app_v0/app.py @@ -0,0 +1,51 @@ +# v0_app.py +import os +from datetime import datetime +from time import sleep + +import lightning as L +from lightning.app.frontend.web import StaticWebFrontend + + +class Word(L.LightningFlow): + def __init__(self, letter): + super().__init__() + self.letter = letter + self.repeats = letter + + def run(self): + self.repeats += self.letter + + def configure_layout(self): + return StaticWebFrontend(serve_dir=os.path.join(os.path.dirname(__file__), f"ui/{self.letter}")) + + +class V0App(L.LightningFlow): + def __init__(self): + super().__init__() + self.aas = Word("a") + self.bbs = Word("b") + self.counter = 0 + + def run(self): + now = datetime.now() + now = now.strftime("%H:%M:%S") + log = {"time": now, "a": self.aas.repeats, "b": self.bbs.repeats} + print(log) + self.aas.run() + self.bbs.run() + + sleep(2.0) + self.counter += 1 + + def configure_layout(self): + tab1 = {"name": "Tab_1", "content": self.aas} + tab2 = {"name": "Tab_2", "content": self.bbs} + tab3 = { + "name": "Tab_3", + "content": "https://tensorboard.dev/experiment/8m1aX0gcQ7aEmH0J7kbBtg/#scalars", + } + return [tab1, tab2, tab3] + + +app = L.LightningApp(V0App()) diff --git a/tests/tests_app/cli/launch_data/app_v0/ui/a/index.html b/tests/tests_app/cli/launch_data/app_v0/ui/a/index.html new file mode 100644 index 0000000000000..6ddb9a5a1323c --- /dev/null +++ b/tests/tests_app/cli/launch_data/app_v0/ui/a/index.html @@ -0,0 +1 @@ +
Hello from component A
diff --git a/tests/tests_app/cli/launch_data/app_v0/ui/b/index.html b/tests/tests_app/cli/launch_data/app_v0/ui/b/index.html new file mode 100644 index 0000000000000..3bfd9e24cb7f7 --- /dev/null +++ b/tests/tests_app/cli/launch_data/app_v0/ui/b/index.html @@ -0,0 +1 @@ +
Hello from component B
diff --git a/tests/tests_app/cli/test_cmd_launch.py b/tests/tests_app/cli/test_cmd_launch.py new file mode 100644 index 0000000000000..b1fdf89ac9606 --- /dev/null +++ b/tests/tests_app/cli/test_cmd_launch.py @@ -0,0 +1,327 @@ +import os +import signal +import time +from functools import partial +from multiprocessing import Process +from pathlib import Path +from unittest import mock +from unittest.mock import ANY, MagicMock, Mock + +from click.testing import CliRunner + +from lightning.app.cli.lightning_cli_launch import run_flow, run_flow_and_servers, run_frontend, run_server +from lightning.app.core.queues import QueuingSystem +from lightning.app.frontend.web import StaticWebFrontend +from lightning.app.launcher import launcher +from lightning.app.runners.runtime import load_app_from_file +from lightning.app.testing.helpers import _RunIf, EmptyWork +from lightning.app.utilities.app_commands import run_app_commands +from lightning.app.utilities.network import find_free_network_port +from tests_app import _PROJECT_ROOT + +_FILE_PATH = os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py") + + +def test_run_frontend(monkeypatch): + """Test that the CLI can be used to start the frontend server of a particular LightningFlow using the cloud + dispatcher. + + This CLI call is made by Lightning AI and is not meant to be invoked by the user directly. + """ + runner = CliRunner() + + port = find_free_network_port() + + start_server_mock = Mock() + monkeypatch.setattr(StaticWebFrontend, "start_server", start_server_mock) + + result = runner.invoke( + run_frontend, + [ + str(Path(__file__).parent / "launch_data" / "app_v0" / "app.py"), + "--flow-name", + "root.aas", + "--host", + "localhost", + "--port", + port, + ], + ) + assert result.exit_code == 0 + start_server_mock.assert_called_once() + start_server_mock.assert_called_with("localhost", port) + + +class MockRedisQueue: + _MOCKS = {} + + def __init__(self, name: str, default_timeout: float): + self.name = name + self.default_timeout = default_timeout + self.queue = [None] # adding a dummy element. + + self._MOCKS[name] = MagicMock() + + def put(self, item): + self._MOCKS[self.name].put(item) + self.queue.put(item) + + def get(self, timeout: int = None): + self._MOCKS[self.name].get(timeout=timeout) + return self.queue.pop(0) + + @property + def is_running(self): + self._MOCKS[self.name].is_running() + return True + + +@mock.patch("lightning.app.core.queues.RedisQueue", MockRedisQueue) +@mock.patch("lightning.app.launcher.launcher.check_if_redis_running", MagicMock(return_value=True)) +@mock.patch("lightning.app.launcher.launcher.start_server") +def test_run_server(start_server_mock): + runner = CliRunner() + result = runner.invoke( + run_server, + [ + _FILE_PATH, + "--queue-id", + "1", + "--host", + "http://127.0.0.1:7501/view", + "--port", + "6000", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + start_server_mock.assert_called_once_with( + host="http://127.0.0.1:7501/view", + port=6000, + api_publish_state_queue=ANY, + api_delta_queue=ANY, + api_response_queue=ANY, + spec=ANY, + apis=ANY, + ) + kwargs = start_server_mock._mock_call_args.kwargs + assert isinstance(kwargs["api_publish_state_queue"], MockRedisQueue) + assert kwargs["api_publish_state_queue"].name.startswith("1") + assert isinstance(kwargs["api_delta_queue"], MockRedisQueue) + assert kwargs["api_delta_queue"].name.startswith("1") + + +def mock_server(should_catch=False, sleep=1000): + if should_catch: + + def _sigterm_handler(*_): + time.sleep(100) + + signal.signal(signal.SIGTERM, _sigterm_handler) + + time.sleep(sleep) + + +def run_forever_process(): + while True: + time.sleep(1) + + +def run_for_2_seconds_and_raise(): + time.sleep(2) + raise RuntimeError("existing") + + +def exit_successfully_immediately(): + return + + +def start_servers(should_catch=False, sleep=1000): + processes = [ + ( + "p1", + launcher.start_server_in_process(target=partial(mock_server, should_catch=should_catch, sleep=sleep)), + ), + ( + "p2", + launcher.start_server_in_process(target=partial(mock_server, sleep=sleep)), + ), + ( + "p3", + launcher.start_server_in_process(target=partial(mock_server, sleep=sleep)), + ), + ] + + launcher.manage_server_processes(processes) + + +@_RunIf(skip_windows=True) +def test_manage_server_processes(): + p = Process(target=partial(start_servers, sleep=0.5)) + p.start() + p.join() + + assert p.exitcode == 0 + + p = Process(target=start_servers) + p.start() + p.join(0.5) + p.terminate() + p.join() + + assert p.exitcode in [-15, 0] + + p = Process(target=partial(start_servers, should_catch=True)) + p.start() + p.join(0.5) + p.terminate() + p.join() + + assert p.exitcode in [-15, 1] + + +def start_processes(**functions): + processes = [] + for name, fn in functions.items(): + processes.append((name, launcher.start_server_in_process(fn))) + launcher.manage_server_processes(processes) + + +@_RunIf(skip_windows=True) +def test_manage_server_processes_one_process_gets_killed(capfd): + functions = {"p1": run_forever_process, "p2": run_for_2_seconds_and_raise} + p = Process(target=start_processes, kwargs=functions) + p.start() + + for _ in range(40): + time.sleep(1) + if p.exitcode is not None: + break + assert p.exitcode == 1 + captured = capfd.readouterr() + assert ( + "Found dead components with non-zero exit codes, exiting execution!!! Components: \n" + "| Name | Exit Code |\n|------|-----------|\n| p2 | 1 |\n" in captured.out + ) + + +@_RunIf(skip_windows=True) +def test_manage_server_processes_all_processes_exits_with_zero_exitcode(capfd): + functions = { + "p1": exit_successfully_immediately, + "p2": exit_successfully_immediately, + } + p = Process(target=start_processes, kwargs=functions) + p.start() + + for _ in range(40): + time.sleep(1) + if p.exitcode is not None: + break + assert p.exitcode == 0 + captured = capfd.readouterr() + assert "All the components are inactive with exitcode 0. Exiting execution!!!" in captured.out + + +@mock.patch("lightning.app.launcher.launcher.StorageOrchestrator", MagicMock()) +@mock.patch("lightning.app.core.queues.RedisQueue", MockRedisQueue) +@mock.patch("lightning.app.launcher.launcher.manage_server_processes", Mock()) +def test_run_flow_and_servers(monkeypatch): + runner = CliRunner() + + start_server_mock = Mock() + monkeypatch.setattr(launcher, "start_server_in_process", start_server_mock) + + runner.invoke( + run_flow_and_servers, + [ + str(Path(__file__).parent / "launch_data" / "app_v0" / "app.py"), + "--base-url", + "https://some.url", + "--queue-id", + "1", + "--host", + "http://127.0.0.1:7501/view", + "--port", + 6000, + "--flow-port", + "root.aas", + 6001, + "--flow-port", + "root.bbs", + 6002, + ], + catch_exceptions=False, + ) + + start_server_mock.assert_called() + assert start_server_mock.call_count == 4 + + +@mock.patch("lightning.app.core.queues.RedisQueue", MockRedisQueue) +@mock.patch("lightning.app.launcher.launcher.WorkRunner") +def test_run_work(mock_work_runner, monkeypatch): + run_app_commands(_FILE_PATH) + app = load_app_from_file(_FILE_PATH) + names = [w.name for w in app.works] + + mocked_queue = MagicMock() + mocked_queue.get.return_value = EmptyWork() + monkeypatch.setattr( + QueuingSystem, + "get_work_queue", + MagicMock(return_value=mocked_queue), + ) + + assert names == [ + "root.flow_a_1.work_a", + "root.flow_a_2.work_a", + "root.flow_b.work_b", + ] + + for name in names: + launcher.run_lightning_work( + file=_FILE_PATH, + work_name=name, + queue_id="1", + ) + kwargs = mock_work_runner._mock_call_args.kwargs + assert isinstance(kwargs["work"], EmptyWork) + assert kwargs["work_name"] == name + assert isinstance(kwargs["caller_queue"], MockRedisQueue) + assert kwargs["caller_queue"].name.startswith("1") + assert isinstance(kwargs["delta_queue"], MockRedisQueue) + assert kwargs["delta_queue"].name.startswith("1") + assert isinstance(kwargs["readiness_queue"], MockRedisQueue) + assert kwargs["readiness_queue"].name.startswith("1") + assert isinstance(kwargs["error_queue"], MockRedisQueue) + assert kwargs["error_queue"].name.startswith("1") + assert isinstance(kwargs["request_queue"], MockRedisQueue) + assert kwargs["request_queue"].name.startswith("1") + assert isinstance(kwargs["response_queue"], MockRedisQueue) + assert kwargs["response_queue"].name.startswith("1") + assert isinstance(kwargs["copy_request_queue"], MockRedisQueue) + assert kwargs["copy_request_queue"].name.startswith("1") + assert isinstance(kwargs["copy_response_queue"], MockRedisQueue) + assert kwargs["copy_response_queue"].name.startswith("1") + + MockRedisQueue._MOCKS["healthz"].is_running.assert_called() + + +@mock.patch("lightning.app.core.queues.QueuingSystem", MagicMock()) +@mock.patch("lightning.app.launcher.launcher.StorageOrchestrator", MagicMock()) +@mock.patch("lightning.app.LightningApp._run") +@mock.patch("lightning.app.launcher.launcher.CloudBackend") +def test_run_flow(mock_cloud_backend, mock_lightning_app_run): + runner = CliRunner() + + base_url = "https://lightning.ai/me/apps" + + result = runner.invoke( + run_flow, + [_FILE_PATH, "--queue-id=1", f"--base-url={base_url}"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + mock_lightning_app_run.assert_called_once() + assert len(mock_cloud_backend._mock_mock_calls) == 13 diff --git a/tests/tests_app/launcher/test_lightning_backend.py b/tests/tests_app/launcher/test_lightning_backend.py new file mode 100644 index 0000000000000..d5b89af7a3cea --- /dev/null +++ b/tests/tests_app/launcher/test_lightning_backend.py @@ -0,0 +1,809 @@ +import json +import os +from copy import copy +from datetime import datetime +from unittest import mock +from unittest.mock import ANY, MagicMock, Mock + +import pytest +from lightning_cloud.openapi import Body5, V1DriveType, V1LightningworkState, V1SourceType +from lightning_cloud.openapi.rest import ApiException + +from lightning.app import BuildConfig, CloudCompute, LightningWork +from lightning.app.launcher.lightning_backend import CloudBackend +from lightning.app.storage import Drive, Mount +from lightning.app.testing.helpers import EmptyWork +from lightning.app.utilities.enum import WorkFailureReasons, WorkStageStatus +from lightning.app.utilities.exceptions import LightningPlatformException + + +class WorkWithDrive(LightningWork): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.drive = None + + def run(self): + pass + + +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_no_update_when_no_works(client_mock): + cloud_backend = CloudBackend("") + cloud_backend._get_cloud_work_specs = Mock() + client_mock.assert_called_once() + cloud_backend.update_work_statuses(works=[]) + cloud_backend._get_cloud_work_specs.assert_not_called() + + +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_no_update_when_all_work_has_started(client_mock): + cloud_backend = CloudBackend("") + cloud_backend._get_cloud_work_specs = MagicMock() + client_mock.assert_called_once() + started_mock = MagicMock() + started_mock.has_started = True + + # all works have started + works = [started_mock, started_mock] + cloud_backend.update_work_statuses(works=works) + cloud_backend._get_cloud_work_specs.assert_called_once() + + +@mock.patch("lightning.app.launcher.lightning_backend.monotonic") +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_no_update_within_interval(client_mock, monotonic_mock): + cloud_backend = CloudBackend("", status_update_interval=2) + cloud_backend._get_cloud_work_specs = Mock() + client_mock.assert_called_once() + cloud_backend._last_time_updated = 1 + monotonic_mock.return_value = 2 + + stopped_mock = Mock() + stopped_mock.has_started = False + + # not all works have started + works = [stopped_mock, stopped_mock] + + cloud_backend.update_work_statuses(works=works) + cloud_backend._get_cloud_work_specs.assert_not_called() + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.launcher.lightning_backend.monotonic") +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_update_within_interval(client_mock, monotonic_mock): + cloud_backend = CloudBackend("", status_update_interval=2) + cloud_backend._last_time_updated = 1 + # pretend a lot of time has passed since the last update + monotonic_mock.return_value = 8 + + stopped_mock1 = Mock() + stopped_mock1.has_started = False + stopped_mock1.name = "root.mock1" + stopped_mock2 = Mock() + stopped_mock2.has_started = False + stopped_mock2.name = "root.mock2" + + spec1 = Mock() + spec1.name = "root.mock1" + spec2 = Mock() + spec2.name = "root.mock2" + + # not all works have started + works = [stopped_mock1, stopped_mock2] + + cloud_backend.update_work_statuses(works=works) + client_mock().lightningwork_service_list_lightningwork.assert_called_with(project_id="project_id", app_id="app_id") + + # TODO: assert calls on the work mocks + # ... + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_stop_all_works(mock_client): + work_a = EmptyWork() + work_a._name = "root.work_a" + work_a._calls = { + "latest_call_hash": "some_call_hash", + "some_call_hash": { + "statuses": [ + { + "stage": WorkStageStatus.FAILED, + "timestamp": int(datetime.now().timestamp()), + "reason": WorkFailureReasons.USER_EXCEPTION, + }, + ] + }, + } + + work_b = EmptyWork() + work_b._name = "root.work_b" + work_b._calls = { + "latest_call_hash": "some_call_hash", + "some_call_hash": { + "statuses": [{"stage": WorkStageStatus.RUNNING, "timestamp": int(datetime.now().timestamp()), "reason": ""}] + }, + } + + cloud_backend = CloudBackend("") + + spec1 = Mock() + spec1.name = "root.work_a" + spec1.spec.desired_state = V1LightningworkState.RUNNING + spec1.status.phase = V1LightningworkState.FAILED + spec2 = Mock() + spec2.name = "root.work_b" + spec2.spec.desired_state = V1LightningworkState.RUNNING + + class BackendMock: + def __init__(self): + self.called = 0 + + def _get_cloud_work_specs(self, *_): + value = [spec1, spec2] if not self.called else [] + self.called += 1 + return value + + cloud_backend._get_cloud_work_specs = BackendMock()._get_cloud_work_specs + cloud_backend.stop_all_works([work_a, work_b]) + + mock_client().lightningwork_service_update_lightningwork.assert_called_with( + project_id="project_id", + id=ANY, + spec_lightningapp_instance_id="app_id", + body=ANY, + ) + assert spec1.spec.desired_state == V1LightningworkState.RUNNING + assert spec2.spec.desired_state == V1LightningworkState.STOPPED + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_stop_work(mock_client): + work = EmptyWork() + work._name = "root.work" + work._calls = { + "latest_call_hash": "some_call_hash", + "some_call_hash": { + "statuses": [ + { + "stage": WorkStageStatus.RUNNING, + "timestamp": int(datetime.now().timestamp()), + "reason": "", + }, + ] + }, + } + + cloud_backend = CloudBackend("") + spec1 = Mock() + spec1.name = "root.work" + spec1.spec.desired_state = V1LightningworkState.RUNNING + + spec2 = Mock() + spec2.name = "root.work_b" + spec2.spec.desired_state = V1LightningworkState.RUNNING + + class BackendMock: + def __init__(self): + self.called = 0 + + def _get_cloud_work_specs(self, *_): + value = [spec1, spec2] if not self.called else [] + self.called += 1 + return value + + cloud_backend._get_cloud_work_specs = BackendMock()._get_cloud_work_specs + cloud_backend.stop_work(MagicMock(), work) + + mock_client().lightningwork_service_update_lightningwork.assert_called_with( + project_id="project_id", + id=ANY, + spec_lightningapp_instance_id="app_id", + body=ANY, + ) + assert spec1.spec.desired_state == V1LightningworkState.STOPPED + assert spec2.spec.desired_state == V1LightningworkState.RUNNING + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_create_work_where_work_does_not_exists(mock_client): + cloud_backend = CloudBackend("") + non_matching_spec = Mock() + app = MagicMock() + work = EmptyWork(port=1111) + work._name = "name" + + def lightningwork_service_create_lightningwork( + project_id: str = None, + spec_lightningapp_instance_id: str = None, + body: "Body5" = None, + ): + assert project_id == "project_id" + assert spec_lightningapp_instance_id == "app_id" + assert len(body.spec.network_config) == 1 + assert body.spec.network_config[0].port == 1111 + assert not body.spec.network_config[0].host + body.spec.network_config[0].host = "x.lightning.ai" + return body + + response_mock = Mock() + response_mock.lightningworks = [non_matching_spec] + mock_client().lightningwork_service_list_lightningwork.return_value = response_mock + mock_client().lightningwork_service_create_lightningwork = lightningwork_service_create_lightningwork + + cloud_backend.create_work(app, work) + assert work._future_url == "https://x.lightning.ai" + app.work_queues["name"].put.assert_called_once_with(work) + + # testing whether the exception is raised correctly when the backend throws on work creation + http_resp = MagicMock() + error_message = "exception generated from test_create_work_where_work_does_not_exists test case" + http_resp.data = json.dumps({"message": error_message}) + mock_client().lightningwork_service_create_lightningwork = MagicMock() + mock_client().lightningwork_service_create_lightningwork.side_effect = ApiException(http_resp=http_resp) + with pytest.raises(LightningPlatformException, match=error_message): + cloud_backend.create_work(app, work) + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_create_work_with_drives_where_work_does_not_exists(mock_client, tmpdir): + cloud_backend = CloudBackend("") + non_matching_spec = Mock() + app = MagicMock() + + mocked_drive = MagicMock(spec=Drive) + setattr(mocked_drive, "id", "foobar") + setattr(mocked_drive, "protocol", "lit://") + setattr(mocked_drive, "component_name", "test-work") + setattr(mocked_drive, "allow_duplicates", False) + setattr(mocked_drive, "root_folder", tmpdir) + # deepcopy on a MagicMock instance will return an empty magicmock instance. To + # overcome this we set the __deepcopy__ method `return_value` to equal what + # should be the results of the deepcopy operation (an instance of the original class) + mocked_drive.__deepcopy__.return_value = copy(mocked_drive) + + work = WorkWithDrive(port=1111) + work._name = "test-work-name" + work.drive = mocked_drive + + def lightningwork_service_create_lightningwork( + project_id: str = None, + spec_lightningapp_instance_id: str = None, + body: "Body5" = None, + ): + assert project_id == "project_id" + assert spec_lightningapp_instance_id == "app_id" + assert len(body.spec.network_config) == 1 + assert body.spec.network_config[0].port == 1111 + assert not body.spec.network_config[0].host + body.spec.network_config[0].host = "x.lightning.ai" + assert len(body.spec.drives) == 1 + assert body.spec.drives[0].drive.spec.drive_type == V1DriveType.NO_MOUNT_S3 + assert body.spec.drives[0].drive.spec.source_type == V1SourceType.S3 + assert body.spec.drives[0].drive.spec.source == "lit://foobar" + assert body.spec.drives[0].drive.metadata.name == "test-work-name.drive" + for v in body.spec.drives[0].drive.status.to_dict().values(): + assert v is None + + return body + + response_mock = Mock() + response_mock.lightningworks = [non_matching_spec] + mock_client().lightningwork_service_list_lightningwork.return_value = response_mock + mock_client().lightningwork_service_create_lightningwork = lightningwork_service_create_lightningwork + + cloud_backend.create_work(app, work) + assert work._future_url == "https://x.lightning.ai" + app.work_queues["test-work-name"].put.assert_called_once_with(work) + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + "LIGHTNING_PROXY_SCHEME": "http", + }, +) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_create_work_proxy_http(mock_client, tmpdir): + cloud_backend = CloudBackend("") + non_matching_spec = Mock() + app = MagicMock() + + mocked_drive = MagicMock(spec=Drive) + setattr(mocked_drive, "id", "foobar") + setattr(mocked_drive, "protocol", "lit://") + setattr(mocked_drive, "component_name", "test-work") + setattr(mocked_drive, "allow_duplicates", False) + setattr(mocked_drive, "root_folder", tmpdir) + # deepcopy on a MagicMock instance will return an empty magicmock instance. To + # overcome this we set the __deepcopy__ method `return_value` to equal what + # should be the results of the deepcopy operation (an instance of the original class) + mocked_drive.__deepcopy__.return_value = copy(mocked_drive) + + work = WorkWithDrive(port=1111) + work._name = "test-work-name" + work.drive = mocked_drive + + def lightningwork_service_create_lightningwork( + project_id: str = None, + spec_lightningapp_instance_id: str = None, + body: "Body5" = None, + ): + assert project_id == "project_id" + assert spec_lightningapp_instance_id == "app_id" + assert len(body.spec.network_config) == 1 + assert body.spec.network_config[0].port == 1111 + assert not body.spec.network_config[0].host + body.spec.network_config[0].host = "x.lightning.ai" + assert len(body.spec.drives) == 1 + assert body.spec.drives[0].drive.spec.drive_type == V1DriveType.NO_MOUNT_S3 + assert body.spec.drives[0].drive.spec.source_type == V1SourceType.S3 + assert body.spec.drives[0].drive.spec.source == "lit://foobar" + assert body.spec.drives[0].drive.metadata.name == "test-work-name.drive" + for v in body.spec.drives[0].drive.status.to_dict().values(): + assert v is None + + return body + + response_mock = Mock() + response_mock.lightningworks = [non_matching_spec] + mock_client().lightningwork_service_list_lightningwork.return_value = response_mock + mock_client().lightningwork_service_create_lightningwork = lightningwork_service_create_lightningwork + + cloud_backend.create_work(app, work) + assert work._future_url == "http://x.lightning.ai" + app.work_queues["test-work-name"].put.assert_called_once_with(work) + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock()) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_update_work_with_changed_compute_config_with_mounts(mock_client): + cloud_backend = CloudBackend("") + matching_spec = Mock() + app = MagicMock() + work = EmptyWork(cloud_compute=CloudCompute("default"), cloud_build_config=BuildConfig(image="image1")) + work._name = "work_name" + + matching_spec.spec = cloud_backend._work_to_spec(work) + matching_spec.spec.desired_state = V1LightningworkState.STOPPED + matching_spec.name = "work_name" + + response_mock = Mock() + response_mock.lightningworks = [matching_spec] + mock_client().lightningwork_service_list_lightningwork.return_value = response_mock + + cloud_backend.create_work(app, work) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.desired_state + == V1LightningworkState.RUNNING + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.user_requested_compute_config.name + == "cpu-small" + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.build_spec.image + == "image1" + ) + + # resetting the values changed in the previous step + matching_spec.spec.desired_state = V1LightningworkState.STOPPED + cloud_backend.client.lightningwork_service_update_lightningwork.reset_mock() + + # new work with same name but different compute config + mount = Mount(source="s3://foo/", mount_path="/foo") + work = EmptyWork(cloud_compute=CloudCompute("gpu", mounts=mount), cloud_build_config=BuildConfig(image="image2")) + work._name = "work_name" + cloud_backend.create_work(app, work) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.desired_state + == V1LightningworkState.RUNNING + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.user_requested_compute_config.name + == "gpu" + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs["body"] + .spec.drives[0] + .mount_location + == "/foo" + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs["body"] + .spec.drives[0] + .drive.spec.source + == "s3://foo/" + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.build_spec.image + == "image2" + ) + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock()) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_create_work_where_work_already_exists(mock_client): + cloud_backend = CloudBackend("") + matching_spec = Mock() + app = MagicMock() + work = EmptyWork(port=1111) + work._name = "work_name" + work._backend = cloud_backend + + matching_spec.spec = cloud_backend._work_to_spec(work) + matching_spec.spec.network_config[0].host = "x.lightning.ai" + matching_spec.spec.desired_state = V1LightningworkState.STOPPED + matching_spec.name = "work_name" + + response_mock = Mock() + response_mock.lightningworks = [matching_spec] + mock_client().lightningwork_service_list_lightningwork.return_value = response_mock + + cloud_backend.create_work(app, work) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.desired_state + == V1LightningworkState.RUNNING + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs["body"] + .spec.network_config[0] + .port + == 1111 + ) + assert work._future_url == "https://x.lightning.ai" + app.work_queues["work_name"].put.assert_called_once_with(work) + + # resetting the values changed in the previous step + matching_spec.spec.desired_state = V1LightningworkState.STOPPED + cloud_backend.client.lightningwork_service_update_lightningwork.reset_mock() + app.work_queues["work_name"].put.reset_mock() + + # changing the port + work._port = 2222 + cloud_backend.create_work(app, work) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs["body"] + .spec.network_config[0] + .port + == 2222 + ) + app.work_queues["work_name"].put.assert_called_once_with(work) + + # testing whether the exception is raised correctly when the backend throws on work creation + # resetting the values changed in the previous step + matching_spec.spec.desired_state = V1LightningworkState.STOPPED + http_resp = MagicMock() + error_message = "exception generated from test_create_work_where_work_already_exists test case" + http_resp.data = json.dumps({"message": error_message}) + mock_client().lightningwork_service_update_lightningwork = MagicMock() + mock_client().lightningwork_service_update_lightningwork.side_effect = ApiException(http_resp=http_resp) + with pytest.raises(LightningPlatformException, match=error_message): + cloud_backend.create_work(app, work) + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock()) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_create_work_will_have_none_backend(mockclient): + def queue_put_mock(work): + # because we remove backend before pushing to queue + assert work._backend is None + + cloud_backend = CloudBackend("") + app = MagicMock() + work = EmptyWork() + # attaching backend - this will be removed by the queue + work._backend = cloud_backend + app.work_queues["work_name"].put = queue_put_mock + cloud_backend.create_work(app, work) + # make sure the work still have the backend attached + assert work._backend == cloud_backend + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock()) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_update_work_with_changed_compute_config_and_build_spec(mock_client): + cloud_backend = CloudBackend("") + matching_spec = Mock() + app = MagicMock() + work = EmptyWork(cloud_compute=CloudCompute("default"), cloud_build_config=BuildConfig(image="image1")) + work._name = "work_name" + + matching_spec.spec = cloud_backend._work_to_spec(work) + matching_spec.spec.desired_state = V1LightningworkState.STOPPED + matching_spec.name = "work_name" + + response_mock = Mock() + response_mock.lightningworks = [matching_spec] + mock_client().lightningwork_service_list_lightningwork.return_value = response_mock + + cloud_backend.create_work(app, work) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.desired_state + == V1LightningworkState.RUNNING + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.user_requested_compute_config.name + == "cpu-small" + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.build_spec.image + == "image1" + ) + + # resetting the values changed in the previous step + matching_spec.spec.desired_state = V1LightningworkState.STOPPED + cloud_backend.client.lightningwork_service_update_lightningwork.reset_mock() + + # new work with same name but different compute config + work = EmptyWork(cloud_compute=CloudCompute("gpu"), cloud_build_config=BuildConfig(image="image2")) + work._name = "work_name" + cloud_backend.create_work(app, work) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.desired_state + == V1LightningworkState.RUNNING + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.user_requested_compute_config.name + == "gpu" + ) + assert ( + cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args.kwargs[ + "body" + ].spec.build_spec.image + == "image2" + ) + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock()) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_update_work_with_changed_spec_while_work_running(mock_client): + cloud_backend = CloudBackend("") + matching_spec = Mock() + app = MagicMock() + work = EmptyWork(cloud_compute=CloudCompute("default"), cloud_build_config=BuildConfig(image="image1")) + work._name = "work_name" + + matching_spec.spec = cloud_backend._work_to_spec(work) + matching_spec.spec.desired_state = V1LightningworkState.RUNNING + matching_spec.name = "work_name" + + response_mock = Mock() + response_mock.lightningworks = [matching_spec] + mock_client().lightningwork_service_list_lightningwork.return_value = response_mock + + cloud_backend.create_work(app, work) + + # asserting the method is not called + cloud_backend.client.lightningwork_service_update_lightningwork.assert_not_called() + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock()) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_update_lightning_app_frontend_new_frontends(mock_client): + cloud_backend = CloudBackend("") + cloud_backend.client = mock_client + mocked_app = MagicMock() + mocked_app.frontends.keys.return_value = ["frontend2", "frontend1"] + app_instance_mock = MagicMock() + app_instance_mock.spec.flow_servers = [] + update_lightning_app_instance_mock = MagicMock() + mock_client.lightningapp_instance_service_get_lightningapp_instance.return_value = app_instance_mock + mock_client.lightningapp_instance_service_update_lightningapp_instance.return_value = ( + update_lightning_app_instance_mock + ) + cloud_backend.update_lightning_app_frontend(mocked_app) + assert mock_client.lightningapp_instance_service_update_lightningapp_instance.call_count == 1 + + # frontends should be sorted + assert ( + mock_client.lightningapp_instance_service_update_lightningapp_instance.call_args.kwargs["body"] + .spec.flow_servers[0] + .name + == "frontend1" + ) + assert ( + mock_client.lightningapp_instance_service_update_lightningapp_instance.call_args.kwargs["body"] + .spec.flow_servers[1] + .name + == "frontend2" + ) + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock()) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_update_lightning_app_frontend_existing_frontends(mock_client): + cloud_backend = CloudBackend("") + cloud_backend.client = mock_client + mocked_app = MagicMock() + mocked_app.frontends.keys.return_value = ["frontend2", "frontend1"] + app_instance_mock = MagicMock() + app_instance_mock.spec.flow_servers = ["frontend2", "frontend1"] + update_lightning_app_instance_mock = MagicMock() + mock_client.lightningapp_instance_service_get_lightningapp_instance.return_value = app_instance_mock + mock_client.lightningapp_instance_service_update_lightningapp_instance.return_value = ( + update_lightning_app_instance_mock + ) + cloud_backend.update_lightning_app_frontend(mocked_app) + + # the app spec already has the frontends, so no update should be called + assert mock_client.lightningapp_instance_service_update_lightningapp_instance.call_count == 0 + assert mock_client.lightningapp_instance_service_update_lightningapp_instance.call_count == 0 + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.utilities.network.create_swagger_client", MagicMock()) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_stop_app(mock_client): + cloud_backend = CloudBackend("") + external_spec = MagicMock() + mock_client.lightningapp_instance_service_get_lightningapp_instance.return_value = external_spec + cloud_backend.client = mock_client + mocked_app = MagicMock() + cloud_backend.stop_app(mocked_app) + spec = mock_client.lightningapp_instance_service_update_lightningapp_instance._mock_call_args.kwargs["body"].spec + assert spec.desired_state == "LIGHTNINGAPP_INSTANCE_STATE_STOPPED" + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_failed_works_during_pending(client_mock): + cloud_backend = CloudBackend("") + cloud_work = MagicMock() + cloud_work.name = "a" + cloud_work.status.phase = V1LightningworkState.FAILED + cloud_backend._get_cloud_work_specs = MagicMock(return_value=[cloud_work]) + + local_work = MagicMock() + local_work.status.stage = "pending" + local_work.name = "a" + local_work._raise_exception = True + + with pytest.raises(Exception, match="The work a failed during pending phase."): + # all works have started + cloud_backend.update_work_statuses(works=[local_work]) + + +@mock.patch.dict( + os.environ, + { + "LIGHTNING_CLOUD_PROJECT_ID": "project_id", + "LIGHTNING_CLOUD_APP_ID": "app_id", + }, +) +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_work_delete(client_mock): + cloud_backend = CloudBackend("") + cloud_work = MagicMock() + cloud_work.name = "a" + cloud_work.status.phase = V1LightningworkState.RUNNING + cloud_backend._get_cloud_work_specs = MagicMock(return_value=[cloud_work]) + + local_work = MagicMock() + local_work.status.stage = "running" + local_work.name = "a" + local_work._raise_exception = True + cloud_backend.delete_work(None, local_work) + call = cloud_backend.client.lightningwork_service_update_lightningwork._mock_call_args_list[0] + assert call.kwargs["body"].spec.desired_state == V1LightningworkState.DELETED diff --git a/tests/tests_app/launcher/test_lightning_hydrid.py b/tests/tests_app/launcher/test_lightning_hydrid.py new file mode 100644 index 0000000000000..695d3216a2316 --- /dev/null +++ b/tests/tests_app/launcher/test_lightning_hydrid.py @@ -0,0 +1,14 @@ +from unittest import mock + +from lightning.app import CloudCompute +from lightning.app.launcher.lightning_hybrid_backend import CloudHybridBackend + + +@mock.patch("lightning.app.launcher.lightning_backend.LightningClient") +def test_backend_selection(client_mock): + cloud_backend = CloudHybridBackend("", queue_id="") + work = mock.MagicMock() + work.cloud_compute = CloudCompute() + assert cloud_backend._get_backend(work) == cloud_backend.backends["multiprocess"] + work.cloud_compute = CloudCompute("gpu") + assert cloud_backend._get_backend(work) == cloud_backend.backends["cloud"] diff --git a/tests/tests_app/launcher/test_running_flow.py b/tests/tests_app/launcher/test_running_flow.py new file mode 100644 index 0000000000000..6002238ac2ba7 --- /dev/null +++ b/tests/tests_app/launcher/test_running_flow.py @@ -0,0 +1,134 @@ +import logging +import os +import signal +import sys +from unittest import mock +from unittest.mock import MagicMock, Mock + +import pytest +import requests + +from lightning.app.launcher import launcher, lightning_backend +from lightning.app.utilities.app_helpers import convert_print_to_logger_info +from lightning.app.utilities.enum import AppStage +from lightning.app.utilities.exceptions import ExitAppException + + +def _make_mocked_network_config(key, host): + network_config = Mock() + network_config.name = key + network_config.host = host + return network_config + + +@mock.patch("lightning.app.core.queues.QueuingSystem", mock.MagicMock()) +@mock.patch("lightning.app.launcher.launcher.check_if_redis_running", MagicMock(return_value=True)) +def test_running_flow(monkeypatch): + app = MagicMock() + flow = MagicMock() + work = MagicMock() + work.run.__name__ = "run" + flow._layout = {} + flow.name = "flowname" + work.name = "workname" + app.flows = [flow] + flow.works.return_value = [work] + + def load_app_from_file(file): + assert file == "file.py" + return app + + class BackendMock: + def __init__(self, return_value): + self.called = 0 + self.return_value = return_value + + def _get_cloud_work_specs(self, *_): + value = self.return_value if not self.called else [] + self.called += 1 + return value + + cloud_work_spec = Mock() + cloud_work_spec.name = "workname" + cloud_work_spec.spec.network_config = [ + _make_mocked_network_config("key1", "x.lightning.ai"), + ] + monkeypatch.setattr(launcher, "load_app_from_file", load_app_from_file) + monkeypatch.setattr(launcher, "start_server", MagicMock()) + monkeypatch.setattr(lightning_backend, "LightningClient", MagicMock()) + lightning_backend.CloudBackend._get_cloud_work_specs = BackendMock( + return_value=[cloud_work_spec] + )._get_cloud_work_specs + monkeypatch.setattr(lightning_backend.CloudBackend, "_get_project_id", MagicMock()) + monkeypatch.setattr(lightning_backend.CloudBackend, "_get_app_id", MagicMock()) + queue_system = MagicMock() + queue_system.REDIS = MagicMock() + monkeypatch.setattr(launcher, "QueuingSystem", queue_system) + monkeypatch.setattr(launcher, "StorageOrchestrator", MagicMock()) + + response = MagicMock() + response.status_code = 200 + monkeypatch.setattr(requests, "get", MagicMock(return_value=response)) + + # testing with correct base URL + with pytest.raises(SystemExit, match="0"): + launcher.run_lightning_flow("file.py", queue_id="", base_url="http://localhost:8080") + assert flow._layout["target"] == "http://localhost:8080/flowname/" + + app._run.assert_called_once() + + # testing with invalid base URL + with pytest.raises(ValueError, match="Base URL doesn't have a valid scheme"): + launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080") + + app.flows = [] + + def run_patch(): + raise Exception + + app._run = run_patch + + with pytest.raises(SystemExit, match="1"): + launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080") + + def run_patch(): + app.stage = AppStage.FAILED + + app._run = run_patch + + with pytest.raises(SystemExit, match="1"): + launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080") + + def run_patch(): + raise ExitAppException + + if sys.platform == "win32": + return + + app.stage = AppStage.STOPPING + + app._run = run_patch + with pytest.raises(SystemExit, match="0"): + launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080") + + def run_method(): + os.kill(os.getpid(), signal.SIGTERM) + + app._run = run_method + monkeypatch.setattr(lightning_backend.CloudBackend, "resolve_url", MagicMock()) + with pytest.raises(SystemExit, match="0"): + launcher.run_lightning_flow("file.py", queue_id="", base_url="localhost:8080") + assert app.stage == AppStage.STOPPING + + +def test_replace_print_to_info(caplog, monkeypatch): + monkeypatch.setattr("lightning.app._logger", logging.getLogger()) + + @convert_print_to_logger_info + def fn_captured(value): + print(value) + + with caplog.at_level(logging.INFO): + fn_captured(1) + + assert caplog.messages == ["1"] diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index 6b58d79ab3c79..af569b5ccc6ab 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -1540,9 +1540,6 @@ def test_reopen(self, monkeypatch, capsys): project_id="test-project-id", cloudspace_id="cloudspace_id", body=mock.ANY ) - out, _ = capsys.readouterr() - assert "will not overwrite the files in your CloudSpace." in out - def test_not_enabled(self, monkeypatch, capsys): """Tests that an error is printed and the call exits if the feature isn't enabled for the user.""" mock_client = mock.MagicMock() diff --git a/tests/tests_app/test_imports.py b/tests/tests_app/test_imports.py index e40ab369c747d..f80d497cacae5 100644 --- a/tests/tests_app/test_imports.py +++ b/tests/tests_app/test_imports.py @@ -44,6 +44,7 @@ def test_import_depth( "lightning.app.cli", "lightning.app.components.serve.types", "lightning.app.core", + "lightning.app.launcher", "lightning.app.runners", "lightning.app.utilities", ] diff --git a/tests/tests_data/datasets/test_env.py b/tests/tests_data/datasets/test_env.py index 6be415cb7e021..55e33282153f8 100644 --- a/tests/tests_data/datasets/test_env.py +++ b/tests/tests_data/datasets/test_env.py @@ -6,6 +6,7 @@ from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv, Environment from lightning.fabric import Fabric +from tests_fabric.helpers.runif import RunIf @pytest.mark.parametrize( @@ -109,6 +110,7 @@ def env_auto_test(fabric: Fabric, num_workers): pass +@RunIf(skip_windows=True) @pytest.mark.parametrize("num_workers", [0, 1, 2]) @pytest.mark.parametrize("dist_world_size", [1, 2]) def test_env_auto(num_workers, dist_world_size):