Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SDK to use new workflows API #783

Merged
merged 55 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
acfb6c6
A non-working example to act as talking point
njbrake Jan 23, 2025
3ead888
Explain the new experiments endpoint
njbrake Jan 23, 2025
5c944fd
fix the tags
njbrake Jan 23, 2025
7d5825c
Fix typo
njbrake Jan 23, 2025
01ad9e7
Merge branch 'main' into brake/route_rename_proposal
njbrake Jan 24, 2025
fedd4d1
Create schemas and deps to handle the concept of a run
njbrake Jan 24, 2025
9c54621
rename runs to workflows
njbrake Jan 24, 2025
0baafb2
Reorganizing and making exception for functionality that doesn't yet …
njbrake Jan 24, 2025
2ae8375
Merge the remaining endpoints into the route
njbrake Jan 24, 2025
02adc20
keep the "new" name so it's clear what I'm doing
njbrake Jan 24, 2025
3b812b8
dont use the workflows endpoints for the old tests yet
njbrake Jan 24, 2025
2c96096
fix the unit test
njbrake Jan 24, 2025
0dd476d
Mlflow interface created
njbrake Jan 27, 2025
ea60bec
messed up naming dup
njbrake Jan 27, 2025
55bb75a
Merge branch 'main' into brake/route_rename_proposal
njbrake Jan 28, 2025
201e3ae
Not yet working, ckpt
njbrake Jan 28, 2025
69dc098
Address PR comments
njbrake Jan 28, 2025
55a760f
Merge branch 'main' into brake/route_rename_proposal
njbrake Jan 28, 2025
81e8665
Merge branch 'main' into brake/route_rename_proposal
peteski22 Jan 28, 2025
6df8ae3
Updated workflow service based on changes to job and experiment servi…
peteski22 Jan 28, 2025
cf8ecda
Merge branch 'main' into brake/route_rename_proposal
njbrake Jan 28, 2025
093c368
reinstate the experiments exception mapping
njbrake Jan 28, 2025
a134ce3
merge fixes
njbrake Jan 28, 2025
aa299a6
Merge remote-tracking branch 'origin/main' into brake/route_rename_pr…
njbrake Jan 29, 2025
81a9840
fix linting issue
njbrake Jan 29, 2025
2cdff95
Fix my bad merge
njbrake Jan 29, 2025
00a4301
Fix unit tests after merge
njbrake Jan 29, 2025
016980f
Merge remote-tracking branch 'origin/main' into brake/route_rename_pr…
njbrake Jan 29, 2025
bbf2498
Docs cleanup
njbrake Jan 29, 2025
de12c3e
Merge branch 'brake/route_rename_proposal' into 741-tracking-interface
njbrake Jan 29, 2025
19347df
Mlflow only deployed in dev mode
njbrake Jan 29, 2025
76b3ea7
Merge remote-tracking branch 'origin/main' into 741-tracking-interface
njbrake Jan 29, 2025
a3a1e03
update the terminology to be "job" instead of the mlflow specific "ru…
njbrake Jan 29, 2025
6dfa06d
Merge remote-tracking branch 'origin/741-tracking-interface' into 527…
njbrake Jan 29, 2025
99a9825
Separate job handling logic from job management logic
njbrake Jan 29, 2025
4e03d3a
Show how you can dynamically add workflows
njbrake Jan 29, 2025
166375c
Clean up the routes!
njbrake Jan 29, 2025
a0be305
Merge remote-tracking branch 'origin/main' into 741-tracking-interface
njbrake Jan 30, 2025
fec2f2b
uv lock fix
njbrake Jan 30, 2025
ea738e2
Merge remote-tracking branch 'origin/741-tracking-interface' into 527…
njbrake Jan 30, 2025
a1fe055
Fix bad merge
njbrake Jan 30, 2025
38d73f9
Delete functionality
njbrake Jan 30, 2025
9bb13bd
Support log retrieval
njbrake Jan 30, 2025
eb980a2
log retrieval working
njbrake Jan 30, 2025
caf2e6b
mlflow is now a part of the deployment
njbrake Jan 30, 2025
ed9fcd8
Merge remote-tracking branch origin/741-tracking-interface into 527-m…
njbrake Jan 30, 2025
8363487
Clean up for PR
njbrake Jan 31, 2025
9ef25db
Don't expose workflows in the schema just yet
njbrake Jan 31, 2025
0f85ac3
Merge remote-tracking branch 'origin/main' into 527-mlflow-implementa…
njbrake Jan 31, 2025
b9588f1
first pass
njbrake Jan 31, 2025
f8a2168
Updates based on PR comments
njbrake Jan 31, 2025
98eb0f3
merge complete
njbrake Jan 31, 2025
229ee19
Merge remote-tracking branch 'origin/527-mlflow-implementation' into …
njbrake Jan 31, 2025
2216e02
Fix the tests
njbrake Jan 31, 2025
a091c10
Merge remote-tracking branch 'origin/main' into brake/sdk_update
njbrake Feb 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"python.testing.pytestEnabled": true,
"python.analysis.extraPaths": [
"./lumigator/schemas",
"./lumigator/jobs"
"./lumigator/jobs",
"./lumigator/sdk"
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,103 @@ def test_full_experiment_launch(
retrieve_and_validate_workflow_logs(local_client, workflow_1_details.id)
delete_experiment_and_validate(local_client, experiment_id)

experiment = local_client.post(
"/experiments/new/",
headers=POST_HEADER,
json={
"name": "test_create_exp_workflow_check_results",
"description": "Test for an experiment with associated workflows",
},
)
assert experiment.status_code == 201
experiment_id = experiment.json()["id"]

# run a workflow for that experiment
workflow_1 = WorkflowResponse.model_validate(
local_client.post(
"/workflows/",
headers=POST_HEADER,
json={
"name": "Workflow_1",
"description": "Test workflow for inf and eval",
"model": TEST_CAUSAL_MODEL,
"dataset": str(dataset.id),
"experiment_id": experiment_id,
"max_samples": 1,
},
).json()
)

# Wait till the workflow is done
workflow_1_details = wait_for_workflow_complete(local_client, workflow_1.id)

experiment_results = GetExperimentResponse.model_validate(
local_client.get(f"/experiments/new/{experiment_id}").json()
)

assert workflow_1_details.experiment_id == experiment_results.id
assert len(experiment_results.workflows) == 1
# the presigned url can be different but everything else should be the same
assert workflow_1_details.model_dump(
exclude={"artifacts_download_url"}
) == experiment_results.workflows[0].model_dump(exclude={"artifacts_download_url"})

# add another workflow to the experiment
workflow_2 = WorkflowResponse.model_validate(
local_client.post(
"/workflows/",
headers=POST_HEADER,
json={
"name": "Workflow_2",
"description": "Test workflow for inf and eval",
"model": TEST_CAUSAL_MODEL,
"dataset": str(dataset.id),
"experiment_id": experiment_id,
"max_samples": 1,
},
).json()
)

# Wait till the workflow is done
workflow_2_details = wait_for_workflow_complete(local_client, workflow_2.id)

# now get the results of the experiment
experiment_results = GetExperimentResponse.model_validate(
local_client.get(f"/experiments/new/{experiment_id}").json()
)
# make sure it has the info for both workflows
assert len(experiment_results.workflows) == 2
# make sure both workflows are in the experiment, excluding that presigned url again
assert workflow_1_details.model_dump(exclude={"artifacts_download_url"}) in [
w.model_dump(exclude={"artifacts_download_url"}) for w in experiment_results.workflows
]
assert workflow_2_details.model_dump(exclude={"artifacts_download_url"}) in [
w.model_dump(exclude={"artifacts_download_url"}) for w in experiment_results.workflows
]

# get the logs
logs_job_response = local_client.get(f"/workflows/{workflow_1_details.id}/logs")
logs = JobLogsResponse.model_validate(logs_job_response.json())
assert logs.logs is not None
# Very naive way to check whether both of the logs we expect are in here
# This will need to be updated as we improve the log retrieval structure.
assert "Inference results stored at" in logs.logs
assert "Storing evaluation results into" in logs.logs
# assert that inference comes before eval
assert logs.logs.index("Inference results stored at") < logs.logs.index(
"Storing evaluation results into"
)

# delete the experiment
local_client.delete(f"/experiments/new/{experiment_id}")
response = local_client.get(f"/experiments/new/{experiment_id}")
assert response.status_code == 404
# make sure the workflow results also were deleted
response = local_client.get(f"/workflows/{workflow_1_details.id}")
assert response.status_code == 404
response = local_client.get(f"/workflows/{workflow_2_details.id}")
assert response.status_code == 404


def test_experiment_non_existing(local_client: TestClient, dependency_overrides_services):
non_existing_id = "71aaf905-4bea-4d19-ad06-214202165812"
Expand Down
47 changes: 16 additions & 31 deletions lumigator/sdk/lumigator_sdk/experiments.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,50 @@
from http import HTTPMethod
from json import dumps
from uuid import UUID

from lumigator_schemas.experiments import (
ExperimentCreate,
ExperimentIdCreate,
ExperimentIdResponse,
ExperimentResponse,
ExperimentResultDownloadResponse,
ExperimentResultResponse,
GetExperimentResponse,
)
from lumigator_schemas.extras import ListingResponse

from lumigator_sdk.client import ApiClient
from lumigator_sdk.strict_schemas import ExperimentCreate as ExperimentCreateStrict
from lumigator_sdk.strict_schemas import ExperimentIdCreate as ExperimentIdCreateStrict


class Experiments:
EXPERIMENTS_ROUTE = "experiments"
EXPERIMENTS_ROUTE = "experiments/new"

def __init__(self, c: ApiClient):
self.__client = c

def create_experiment(self, experiment: ExperimentCreate) -> ExperimentResponse:
def create_experiment(self, experiment: ExperimentIdCreate) -> ExperimentIdResponse:
"""Creates a new experiment."""
ExperimentCreateStrict.model_validate(ExperimentCreate.model_dump(experiment))
ExperimentIdCreateStrict.model_validate(ExperimentIdCreate.model_dump(experiment))
response = self.__client.get_response(
self.EXPERIMENTS_ROUTE, HTTPMethod.POST, dumps(experiment)
self.EXPERIMENTS_ROUTE, HTTPMethod.POST, experiment.model_dump_json()
)

data = response.json()
return ExperimentResponse(**data)
return ExperimentIdResponse(**data)

def get_experiment(self, experiment_id: UUID) -> ExperimentResponse | None:
def get_experiment(self, experiment_id: str) -> GetExperimentResponse | None:
"""Returns information on the experiment for the specified ID."""
response = self.__client.get_response(f"{self.EXPERIMENTS_ROUTE}/{experiment_id}")

data = response.json()
return ExperimentResponse(**data)
return GetExperimentResponse(**data)

def get_experiments(
self, skip: int = 0, limit: int = 100
) -> ListingResponse[ExperimentResponse]:
"""Returns information on all experiments."""
response = self.__client.get_response(self.EXPERIMENTS_ROUTE)
response = self.__client.get_response(f"{self.EXPERIMENTS_ROUTE}/all")

data = response.json()
return ListingResponse[ExperimentResponse](**data)

def get_experiment_result(self, experiment_id: UUID) -> ExperimentResultResponse | None:
"""Returns the result of the experiment for the specified ID."""
response = self.__client.get_response(f"{self.EXPERIMENTS_ROUTE}/{experiment_id}/result")

data = response.json()
return ExperimentResultResponse(**data)

def get_experiment_result_download(
self, experiment_id: UUID
) -> ExperimentResultDownloadResponse | None:
"""Returns the result of the experiment for the specified ID."""
response = self.__client.get_response(
f"{self.EXPERIMENTS_ROUTE}/{experiment_id}/result/download"
)

data = response.json()
return ExperimentResultDownloadResponse(**data)
def delete_experiment(self, experiment_id: str) -> None:
"""Deletes the experiment for the specified ID."""
self.__client.get_response(f"{self.EXPERIMENTS_ROUTE}/{experiment_id}", HTTPMethod.DELETE)
return None
4 changes: 4 additions & 0 deletions lumigator/sdk/lumigator_sdk/lumigator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from lumigator_sdk.client import ApiClient
from lumigator_sdk.completions import Completions
from lumigator_sdk.experiments import Experiments
from lumigator_sdk.health import Health
from lumigator_sdk.jobs import Jobs
from lumigator_sdk.lm_datasets import Datasets
from lumigator_sdk.models import Models
from lumigator_sdk.workflows import Workflows

# Only retries initial connections
# No HTTP errors are retried
Expand Down Expand Up @@ -54,3 +56,5 @@ def __init__(
self.jobs = Jobs(self.client)
self.datasets = Datasets(self.client)
self.models = Models(self.client)
self.workflows = Workflows(self.client)
self.experiments = Experiments(self.client)
19 changes: 7 additions & 12 deletions lumigator/sdk/lumigator_sdk/strict_schemas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from lumigator_schemas.completions import CompletionResponse
from lumigator_schemas.datasets import DatasetDownloadResponse, DatasetResponse
from lumigator_schemas.experiments import (
ExperimentCreate,
ExperimentIdCreate,
ExperimentResponse,
ExperimentResultDownloadResponse,
ExperimentResultResponse,
)
from lumigator_schemas.extras import HealthResponse, ListingResponse
from lumigator_schemas.jobs import (
Expand All @@ -20,6 +18,7 @@
JobResultResponse,
JobSubmissionResponse,
)
from lumigator_schemas.workflows import WorkflowCreateRequest
from pydantic import ConfigDict


Expand All @@ -35,22 +34,14 @@ class DatasetResponse(DatasetResponse, from_attributes=True):
model_config = ConfigDict(extra="forbid")


class ExperimentCreate(ExperimentCreate):
class ExperimentIdCreate(ExperimentIdCreate):
model_config = ConfigDict(extra="forbid")


class ExperimentResponse(ExperimentResponse, from_attributes=True):
model_config = ConfigDict(extra="forbid")


class ExperimentResultResponse(ExperimentResultResponse, from_attributes=True):
model_config = ConfigDict(extra="forbid")


class ExperimentResultDownloadResponse(ExperimentResultDownloadResponse):
model_config = ConfigDict(extra="forbid")


class HealthResponse(HealthResponse):
model_config = ConfigDict(extra="forbid")

Expand Down Expand Up @@ -101,3 +92,7 @@ class JobResultResponse(JobResultResponse, from_attributes=True):

class JobResultDownloadResponse(JobResultDownloadResponse):
model_config = ConfigDict(extra="forbid")


class WorkflowCreateRequest(WorkflowCreateRequest):
model_config = ConfigDict(extra="forbid")
47 changes: 47 additions & 0 deletions lumigator/sdk/lumigator_sdk/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from http import HTTPMethod

from lumigator_schemas.jobs import JobLogsResponse
from lumigator_schemas.workflows import (
WorkflowCreateRequest,
WorkflowDetailsResponse,
WorkflowResponse,
)

from lumigator_sdk.client import ApiClient
from lumigator_sdk.strict_schemas import WorkflowCreateRequest as WorkflowCreateRequestStrict


class Workflows:
WORKFLOWS_ROUTE = "workflows"

def __init__(self, c: ApiClient):
self.__client = c

def create_workflow(self, workflow: WorkflowCreateRequest) -> WorkflowResponse:
"""Creates a new experiment."""
WorkflowCreateRequestStrict.model_validate(WorkflowCreateRequest.model_dump(workflow))
response = self.__client.get_response(
self.WORKFLOWS_ROUTE, HTTPMethod.POST, workflow.model_dump_json()
)

data = response.json()
return WorkflowResponse(**data)

def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse | None:
"""Returns information on the experiment for the specified ID."""
response = self.__client.get_response(f"{self.WORKFLOWS_ROUTE}/{workflow_id}")

data = response.json()
return WorkflowDetailsResponse(**data)

def get_workflow_logs(self, workflow_id: str) -> JobLogsResponse | None:
"""Returns information on the experiment for the specified ID."""
response = self.__client.get_response(f"{self.WORKFLOWS_ROUTE}/{workflow_id}/logs")

data = response.json()
return JobLogsResponse(**data)

def delete_workflow(self, workflow_id: str) -> None:
"""Deletes the experiment for the specified ID."""
self.__client.get_response(f"{self.WORKFLOWS_ROUTE}/{workflow_id}", HTTPMethod.DELETE)
return None
Loading