Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mlflow implementation of Tracking Interface #768

Merged
merged 58 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
acfb6c6
A non-working example to act as talking point
njbrake Jan 23, 2025
3ead888
Explain the new experiments endpoint
njbrake Jan 23, 2025
5c944fd
fix the tags
njbrake Jan 23, 2025
7d5825c
Fix typo
njbrake Jan 23, 2025
01ad9e7
Merge branch 'main' into brake/route_rename_proposal
njbrake Jan 24, 2025
fedd4d1
Create schemas and deps to handle the concept of a run
njbrake Jan 24, 2025
9c54621
rename runs to workflows
njbrake Jan 24, 2025
0baafb2
Reorganizing and making exception for functionality that doesn't yet …
njbrake Jan 24, 2025
2ae8375
Merge the remaining endpoints into the route
njbrake Jan 24, 2025
02adc20
keep the "new" name so it's clear what I'm doing
njbrake Jan 24, 2025
3b812b8
dont use the workflows endpoints for the old tests yet
njbrake Jan 24, 2025
2c96096
fix the unit test
njbrake Jan 24, 2025
0dd476d
Mlflow interface created
njbrake Jan 27, 2025
ea60bec
messed up naming dup
njbrake Jan 27, 2025
55bb75a
Merge branch 'main' into brake/route_rename_proposal
njbrake Jan 28, 2025
201e3ae
Not yet working, ckpt
njbrake Jan 28, 2025
69dc098
Address PR comments
njbrake Jan 28, 2025
55a760f
Merge branch 'main' into brake/route_rename_proposal
njbrake Jan 28, 2025
81e8665
Merge branch 'main' into brake/route_rename_proposal
peteski22 Jan 28, 2025
6df8ae3
Updated workflow service based on changes to job and experiment servi…
peteski22 Jan 28, 2025
cf8ecda
Merge branch 'main' into brake/route_rename_proposal
njbrake Jan 28, 2025
093c368
reinstate the experiments exception mapping
njbrake Jan 28, 2025
a134ce3
merge fixes
njbrake Jan 28, 2025
aa299a6
Merge remote-tracking branch 'origin/main' into brake/route_rename_pr…
njbrake Jan 29, 2025
81a9840
fix linting issue
njbrake Jan 29, 2025
2cdff95
Fix my bad merge
njbrake Jan 29, 2025
00a4301
Fix unit tests after merge
njbrake Jan 29, 2025
016980f
Merge remote-tracking branch 'origin/main' into brake/route_rename_pr…
njbrake Jan 29, 2025
bbf2498
Docs cleanup
njbrake Jan 29, 2025
de12c3e
Merge branch 'brake/route_rename_proposal' into 741-tracking-interface
njbrake Jan 29, 2025
19347df
Mlflow only deployed in dev mode
njbrake Jan 29, 2025
76b3ea7
Merge remote-tracking branch 'origin/main' into 741-tracking-interface
njbrake Jan 29, 2025
a3a1e03
update the terminology to be "job" instead of the mlflow specific "ru…
njbrake Jan 29, 2025
6dfa06d
Merge remote-tracking branch 'origin/741-tracking-interface' into 527…
njbrake Jan 29, 2025
99a9825
Separate job handling logic from job management logic
njbrake Jan 29, 2025
4e03d3a
Show how you can dynamically add workflows
njbrake Jan 29, 2025
166375c
Clean up the routes!
njbrake Jan 29, 2025
a0be305
Merge remote-tracking branch 'origin/main' into 741-tracking-interface
njbrake Jan 30, 2025
fec2f2b
uv lock fix
njbrake Jan 30, 2025
ea738e2
Merge remote-tracking branch 'origin/741-tracking-interface' into 527…
njbrake Jan 30, 2025
a1fe055
Fix bad merge
njbrake Jan 30, 2025
38d73f9
Delete functionality
njbrake Jan 30, 2025
9bb13bd
Support log retrieval
njbrake Jan 30, 2025
eb980a2
log retrieval working
njbrake Jan 30, 2025
caf2e6b
mlflow is now a part of the deployment
njbrake Jan 30, 2025
ed9fcd8
Merge remote-tracking branch origin/741-tracking-interface into 527-m…
njbrake Jan 30, 2025
8363487
Clean up for PR
njbrake Jan 31, 2025
9ef25db
Don't expose workflows in the schema just yet
njbrake Jan 31, 2025
0f85ac3
Merge remote-tracking branch 'origin/main' into 527-mlflow-implementa…
njbrake Jan 31, 2025
f8a2168
Updates based on PR comments
njbrake Jan 31, 2025
98eb0f3
merge complete
njbrake Jan 31, 2025
5a5ddb6
Merge remote-tracking branch 'origin/main' into 527-mlflow-implementa…
njbrake Feb 4, 2025
6156529
Merge branch 'main' into 527-mlflow-implementation
njbrake Feb 5, 2025
30e8028
Merge remote-tracking branch 'origin/main' into 527-mlflow-implementa…
njbrake Feb 5, 2025
55350c1
Updates based on PR comments
njbrake Feb 5, 2025
84cc381
shift task addition back into service layer
njbrake Feb 5, 2025
1663f44
Merge branch 'main' into 527-mlflow-implementation
njbrake Feb 5, 2025
663dcee
missing f-string
njbrake Feb 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions .devcontainer/docker-compose.override.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,6 @@ services:
- database_volume:/mzai/backend/local.db
ports:
- "5678:5678"
environment:
- MLFLOW_TRACKING_URI
depends_on:
mlflow:
condition: "service_started"
required: false
develop:
watch:
- path: lumigator/backend/
Expand All @@ -36,23 +30,3 @@ services:
- .venv/
- path: lumigator/backend/pyproject.toml
action: rebuild

mlflow:
image: ghcr.io/mlflow/mlflow:v2.0.1
environment:
- MLFLOW_TRACKING_URI=${MLFLOW_TRACKING_URI}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
- BACKEND_STORE_URI=sqlite:///mlflow.db
- ARTIFACT_ROOT=s3://mlflow`
ports:
- "8001:5000"
depends_on:
minio:
condition: service_healthy
command: mlflow server --backend-store-uri ${BACKEND_STORE_URI} --default-artifact-root ${ARTIFACT_ROOT} --host 0.0.0.0
extra_hosts:
- "localhost:host-gateway"
profiles:
- local
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ test-backend-unit:
RAY_HEAD_NODE_HOST=localhost \
RAY_DASHBOARD_PORT=8265 \
SQLALCHEMY_DATABASE_URL=sqlite:////tmp/local.db \
MLFLOW_TRACKING_URI=http://localhost:8001 \
PYTHONPATH=../jobs:$$PYTHONPATH \
uv run $(DEBUGPY_ARGS) -m pytest -s -o python_files="backend/tests/unit/*/test_*.py backend/tests/unit/test_*.py"

Expand All @@ -191,6 +192,7 @@ test-backend-integration:
RAY_HEAD_NODE_HOST=localhost \
RAY_DASHBOARD_PORT=8265 \
SQLALCHEMY_DATABASE_URL=sqlite:////tmp/local.db \
MLFLOW_TRACKING_URI=http://localhost:8001 \
RAY_WORKER_GPUS="0.0" \
RAY_WORKER_GPUS_FRACTION="0.0" \
INFERENCE_PIP_REQS=../jobs/inference/requirements_cpu.txt \
Expand Down
24 changes: 24 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ services:
ray:
condition: "service_started"
required: false
mlflow:
condition: "service_started"
required: false
ports:
- 8000:8000
environment:
Expand Down Expand Up @@ -145,6 +148,7 @@ services:
- RAY_WORKER_GPUS=$RAY_WORKER_GPUS
- RAY_WORKER_GPUS_FRACTION=$RAY_WORKER_GPUS_FRACTION
- LUMI_API_CORS_ALLOWED_ORIGINS
- MLFLOW_TRACKING_URI
# NOTE: to keep AWS_ENDPOINT_URL as http://localhost:9000 both on the host system
# and inside containers, we map localhost to the host gateway IP.
# This currently works properly, but might be the cause of networking
Expand Down Expand Up @@ -174,6 +178,26 @@ services:
ports:
- 80:80

mlflow:
image: ghcr.io/mlflow/mlflow:v2.0.1
environment:
- MLFLOW_TRACKING_URI=${MLFLOW_TRACKING_URI}
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID}
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY}
- AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION}
- BACKEND_STORE_URI=sqlite:///mlflow.db
- ARTIFACT_ROOT=s3://mlflow`
ports:
- "8001:5000"
depends_on:
minio:
condition: service_healthy
command: mlflow server --backend-store-uri ${BACKEND_STORE_URI} --default-artifact-root ${ARTIFACT_ROOT} --host 0.0.0.0
extra_hosts:
- "localhost:host-gateway"
profiles:
- local

volumes:
minio-data:
database_volume:
Expand Down
21 changes: 16 additions & 5 deletions lumigator/backend/backend/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

from backend.db import session_manager
from backend.repositories.datasets import DatasetRepository
from backend.repositories.experiments import ExperimentRepository
from backend.repositories.jobs import JobRepository, JobResultRepository
from backend.services.completions import MistralCompletionService, OpenAICompletionService
from backend.services.datasets import DatasetService
from backend.services.experiments import ExperimentService
from backend.services.jobs import JobService
from backend.services.workflows import WorkflowService
from backend.settings import settings
from backend.tracking import tracking_client_manager


def get_db_session() -> Generator[Session, None, None]:
Expand All @@ -28,6 +28,14 @@ def get_db_session() -> Generator[Session, None, None]:
DBSessionDep = Annotated[Session, Depends(get_db_session)]


def get_tracking_client() -> Generator[Session, None, None]:
with tracking_client_manager.connect() as client:
yield client


TrackingClientDep = Annotated[Session, Depends(get_tracking_client)]


def get_s3_client() -> Generator[S3Client, None, None]:
return boto3.client("s3", endpoint_url=settings.S3_ENDPOINT_URL)

Expand Down Expand Up @@ -64,22 +72,25 @@ def get_job_service(session: DBSessionDep, dataset_service: DatasetServiceDep) -

def get_experiment_service(
session: DBSessionDep,
tracking_client: TrackingClientDep,
job_service: JobServiceDep,
dataset_service: DatasetServiceDep,
) -> ExperimentService:
job_repo = JobRepository(session)
experiment_repo = ExperimentRepository(session)
return ExperimentService(experiment_repo, job_repo, job_service, dataset_service)
return ExperimentService(job_repo, job_service, dataset_service, tracking_client)


ExperimentServiceDep = Annotated[ExperimentService, Depends(get_experiment_service)]


def get_workflow_service(
session: DBSessionDep, job_service: JobServiceDep, dataset_service: DatasetServiceDep
session: DBSessionDep,
tracking_client: TrackingClientDep,
job_service: JobServiceDep,
dataset_service: DatasetServiceDep,
) -> WorkflowService:
job_repo = JobRepository(session)
return WorkflowService(job_repo, job_service, dataset_service)
return WorkflowService(job_repo, job_service, dataset_service, tracking_client=tracking_client)


WorkflowServiceDep = Annotated[WorkflowService, Depends(get_workflow_service)]
Expand Down
1 change: 0 additions & 1 deletion lumigator/backend/backend/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
api_router.include_router(experiments.router, prefix="/experiments", tags=[Tags.EXPERIMENTS])
api_router.include_router(completions.router, prefix="/completions", tags=[Tags.COMPLETIONS])
api_router.include_router(models.router, prefix="/models", tags=[Tags.MODELS])
# TODO: Workflows route is not yet ready so it is excluded from the OpenAPI schema
api_router.include_router(
workflows.router, prefix="/workflows", tags=[Tags.WORKFLOWS], include_in_schema=False
)
23 changes: 11 additions & 12 deletions lumigator/backend/backend/api/routes/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from lumigator_schemas.experiments import (
ExperimentCreate,
ExperimentIdCreate,
ExperimentIdResponse,
ExperimentResponse,
ExperimentResultDownloadResponse,
ExperimentResultResponse,
GetExperimentResponse,
)
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import (
Expand Down Expand Up @@ -80,12 +82,12 @@ def get_experiment_result_download(
# TODO: Eventually this route will become the / route,
# but right now it is a placeholder while we build up the Workflows routes
# It's not included in the OpenAPI schema for now so it's not visible in the docs
@router.post("/new", status_code=status.HTTP_201_CREATED, include_in_schema=False)
@router.post("/new", status_code=status.HTTP_201_CREATED, include_in_schema=True)
def create_experiment_id(
service: ExperimentServiceDep, request: ExperimentIdCreate
) -> ExperimentResponse:
) -> ExperimentIdResponse:
"""Create an experiment ID."""
return ExperimentResponse.model_validate(service.create_experiment(request).model_dump())
return ExperimentIdResponse.model_validate(service.create_experiment(request).model_dump())


# TODO: FIXME this should not need the /all suffix.
Expand All @@ -103,15 +105,12 @@ def list_experiments_new(


@router.get("/new/{experiment_id}", include_in_schema=False)
def get_experiment_new(service: ExperimentServiceDep, experiment_id: UUID) -> ExperimentResponse:
def get_experiment_new(service: ExperimentServiceDep, experiment_id: str) -> GetExperimentResponse:
"""Get an experiment by ID."""
return ExperimentResponse.model_validate(service.get_experiment(experiment_id).model_dump())
return GetExperimentResponse.model_validate(service.get_experiment(experiment_id).model_dump())


@router.get("/new/{experiment_id}/workflows", include_in_schema=False)
def get_workflows(service: ExperimentServiceDep, experiment_id: UUID) -> ListingResponse[UUID]:
"""TODO: this endpoint should handle passing in an experiment id and the returning a list
of all the workflows associated with that experiment. Until workflows are stored and associated
with experiments, this is not yet implemented.
"""
raise NotImplementedError
@router.delete("/new/{experiment_id}", include_in_schema=False)
def delete_experiment_new(service: ExperimentServiceDep, experiment_id: str) -> None:
"""Delete an experiment by ID."""
service.delete_experiment(experiment_id)
27 changes: 18 additions & 9 deletions lumigator/backend/backend/api/routes/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ def create_inference_job(
response: Response,
background_tasks: BackgroundTasks,
) -> JobResponse:
job_response = service.create_job(job_create_request, background_tasks)
job_response = service.create_job(job_create_request)

service.add_background_task(
background_tasks, service.handle_inference_job, job_response.id, job_create_request
)

url = request.url_for(get_job.__name__, job_id=job_response.id)
response.headers[HttpHeaders.LOCATION] = f"{url}"
Expand All @@ -78,7 +82,14 @@ def create_annotation_job(
output_field="ground_truth",
)
inference_job_create_request.store_to_dataset = True
job_response = service.create_job(inference_job_create_request, background_tasks)
job_response = service.create_job(inference_job_create_request)

service.add_background_task(
background_tasks,
service.handle_inference_job,
job_response.id,
inference_job_create_request,
)

url = request.url_for(get_job.__name__, job_id=job_response.id)
response.headers[HttpHeaders.LOCATION] = f"{url}"
Expand All @@ -92,9 +103,8 @@ def create_evaluation_job(
job_create_request: JobEvalCreate,
request: Request,
response: Response,
background_tasks: BackgroundTasks,
) -> JobResponse:
job_response = service.create_job(job_create_request, background_tasks)
job_response = service.create_job(job_create_request)

url = request.url_for(get_job.__name__, job_id=job_response.id)
response.headers[HttpHeaders.LOCATION] = f"{url}"
Expand All @@ -110,9 +120,8 @@ def create_evaluation_lite_job(
job_create_request: JobEvalLiteCreate,
request: Request,
response: Response,
background_tasks: BackgroundTasks,
) -> JobResponse:
job_response = service.create_job(job_create_request, background_tasks)
job_response = service.create_job(job_create_request)

url = request.url_for(get_job.__name__, job_id=job_response.id)
response.headers[HttpHeaders.LOCATION] = f"{url}"
Expand Down Expand Up @@ -184,7 +193,7 @@ def get_job(service: JobServiceDep, job_id: UUID) -> Job:

@router.get("/{job_id}/logs")
def get_job_logs(job_id: UUID) -> JobLogsResponse:
resp = requests.get(urljoin(settings.RAY_JOBS_URL, f"{job_id}/logs"))
resp = requests.get(urljoin(settings.RAY_JOBS_URL, f"{job_id}/logs"), timeout=5) # 5 seconds

if resp.status_code == HTTPStatus.NOT_FOUND:
loguru.logger.error(
Expand Down Expand Up @@ -244,7 +253,7 @@ def get_job_result_download(

def _get_all_ray_jobs() -> list[JobSubmissionResponse]:
"""Returns metadata that exists in the Ray cluster for all jobs."""
resp = requests.get(settings.RAY_JOBS_URL)
resp = requests.get(settings.RAY_JOBS_URL, timeout=5) # 5 seconds
if resp.status_code != HTTPStatus.OK:
loguru.logger.error(
f"Unexpected status code getting all jobs: {resp.status_code}, error: {resp.text or ''}"
Expand All @@ -267,7 +276,7 @@ def _get_all_ray_jobs() -> list[JobSubmissionResponse]:

def _get_ray_job(job_id: UUID) -> JobSubmissionResponse:
"""Returns metadata on the specified job if it exists in the Ray cluster."""
resp = requests.get(urljoin(settings.RAY_JOBS_URL, f"{job_id}"))
resp = requests.get(urljoin(settings.RAY_JOBS_URL, f"{job_id}"), timeout=5) # 5 seconds

if resp.status_code == HTTPStatus.NOT_FOUND:
loguru.logger.error(
Expand Down
74 changes: 28 additions & 46 deletions lumigator/backend/backend/api/routes/workflows.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
from uuid import UUID
from http import HTTPStatus

from fastapi import APIRouter, BackgroundTasks, status
from lumigator_schemas.extras import ListingResponse
from lumigator_schemas.jobs import JobResponse
from lumigator_schemas.jobs import JobLogsResponse
from lumigator_schemas.workflows import (
WorkflowCreate,
WorkflowCreateRequest,
WorkflowDetailsResponse,
WorkflowResponse,
WorkflowResultDownloadResponse,
)

from backend.api.deps import WorkflowServiceDep
from backend.services.exceptions.base_exceptions import ServiceError
from backend.services.exceptions.workflow_exceptions import (
WorkflowNotFoundError,
WorkflowValidationError,
)

router = APIRouter()


def workflow_exception_mappings() -> dict[type[ServiceError], HTTPStatus]:
return {
WorkflowNotFoundError: status.HTTP_404_NOT_FOUND,
WorkflowValidationError: status.HTTP_400_BAD_REQUEST,
}


@router.post("/", status_code=status.HTTP_201_CREATED)
async def create_workflow(
service: WorkflowServiceDep, request: WorkflowCreate, background_tasks: BackgroundTasks
service: WorkflowServiceDep, request: WorkflowCreateRequest, background_tasks: BackgroundTasks
) -> WorkflowResponse:
"""A workflow is a single execution for an experiment.
A workflow is a collection of 1 or more jobs.
Expand All @@ -28,52 +38,24 @@ async def create_workflow(


@router.get("/{workflow_id}")
def get_workflow(service: WorkflowServiceDep, workflow_id: UUID) -> WorkflowResponse:
def get_workflow(service: WorkflowServiceDep, workflow_id: str) -> WorkflowDetailsResponse:
"""TODO: The workflow objects are currently not saved in the database so it can't be retrieved.
In order to get all the info about a workflow,
you need to get all the jobs for an experiment and make some decisions about how to use them.
This means you can't yet easily compile a list of all workflows for an experiment.
"""
raise NotImplementedError
return WorkflowDetailsResponse.model_validate(service.get_workflow(workflow_id).model_dump())


# TODO: currently experiment_id=workflow_id, but this will change
@router.get("/{experiment_id}/jobs", include_in_schema=False)
def get_workflow_jobs(
service: WorkflowServiceDep, experiment_id: UUID
) -> ListingResponse[JobResponse]:
"""Get all jobs for a workflow.
TODO: this will likely eventually be merged with the get_workflow endpoint, once implemented
"""
# TODO right now this command expects that the workflow_id is the same as the experiment_id
return ListingResponse[JobResponse].model_validate(
service.get_workflow_jobs(experiment_id).model_dump()
)


@router.get("/{workflow_id}/details")
def get_workflow_details(
service: WorkflowServiceDep,
workflow_id: UUID,
) -> WorkflowDetailsResponse:
"""TODO:Return the results metadata for a run if available in the DB.
This should retrieve the metadata for the job or jobs that were run in the workflow and compile
them into a single response that can be used to populate the UI.
Currently this looks like taking the average results for the
inference job (tok/s, gen length, etc) and the
average results for the evaluation job (ROUGE, BLEU, etc) and
returning them in a single response.
For detailed results you would want to use the get_workflow_details endpoint.
"""
raise NotImplementedError
# get the logs
@router.get("/{workflow_id}/logs")
def get_workflow_logs(service: WorkflowServiceDep, workflow_id: str) -> JobLogsResponse:
"""Get the logs for a workflow."""
return JobLogsResponse.model_validate(service.get_workflow_logs(workflow_id).model_dump())


@router.get("/{workflow_id}/details")
def get_experiment_result_download(
service: WorkflowServiceDep,
workflow_id: UUID,
) -> WorkflowResultDownloadResponse:
"""Return experiment results file URL for downloading."""
return WorkflowResultDownloadResponse.model_validate(
service.get_workflow_result_download(workflow_id).model_dump()
)
# delete a workflow
@router.delete("/{workflow_id}")
def delete_workflow(service: WorkflowServiceDep, workflow_id: str) -> WorkflowDetailsResponse:
"""Delete a workflow by ID."""
return WorkflowDetailsResponse.model_validate(service.delete_workflow(workflow_id).model_dump())
Loading