Skip to content

Commit

Permalink
First attempt at a parametrized JobCreate
Browse files Browse the repository at this point in the history
  • Loading branch information
javiermtorres committed Jan 24, 2025
1 parent 79a92ce commit d8ae072
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import (
JobEvalCreate,
JobCreate,
)

from backend.api.deps import JobServiceDep
Expand All @@ -21,7 +21,7 @@
def create_experiment(
service: JobServiceDep, request: ExperimentCreate, background_tasks: BackgroundTasks
) -> ExperimentResponse:
return service.create_job(JobEvalCreate.model_validate(request.model_dump()), background_tasks)
return service.create_job(JobCreate.model_validate(request.model_dump()), background_tasks)


@router.get("/{experiment_id}")
Expand Down
24 changes: 13 additions & 11 deletions lumigator/python/mzai/backend/backend/api/routes/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import (
Job,
JobAnnotateCreate,
JobEvalCreate,
JobEvalLiteCreate,
JobInferenceCreate,
JobAnnotateConfig,
JobCreate,
JobLogsResponse,
JobResponse,
JobResultDownloadResponse,
Expand All @@ -33,7 +31,7 @@
@router.post("/inference/", status_code=status.HTTP_201_CREATED)
def create_inference_job(
service: JobServiceDep,
job_create_request: JobInferenceCreate,
job_create_request: JobCreate,
request: Request,
response: Response,
background_tasks: BackgroundTasks,
Expand All @@ -49,7 +47,7 @@ def create_inference_job(
@router.post("/annotate/", status_code=status.HTTP_201_CREATED)
def create_annotation_job(
service: JobServiceDep,
job_create_request: JobAnnotateCreate,
job_create_request: JobCreate,
request: Request,
response: Response,
background_tasks: BackgroundTasks,
Expand All @@ -58,12 +56,16 @@ def create_annotation_job(
reference model should be used to generate annotations.
See more: https://blog.mozilla.ai/lets-build-an-app-for-evaluating-llms/
"""
inference_job_create_request = JobInferenceCreate(
**job_create_request.dict(),
inference_job_create_config = JobAnnotateConfig(
**job_create_request.job_config.dict(),
model="hf://facebook/bart-large-cnn",
output_field="ground_truth",
)
inference_job_create_request.store_to_dataset = True
inference_job_create_config.store_to_dataset = True
inference_job_create_request_dict = job_create_request.model_dump()
inference_job_create_request_dict.job_config = inference_job_create_config

inference_job_create_request = JobCreate(**inference_job_create_request_dict)
job_response = service.create_job(inference_job_create_request, background_tasks)

url = request.url_for(get_job.__name__, job_id=job_response.id)
Expand All @@ -75,7 +77,7 @@ def create_annotation_job(
@router.post("/evaluate/", status_code=status.HTTP_201_CREATED)
def create_evaluation_job(
service: JobServiceDep,
job_create_request: JobEvalCreate,
job_create_request: JobCreate,
request: Request,
response: Response,
background_tasks: BackgroundTasks,
Expand All @@ -93,7 +95,7 @@ def create_evaluation_job(
@router.post("/eval_lite/", status_code=status.HTTP_201_CREATED)
def create_evaluation_lite_job(
service: JobServiceDep,
job_create_request: JobEvalLiteCreate,
job_create_request: JobCreate,
request: Request,
response: Response,
background_tasks: BackgroundTasks,
Expand Down
7 changes: 3 additions & 4 deletions lumigator/python/mzai/backend/backend/services/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from lumigator_schemas.experiments import ExperimentCreate, ExperimentResponse
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import (
JobEvalLiteCreate,
JobInferenceCreate,
JobCreate,
JobStatus,
)

Expand Down Expand Up @@ -98,7 +97,7 @@ def _run_eval(

# submit the job
self._job_service.create_job(
JobEvalLiteCreate.model_validate(job_eval_dict),
JobCreate.model_validate(job_eval_dict),
background_tasks,
experiment_id=experiment_id,
)
Expand Down Expand Up @@ -132,7 +131,7 @@ def create_experiment(

# submit inference job first
job_response = self._job_service.create_job(
JobInferenceCreate.model_validate(job_inference_dict),
JobCreate.model_validate(job_inference_dict),
background_tasks,
experiment_id=experiment_record.id,
)
Expand Down
82 changes: 40 additions & 42 deletions lumigator/python/mzai/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import (
JobConfig,
JobEvalCreate,
JobEvalLiteCreate,
JobInferenceCreate,
JobCreate,
JobEvalConfig,
JobEvalLiteConfig,
JobInferenceConfig,
JobResponse,
JobResultDownloadResponse,
JobResultResponse,
JobStatus,
JobType,
)
from pydantic import BaseModel
from ray.job_submission import JobSubmissionClient
from s3fs import S3FileSystem

Expand All @@ -37,6 +37,8 @@
from backend.services.datasets import DatasetService
from backend.settings import settings

JobSpecificRestrictedConfig = type[JobEvalConfig | JobEvalLiteConfig | JobInferenceConfig]


class JobService:
# set storage path
Expand Down Expand Up @@ -129,7 +131,7 @@ def _results_to_binary_file(self, results: str, fields: list[str]) -> BytesIO:

return bin_data

def _add_dataset_to_db(self, job_id: UUID, request: JobInferenceCreate, s3: S3FileSystem):
def _add_dataset_to_db(self, job_id: UUID, request: JobCreate, s3: S3FileSystem):
loguru.logger.info("Adding a new dataset entry to the database...")

# Get the dataset from the S3 bucket
Expand Down Expand Up @@ -167,9 +169,7 @@ def _add_dataset_to_db(self, job_id: UUID, request: JobInferenceCreate, s3: S3Fi
f"Dataset '{dataset_filename}' with ID '{dataset_record.id}' added to the database."
)

def _validate_evaluation_results(
self, job_id: UUID, request: JobEvalLiteCreate, s3: S3FileSystem
):
def _validate_evaluation_results(self, job_id: UUID, request: JobCreate, s3: S3FileSystem):
"""Handles the evaluation result for a given job.
Args:
Expand Down Expand Up @@ -230,14 +230,14 @@ def _get_config_template(self, job_type: str, model_name: str) -> str:

return config_template

def _set_model_type(self, request: BaseModel) -> str:
def _set_model_type(self, request: JobCreate) -> str:
"""Sets model URL based on protocol address"""
if request.model.startswith("oai://"):
if request.job_config.model.startswith("oai://"):
model_url = settings.OAI_API_URL
elif request.model.startswith("mistral://"):
elif request.job_config.model.startswith("mistral://"):
model_url = settings.MISTRAL_API_URL
else:
model_url = request.model_url
model_url = request.job_config.model_url

return model_url

Expand All @@ -249,7 +249,7 @@ def _validate_config(self, job_type: str, config_template: str, config_params: d
else:
loguru.logger.info(f"Validation for job type {job_type} not yet supported.")

def _get_job_params(self, job_type: str, record, request: BaseModel) -> dict:
def _get_job_params(self, job_type: str, record, request: JobCreate) -> dict:
# get dataset S3 path from UUID
dataset_s3_path = self._dataset_service.get_dataset_s3_path(request.dataset)

Expand All @@ -263,19 +263,19 @@ def _get_job_params(self, job_type: str, record, request: BaseModel) -> dict:
job_params = {
"job_id": record.id,
"job_name": request.name,
"model_uri": request.model,
"model_uri": request.job_config.model,
"dataset_path": dataset_s3_path,
"max_samples": request.max_samples,
"storage_path": self.storage_path,
"model_url": self._set_model_type(request),
"system_prompt": request.system_prompt,
"skip_inference": request.skip_inference,
"system_prompt": request.job_config.system_prompt,
"skip_inference": request.job_config.skip_inference,
}
elif job_type == JobType.EVALUATION_LITE:
job_params = {
"job_id": record.id,
"job_name": request.name,
"model_uri": request.model,
"model_uri": request.job_config.model,
"dataset_path": dataset_s3_path,
"max_samples": request.max_samples,
"storage_path": self.storage_path,
Expand All @@ -284,49 +284,44 @@ def _get_job_params(self, job_type: str, record, request: BaseModel) -> dict:
job_params = {
"job_id": record.id,
"job_name": request.name,
"model_uri": request.model,
"model_uri": request.job_config.model,
"dataset_path": dataset_s3_path,
"task": request.task,
"accelerator": request.accelerator,
"revision": request.revision,
"use_fast": request.use_fast,
"trust_remote_code": request.trust_remote_code,
"torch_dtype": request.torch_dtype,
"task": request.job_config.task,
"accelerator": request.job_config.accelerator,
"revision": request.job_config.revision,
"use_fast": request.job_config.use_fast,
"trust_remote_code": request.job_config.trust_remote_code,
"torch_dtype": request.job_config.torch_dtype,
"max_samples": request.max_samples,
"storage_path": self.storage_path,
"model_url": self._set_model_type(request),
"system_prompt": request.system_prompt,
"output_field": request.output_field,
"max_tokens": request.max_tokens,
"frequency_penalty": request.frequency_penalty,
"temperature": request.temperature,
"top_p": request.top_p,
"system_prompt": request.job_config.system_prompt,
"output_field": request.job_config.output_field,
"max_tokens": request.job_config.max_tokens,
"frequency_penalty": request.job_config.frequency_penalty,
"temperature": request.job_config.temperature,
"top_p": request.job_config.top_p,
}

return job_params

def create_job(
self,
request: JobEvalCreate | JobEvalLiteCreate | JobInferenceCreate,
request: JobCreate,
background_tasks: BackgroundTasks,
experiment_id: UUID = None,
) -> JobResponse:
"""Creates a new evaluation workload to run on Ray and returns the response status."""
if isinstance(request, JobEvalCreate):
job_type = JobType.EVALUATION
elif isinstance(request, JobEvalLiteCreate):
job_type = JobType.EVALUATION_LITE
elif isinstance(request, JobInferenceCreate):
job_type = JobType.INFERENCE
else:
raise HTTPException(status.HTTP_501_NOT_IMPLEMENTED, "Job type not implemented.")
# Typing won't allow other job_type's
job_type = request.job_config.job_type

# Create a db record for the job
record = self.job_repo.create(
name=request.name, description=request.description, experiment_id=experiment_id
)

if isinstance(request, JobInferenceCreate) and not request.output_field:
# TODO defer to specific job
if job_type == JobType.INFERENCE and not request.output_field:
request.output_field = "predictions"

# prepare configuration parameters, which depend both on the user inputs
Expand All @@ -343,12 +338,15 @@ def create_job(

loguru.logger.info(f"template...{config_template, job_type, request.model}")

# The idea would be to remove this step...
self._validate_config(job_type, config_template, config_params)

# eval_config_args is used to map input configuration parameters with
# command parameters provided via command line to the ray job.
# To do this, we use a dict where keys are parameter names as they'd
# appear on the command line and the values are the respective params.

# ...and use directly Job*Config(request.job.config.model_dump_json())
job_config_args = {
"--config": config_template.format(**config_params),
}
Expand Down Expand Up @@ -400,7 +398,7 @@ def create_job(
record.id,
self._add_dataset_to_db,
record.id,
JobInferenceCreate.model_validate(request),
JobCreate.model_validate(request),
self._dataset_service.s3_filesystem,
)
elif job_type == JobType.EVALUATION_LITE:
Expand All @@ -409,7 +407,7 @@ def create_job(
record.id,
self._validate_evaluation_results,
record.id,
JobEvalLiteCreate.model_validate(request),
JobCreate.model_validate(request),
self._dataset_service.s3_filesystem,
)
# FIXME The ray status is now _not enough_ to set the job status,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
import pytest
from lumigator_schemas.jobs import (
JobInferenceCreate,
JobCreate,
JobInferenceConfig,
JobType,
)

from backend.services.jobs import JobService
from backend.settings import settings


def test_set_null_inference_job_params(job_record, job_service):
request = JobInferenceCreate(
request = JobCreate(
name="test_run_hugging_face",
description="Test run for Huggingface model",
model="hf://facebook/bart-large-cnn",
job_config=JobInferenceConfig(
job_type=JobType.INFERENCE, model="hf://facebook/bart-large-cnn"
),
dataset="cced289c-f869-4af1-9195-1d58e32d1cc1",
)
params = job_service._get_job_params("INFERENCE", job_record, request)
assert params["max_samples"] == -1


def test_set_explicit_inference_job_params(job_record, job_service):
request = JobInferenceCreate(
request = JobCreate(
name="test_run_hugging_face",
description="Test run for Huggingface model",
max_samples=10,
model="hf://facebook/bart-large-cnn",
job_config=JobInferenceConfig(
job_type=JobType.INFERENCE, model="hf://facebook/bart-large-cnn"
),
dataset="cced289c-f869-4af1-9195-1d58e32d1cc1",
)
params = job_service._get_job_params("INFERENCE", job_record, request)
Expand Down Expand Up @@ -54,11 +60,14 @@ def test_set_explicit_inference_job_params(job_record, job_service):
],
)
def test_set_model(job_service, model, input_model_url, returned_model_url):
request = JobInferenceCreate(
request = JobCreate(
name="test_run",
description="Test run to verify how model URL is set",
model=model,
model_url=input_model_url,
job_config=JobInferenceConfig(
job_type=JobType.INFERENCE,
model=model,
model_url=input_model_url,
),
dataset="d34dd34d-d34d-d34d-d34d-d34dd34dd34d",
)
model_url = job_service._set_model_type(request)
Expand Down
Loading

0 comments on commit d8ae072

Please sign in to comment.