diff --git a/lumigator/backend/backend/services/jobs.py b/lumigator/backend/backend/services/jobs.py index 5963b77e..51866d08 100644 --- a/lumigator/backend/backend/services/jobs.py +++ b/lumigator/backend/backend/services/jobs.py @@ -289,7 +289,7 @@ def retrieve_job_logs(self, job_id: UUID) -> JobLogsResponse: except json.JSONDecodeError as e: raise JobUpstreamError("ray", f"JSON decode error from {resp.text or ''}") from e - async def wait_for_job_complete(self, job_id, max_wait_time_sec=None): + async def wait_for_job_complete(self, job_id, max_wait_time_sec): """Waits for a job to complete, or until a maximum wait time is reached. :param job_id: The ID of the job to wait for. @@ -306,7 +306,7 @@ async def wait_for_job_complete(self, job_id, max_wait_time_sec=None): # Wait for the job to complete elapsed_time = 0 while job_status not in self.TERMINAL_STATUS: - if max_wait_time_sec and elapsed_time >= max_wait_time_sec: + if elapsed_time >= max_wait_time_sec: loguru.logger.info(f"Job {job_id} did not complete within the maximum wait time.") break await asyncio.sleep(5) @@ -478,6 +478,7 @@ def create_job( # - annotation jobs do not run in workflows => they trigger dataset saving here at job level # As JobType.ANNOTATION is not used uniformly throughout our code yet, we rely on the already # existing `store_to_dataset` parameter to explicitly trigger this in the annotation case + # FIXME add timeout to job spec too (and override at workflow?) if job_type == JobType.INFERENCE and request.job_config.store_to_dataset: self.add_background_task(self._background_tasks, self.handle_inference_job, record.id, request) diff --git a/lumigator/backend/backend/services/workflows.py b/lumigator/backend/backend/services/workflows.py index 5a1b94ff..90570ecc 100644 --- a/lumigator/backend/backend/services/workflows.py +++ b/lumigator/backend/backend/services/workflows.py @@ -1,7 +1,11 @@ import json +from http import HTTPStatus from pathlib import Path +from urllib.parse import urljoin +from uuid import UUID import loguru +import requests from fastapi import BackgroundTasks from lumigator_schemas.jobs import ( JobCreate, @@ -20,6 +24,9 @@ from backend.repositories.jobs import JobRepository from backend.services.datasets import DatasetService +from backend.services.exceptions.job_exceptions import ( + JobUpstreamError, +) from backend.services.exceptions.workflow_exceptions import ( WorkflowNotFoundError, WorkflowValidationError, @@ -53,6 +60,17 @@ def __init__( # TODO: rely on https://github.com/ray-project/ray/blob/7c2a200ef84f17418666dad43017a82f782596a3/python/ray/dashboard/modules/job/common.py#L53 self.TERMINAL_STATUS = [JobStatus.FAILED.value, JobStatus.SUCCEEDED.value] + # Maybe move to the job service? + def _stop_job(self, job_id: UUID): + resp = requests.post(urljoin(settings.RAY_JOBS_URL, f"{job_id}/stop"), timeout=5) # 5 seconds + if resp.status_code == HTTPStatus.NOT_FOUND: + raise JobUpstreamError("ray", "job_id not found when retrieving logs") from None + elif resp.status_code != HTTPStatus.OK: + raise JobUpstreamError( + "ray", + f"Unexpected status code getting job logs: {resp.status_code}, error: {resp.text or ''}", + ) from None + async def _run_inference_eval_pipeline( self, workflow: WorkflowResponse, @@ -87,9 +105,15 @@ async def _run_inference_eval_pipeline( self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.RUNNING) # wait for the inference job to complete - status = await self._job_service.wait_for_job_complete(inference_job.id, max_wait_time_sec=60 * 10) + status = await self._job_service.wait_for_job_complete( + inference_job.id, max_wait_time_sec=request.job_timeout_sec + ) if status != JobStatus.SUCCEEDED: loguru.logger.error(f"Inference job {inference_job.id} failed") + try: + self._stop_job(inference_job.id) + except JobUpstreamError: + loguru.logger.error(f"Failed to stop infer job {inference_job.id}, continuing") self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.FAILED) raise Exception(f"Inference job {inference_job.id} failed") @@ -126,10 +150,16 @@ async def _run_inference_eval_pipeline( ) # wait for the evaluation job to complete - status = await self._job_service.wait_for_job_complete(evaluation_job.id, max_wait_time_sec=60 * 10) + status = await self._job_service.wait_for_job_complete( + evaluation_job.id, max_wait_time_sec=request.job_timeout_sec + ) self._job_service._validate_results(evaluation_job.id, self._dataset_service.s3_filesystem) if status != JobStatus.SUCCEEDED: loguru.logger.error(f"Evaluation job {evaluation_job.id} failed") + try: + self._stop_job(evaluation_job.id) + except JobUpstreamError: + loguru.logger.error(f"Failed to stop eval job {evaluation_job.id}, continuing") self._tracking_client.update_workflow_status(workflow.id, WorkflowStatus.FAILED) try: loguru.logger.info("Handling evaluation result") diff --git a/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py b/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py index 0c9b58fb..15d0eafd 100644 --- a/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py +++ b/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py @@ -256,6 +256,7 @@ def run_workflow(local_client: TestClient, dataset_id, experiment_id, workflow_n "dataset": str(dataset_id), "experiment_id": experiment_id, "max_samples": 1, + "job_timeout_sec": 1, }, ).json() ) diff --git a/lumigator/schemas/lumigator_schemas/workflows.py b/lumigator/schemas/lumigator_schemas/workflows.py index 91608051..c0e071c6 100644 --- a/lumigator/schemas/lumigator_schemas/workflows.py +++ b/lumigator/schemas/lumigator_schemas/workflows.py @@ -1,7 +1,7 @@ import datetime from uuid import UUID -from pydantic import BaseModel +from pydantic import BaseModel, NonNegativeInt from lumigator_schemas.jobs import ( JobResults, @@ -28,6 +28,7 @@ class WorkflowCreateRequest(BaseModel): system_prompt: str | None = None inference_output_field: str = "predictions" config_template: str | None = None + job_timeout_sec: NonNegativeInt = 60 * 10 class WorkflowResponse(BaseModel, from_attributes=True):