Skip to content

Commit

Permalink
Avoid using the same port number for autoscaler works (#15966)
Browse files Browse the repository at this point in the history
* dont hardcode port in python server
* add another chglog
  • Loading branch information
akihironitta authored Dec 9, 2022
1 parent 346e936 commit a72d268
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
21 changes: 12 additions & 9 deletions examples/app_server_with_auto_scaler/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# ! pip install torch torchvision
from typing import Any, List

import torch
Expand All @@ -22,10 +23,10 @@ class BatchResponse(BaseModel):
class PyTorchServer(L.app.components.PythonServer):
def __init__(self, *args, **kwargs):
super().__init__(
port=L.app.utilities.network.find_free_network_port(),
input_type=BatchRequestModel,
output_type=BatchResponse,
cloud_compute=L.CloudCompute("gpu"),
*args,
**kwargs,
)

def setup(self):
Expand Down Expand Up @@ -57,30 +58,32 @@ def scale(self, replicas: int, metrics: dict) -> int:
"""The default scaling logic that users can override."""
# scale out if the number of pending requests exceeds max batch size.
max_requests_per_work = self.max_batch_size
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
replicas + metrics["pending_works"]
)
if pending_requests_per_running_or_pending_work >= max_requests_per_work:
pending_requests_per_work = metrics["pending_requests"] / (replicas + metrics["pending_works"])
if pending_requests_per_work >= max_requests_per_work:
return replicas + 1

# scale in if the number of pending requests is below 25% of max_requests_per_work
min_requests_per_work = max_requests_per_work * 0.25
pending_requests_per_running_work = metrics["pending_requests"] / replicas
if pending_requests_per_running_work < min_requests_per_work:
pending_requests_per_work = metrics["pending_requests"] / replicas
if pending_requests_per_work < min_requests_per_work:
return replicas - 1

return replicas


app = L.LightningApp(
MyAutoScaler(
# work class and args
PyTorchServer,
min_replicas=2,
cloud_compute=L.CloudCompute("gpu"),
# autoscaler specific args
min_replicas=1,
max_replicas=4,
autoscale_interval=10,
endpoint="predict",
input_type=RequestModel,
output_type=Any,
timeout_batching=1,
max_batch_size=8,
)
)
4 changes: 2 additions & 2 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Changed the default port of `PythonServer` from `7777` to a free port at runtime ([#15966](https://github.com/Lightning-AI/lightning/pull/15966))

- Remove the `AutoScaler` dependency `aiohttp` from the base requirements ([#15971](https://github.com/Lightning-AI/lightning/pull/15971))

Expand All @@ -30,7 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed `AutoScaler` failing due to port collision across works ([#15966](https://github.com/Lightning-AI/lightning/pull/15966))


## [1.8.4] - 2022-12-08
Expand Down
3 changes: 2 additions & 1 deletion src/lightning_app/components/auto_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ def workers(self) -> List[LightningWork]:
def create_work(self) -> LightningWork:
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
# TODO: Remove `start_with_flow=False` for faster initialization on the cloud
return self._work_cls(*self._work_args, **self._work_kwargs, start_with_flow=False)
self._work_kwargs.update(dict(start_with_flow=False))
return self._work_cls(*self._work_args, **self._work_kwargs)

def add_work(self, work) -> str:
"""Adds a new LightningWork instance.
Expand Down
6 changes: 1 addition & 5 deletions src/lightning_app/components/serve/python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,13 @@ class PythonServer(LightningWork, abc.ABC):
@requires(["torch", "lightning_api_access"])
def __init__( # type: ignore
self,
host: str = "127.0.0.1",
port: int = 7777,
input_type: type = _DefaultInputData,
output_type: type = _DefaultOutputData,
**kwargs,
):
"""The PythonServer Class enables to easily get your machine learning server up and running.
Arguments:
host: Address to be used for running the server.
port: Port to be used to running the server.
input_type: Optional `input_type` to be provided. This needs to be a pydantic BaseModel class.
The default data type is good enough for the basic usecases and it expects the data
to be a json object that has one key called `payload`
Expand Down Expand Up @@ -129,7 +125,7 @@ def predict(self, request):
...
>>> app = LightningApp(SimpleServer())
"""
super().__init__(parallel=True, host=host, port=port, **kwargs)
super().__init__(parallel=True, **kwargs)
if not issubclass(input_type, BaseModel):
raise TypeError("input_type must be a pydantic BaseModel class")
if not issubclass(output_type, BaseModel):
Expand Down

0 comments on commit a72d268

Please sign in to comment.