diff --git a/examples/app_server_with_auto_scaler/app.py b/examples/app_server_with_auto_scaler/app.py index 70799827776a8..453db2424b404 100644 --- a/examples/app_server_with_auto_scaler/app.py +++ b/examples/app_server_with_auto_scaler/app.py @@ -1,5 +1,5 @@ # ! pip install torch torchvision -from typing import Any, List +from typing import List import torch import torchvision @@ -8,16 +8,12 @@ import lightning as L -class RequestModel(BaseModel): - image: str # bytecode - - class BatchRequestModel(BaseModel): - inputs: List[RequestModel] + inputs: List[L.app.components.Image] class BatchResponse(BaseModel): - outputs: List[Any] + outputs: List[L.app.components.Number] class PyTorchServer(L.app.components.PythonServer): @@ -81,8 +77,8 @@ def scale(self, replicas: int, metrics: dict) -> int: max_replicas=4, autoscale_interval=10, endpoint="predict", - input_type=RequestModel, - output_type=Any, + input_type=L.app.components.Image, + output_type=L.app.components.Number, timeout_batching=1, max_batch_size=8, ) diff --git a/src/lightning_app/components/auto_scaler.py b/src/lightning_app/components/auto_scaler.py index 13948ba50af89..d68b4e04ea336 100644 --- a/src/lightning_app/components/auto_scaler.py +++ b/src/lightning_app/components/auto_scaler.py @@ -280,6 +280,8 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe async def balance_api(inputs: self._input_type): return await self.process_request(inputs) + logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'") + uvicorn.run( fastapi_app, host=self.host, @@ -332,6 +334,51 @@ def send_request_to_update_servers(self, servers: List[str]): response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10) response.raise_for_status() + @staticmethod + def _get_sample_dict_from_datatype(datatype: Any) -> dict: + if hasattr(datatype, "_get_sample_data"): + return datatype._get_sample_data() + + datatype_props = datatype.schema()["properties"] + out: Dict[str, Any] = {} + for k, v in datatype_props.items(): + if v["type"] == "string": + out[k] = "data string" + elif v["type"] == "number": + out[k] = 0.0 + elif v["type"] == "integer": + out[k] = 0 + elif v["type"] == "boolean": + out[k] = False + else: + raise TypeError("Unsupported type") + return out + + def configure_layout(self) -> None: + try: + from lightning_api_access import APIAccessFrontend + except ModuleNotFoundError: + logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI") + return + + try: + request = self._get_sample_dict_from_datatype(self._input_type) + response = self._get_sample_dict_from_datatype(self._output_type) + except (AttributeError, TypeError): + return + + return APIAccessFrontend( + apis=[ + { + "name": self.__class__.__name__, + "url": f"{self.url}{self.endpoint}", + "method": "POST", + "request": request, + "response": response, + } + ] + ) + class AutoScaler(LightningFlow): """The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in @@ -574,5 +621,5 @@ def autoscale(self) -> None: self._last_autoscale = time.time() def configure_layout(self): - tabs = [{"name": "Swagger", "content": self.load_balancer.url}] - return tabs + layout = self.load_balancer.configure_layout() + return layout if layout else super().configure_layout()