Skip to content

Commit

Permalink
Batch Concurrency (#358)
Browse files Browse the repository at this point in the history
* Restored concurrency_modifier support
* Concurrent calls now use batched job takes instead of multiple parallel singular job takes
* Using asyncio.Queue to track jobs queue and processing
* JobScaler refactored to properly coordinate job takes and processing
  • Loading branch information
deanq authored Sep 26, 2024
1 parent ca2bc00 commit f1e6450
Show file tree
Hide file tree
Showing 16 changed files with 444 additions and 348 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI-pytests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ jobs:
pip install '.[test]'
- name: Run Tests
run: pytest --cov-config=.coveragerc --timeout=120 --timeout_method=thread --cov=runpod --cov-report=xml --cov-report=term-missing --cov-fail-under=98 -W error -p no:cacheprovider -p no:unraisableexception
run: pytest
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ test_logging.py
/example_project
IGNORE.py
/quick-test
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
16 changes: 13 additions & 3 deletions examples/serverless/concurrent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,18 @@ async def async_generator_handler(job):
return job


# --------------------------- Concurrency Modifier --------------------------- #
def concurrency_modifier(current_concurrency=1):
"""
Concurrency modifier.
"""
desired_concurrency = current_concurrency

# Do some logic to determine the desired concurrency.

return desired_concurrency


runpod.serverless.start(
{
"handler": async_generator_handler,
}
{"handler": async_generator_handler, "concurrency_modifier": concurrency_modifier}
)
2 changes: 0 additions & 2 deletions runpod/serverless/modules/rp_fastapi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
""" Used to launch the FastAPI web server when worker is running in API mode. """

# pylint: disable=too-few-public-methods, line-too-long

import os
import threading
import uuid
Expand Down
5 changes: 1 addition & 4 deletions runpod/serverless/modules/rp_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from runpod.http_client import ClientSession
from runpod.serverless.modules.rp_logger import RunPodLogger

from .worker_state import WORKER_ID, Jobs
from .worker_state import WORKER_ID

JOB_DONE_URL_TEMPLATE = str(
os.environ.get("RUNPOD_WEBHOOK_POST_OUTPUT", "JOB_DONE_URL")
Expand All @@ -24,7 +24,6 @@
JOB_STREAM_URL = JOB_STREAM_URL_TEMPLATE.replace("$RUNPOD_POD_ID", WORKER_ID)

log = RunPodLogger()
job_list = Jobs()


async def _transmit(client_session: ClientSession, url, job_data):
Expand All @@ -49,7 +48,6 @@ async def _transmit(client_session: ClientSession, url, job_data):
await client_response.text()


# pylint: disable=too-many-arguments, disable=line-too-long
async def _handle_result(
session: ClientSession, job_data, job, url_template, log_message, is_stream=False
):
Expand Down Expand Up @@ -79,7 +77,6 @@ async def _handle_result(
url_template == JOB_DONE_URL
and job_data.get("status", None) != "IN_PROGRESS"
):
job_list.remove_job(job["id"])
log.info("Finished.", job["id"])


Expand Down
145 changes: 61 additions & 84 deletions runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,27 @@
Job related helpers.
"""

# pylint: disable=too-many-branches

import asyncio
import inspect
import json
import os
import traceback
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union
from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, List

from runpod.http_client import ClientSession
from runpod.serverless.modules.rp_logger import RunPodLogger

from ...version import __version__ as runpod_version
from .rp_tips import check_return_size
from .worker_state import WORKER_ID, Jobs
from .worker_state import WORKER_ID, JobsQueue

JOB_GET_URL = str(os.environ.get("RUNPOD_WEBHOOK_GET_JOB")).replace("$ID", WORKER_ID)

log = RunPodLogger()
job_list = Jobs()
job_list = JobsQueue()


def _job_get_url():
def _job_get_url(batch_size: int = 1):
"""
Prepare the URL for making a 'get' request to the serverless API (sls).
Expand All @@ -34,89 +32,68 @@ def _job_get_url():
Returns:
str: The prepared URL for the 'get' request to the serverless API.
"""
job_in_progress = "1" if job_list.get_job_list() else "0"
return JOB_GET_URL + f"&job_in_progress={job_in_progress}"
job_in_progress = "1" if job_list.get_job_count() else "0"

if batch_size > 1:
job_take_url = JOB_GET_URL.replace("/job-take/", "/job-take-batch/")
job_take_url += f"&batch_size={batch_size}&batch_strategy=LMove"
else:
job_take_url = JOB_GET_URL

return job_take_url + f"&job_in_progress={job_in_progress}"

async def get_job(session: ClientSession, retry=True) -> Optional[Dict[str, Any]]:

async def get_job(
session: ClientSession, num_jobs: int = 1
) -> Optional[List[Dict[str, Any]]]:
"""
Get the job from the queue.
Will continue trying to get a job until one is available.
Get a job from the job-take API.
`num_jobs = 1` will query the legacy singular job-take API.
`num_jobs > 1` will query the batch job-take API.
Args:
session (ClientSession): The async http client to use for the request.
retry (bool): Whether to retry if no job is available.
session (ClientSession): The aiohttp ClientSession to use for the request.
num_jobs (int): The number of jobs to get.
"""
next_job = None

while next_job is None:
try:
async with session.get(_job_get_url()) as response:
if response.status == 204:
log.debug("No content, no job to process.")
if retry is False:
break
continue

if response.status == 400:
log.debug(
"Received 400 status, expected when FlashBoot is enabled."
)
if retry is False:
break
continue

if response.status != 200:
log.error(f"Failed to get job, status code: {response.status}")
if retry is False:
break
continue

received_request = await response.json()
log.debug(f"Request Received | {received_request}")

# Check if the job is valid
job_id = received_request.get("id", None)
job_input = received_request.get("input", None)

if None in [job_id, job_input]:
missing_fields = []
if job_id is None:
missing_fields.append("id")
if job_input is None:
missing_fields.append("input")

log.error(f"Job has missing field(s): {', '.join(missing_fields)}.")
else:
next_job = received_request

except asyncio.TimeoutError:
log.debug("Timeout error, retrying.")
if retry is False:
break

except Exception as err: # pylint: disable=broad-except
err_type = type(err).__name__
err_message = str(err)
err_traceback = traceback.format_exc()
log.error(
f"Failed to get job. | Error Type: {err_type} | Error Message: {err_message}"
)
log.error(f"Traceback: {err_traceback}")

if next_job is None:
log.debug("No job available, waiting for the next one.")
if retry is False:
break

await asyncio.sleep(1)
else:
job_list.add_job(next_job["id"])
log.debug("Request ID added.", next_job["id"])

return next_job
try:
async with session.get(_job_get_url(num_jobs)) as response:
if response.status == 204:
log.debug("No content, no job to process.")
return

if response.status == 400:
log.debug("Received 400 status, expected when FlashBoot is enabled.")
return

if response.status != 200:
log.error(f"Failed to get job, status code: {response.status}")
return

jobs = await response.json()
log.debug(f"Request Received | {jobs}")

# legacy job-take API
if isinstance(jobs, dict):
if "id" not in jobs or "input" not in jobs:
raise Exception("Job has missing field(s): id or input.")
return [jobs]

# batch job-take API
if isinstance(jobs, list):
return jobs

except asyncio.TimeoutError:
log.debug("Timeout error, retrying.")

except Exception as error:
log.error(
f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}"
)

return None
# empty
return []


async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -164,7 +141,7 @@ async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]:

check_return_size(run_result) # Checks the size of the return body.

except Exception as err: # pylint: disable=broad-except
except Exception as err:
error_info = {
"error_type": str(type(err)),
"error_message": str(err),
Expand Down Expand Up @@ -209,7 +186,7 @@ async def run_job_generator(
log.debug(f"Generator output: {output_partial}", job["id"])
yield {"output": output_partial}

except Exception as err: # pylint: disable=broad-except
except Exception as err:
log.error(err, job["id"])
yield {"error": f"handler: {str(err)} \ntraceback: {traceback.format_exc()}"}
finally:
Expand Down
2 changes: 1 addition & 1 deletion runpod/serverless/modules/rp_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Optional

MAX_MESSAGE_LENGTH = 4096
LOG_LEVELS = ["NOTSET", "DEBUG", "TRACE", "INFO", "WARN", "ERROR"]
LOG_LEVELS = ["NOTSET", "TRACE", "DEBUG", "INFO", "WARN", "ERROR"]


def _validate_log_level(log_level):
Expand Down
6 changes: 3 additions & 3 deletions runpod/serverless/modules/rp_ping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

from runpod.http_client import SyncClientSession
from runpod.serverless.modules.rp_logger import RunPodLogger
from runpod.serverless.modules.worker_state import WORKER_ID, Jobs
from runpod.serverless.modules.worker_state import WORKER_ID, JobsQueue
from runpod.version import __version__ as runpod_version

log = RunPodLogger()
jobs = Jobs() # Contains the list of jobs that are currently running.
jobs = JobsQueue() # Contains the list of jobs that are currently running.


class Heartbeat:
Expand Down Expand Up @@ -96,7 +96,7 @@ def _send_ping(self):
)

log.debug(
f"Heartbeat Sent | URL: {self.PING_URL} | Status: {result.status_code}"
f"Heartbeat Sent | URL: {result.url} | Status: {result.status_code}"
)

except requests.RequestException as err:
Expand Down
Loading

0 comments on commit f1e6450

Please sign in to comment.