Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a timeout to the workflow API #937

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lumigator/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def retrieve_job_logs(self, job_id: UUID) -> JobLogsResponse:
except json.JSONDecodeError as e:
raise JobUpstreamError("ray", f"JSON decode error from {resp.text or ''}") from e

async def wait_for_job_complete(self, job_id, max_wait_time_sec=None):
async def wait_for_job_complete(self, job_id, max_wait_time_sec):
"""Waits for a job to complete, or until a maximum wait time is reached.

:param job_id: The ID of the job to wait for.
Expand All @@ -306,7 +306,7 @@ async def wait_for_job_complete(self, job_id, max_wait_time_sec=None):
# Wait for the job to complete
elapsed_time = 0
while job_status not in self.TERMINAL_STATUS:
if max_wait_time_sec and elapsed_time >= max_wait_time_sec:
if elapsed_time >= max_wait_time_sec:
loguru.logger.info(f"Job {job_id} did not complete within the maximum wait time.")
break
await asyncio.sleep(5)
Expand Down Expand Up @@ -478,6 +478,7 @@ def create_job(
# - annotation jobs do not run in workflows => they trigger dataset saving here at job level
# As JobType.ANNOTATION is not used uniformly throughout our code yet, we rely on the already
# existing `store_to_dataset` parameter to explicitly trigger this in the annotation case
# FIXME add timeout to job spec too (and override at workflow?)
if job_type == JobType.INFERENCE and request.job_config.store_to_dataset:
self.add_background_task(self._background_tasks, self.handle_inference_job, record.id, request)

Expand Down
34 changes: 32 additions & 2 deletions lumigator/backend/backend/services/workflows.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import json
from http import HTTPStatus
from pathlib import Path
from urllib.parse import urljoin
from uuid import UUID

import loguru
import requests
from fastapi import BackgroundTasks
from lumigator_schemas.jobs import (
JobCreate,
Expand All @@ -20,6 +24,9 @@

from backend.repositories.jobs import JobRepository
from backend.services.datasets import DatasetService
from backend.services.exceptions.job_exceptions import (
JobUpstreamError,
)
from backend.services.exceptions.workflow_exceptions import (
WorkflowNotFoundError,
WorkflowValidationError,
Expand Down Expand Up @@ -53,6 +60,17 @@ def __init__(
# TODO: rely on https://github.com/ray-project/ray/blob/7c2a200ef84f17418666dad43017a82f782596a3/python/ray/dashboard/modules/job/common.py#L53
self.TERMINAL_STATUS = [JobStatus.FAILED.value, JobStatus.SUCCEEDED.value]

# Maybe move to the job service?
def _stop_job(self, job_id: UUID):
resp = requests.post(urljoin(settings.RAY_JOBS_URL, f"{job_id}/stop"), timeout=5) # 5 seconds
if resp.status_code == HTTPStatus.NOT_FOUND:
raise JobUpstreamError("ray", "job_id not found when retrieving logs") from None
elif resp.status_code != HTTPStatus.OK:
raise JobUpstreamError(
"ray",
f"Unexpected status code getting job logs: {resp.status_code}, error: {resp.text or ''}",
) from None
Comment on lines +63 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Yeah this does seem like something that should be in the job service.


async def _run_inference_eval_pipeline(
self,
workflow: WorkflowResponse,
Expand Down Expand Up @@ -87,9 +105,15 @@ async def _run_inference_eval_pipeline(
self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.RUNNING)

# wait for the inference job to complete
status = await self._job_service.wait_for_job_complete(inference_job.id, max_wait_time_sec=60 * 10)
status = await self._job_service.wait_for_job_complete(
inference_job.id, max_wait_time_sec=request.job_timeout_sec
)
if status != JobStatus.SUCCEEDED:
loguru.logger.error(f"Inference job {inference_job.id} failed")
try:
self._stop_job(inference_job.id)
except JobUpstreamError:
loguru.logger.error(f"Failed to stop infer job {inference_job.id}, continuing")
self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.FAILED)
raise Exception(f"Inference job {inference_job.id} failed")

Expand Down Expand Up @@ -126,10 +150,16 @@ async def _run_inference_eval_pipeline(
)

# wait for the evaluation job to complete
status = await self._job_service.wait_for_job_complete(evaluation_job.id, max_wait_time_sec=60 * 10)
status = await self._job_service.wait_for_job_complete(
evaluation_job.id, max_wait_time_sec=request.job_timeout_sec
)
self._job_service._validate_results(evaluation_job.id, self._dataset_service.s3_filesystem)
if status != JobStatus.SUCCEEDED:
loguru.logger.error(f"Evaluation job {evaluation_job.id} failed")
try:
self._stop_job(evaluation_job.id)
except JobUpstreamError:
loguru.logger.error(f"Failed to stop eval job {evaluation_job.id}, continuing")
self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.FAILED)
try:
loguru.logger.info("Handling evaluation result")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def run_workflow(local_client: TestClient, dataset_id, experiment_id, workflow_n
"dataset": str(dataset_id),
"experiment_id": experiment_id,
"max_samples": 1,
"job_timeout_sec": 1,
},
).json()
)
Expand Down
3 changes: 2 additions & 1 deletion lumigator/schemas/lumigator_schemas/workflows.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
from uuid import UUID

from pydantic import BaseModel
from pydantic import BaseModel, NonNegativeInt

from lumigator_schemas.jobs import (
JobResults,
Expand All @@ -28,6 +28,7 @@ class WorkflowCreateRequest(BaseModel):
system_prompt: str | None = None
inference_output_field: str = "predictions"
config_template: str | None = None
job_timeout_sec: NonNegativeInt = 60 * 10


class WorkflowResponse(BaseModel, from_attributes=True):
Expand Down
Loading