From c02406dc0292ee332489f3089a90514bc0dac23e Mon Sep 17 00:00:00 2001 From: Nathan Brake <33383515+njbrake@users.noreply.github.com> Date: Wed, 19 Feb 2025 12:38:56 -0500 Subject: [PATCH] LiteLLM Support (Also DeepSeek in the UI and Llamafile in the UI) (#911) * start litellm support, not yet working * Remove all the completions route code * update walkthrough docs * Working and added Deepseek support! * Deepseek comparison * Explanation of pre-reqs * Apply suggestions from code review Co-authored-by: Hareesh Signed-off-by: Nathan Brake <33383515+njbrake@users.noreply.github.com> * merged * First attempt at a parametrized JobCreate * Replace templates with pydantic models * Adapt SDK and SDK tests * Fix sdk unit tests * Fix notebook tests * Fix tests * Fix job definition in workflows * Fix job unit test * Start a default workflow for experiments * Rebase to main * delete the evaluator * Rebase to main * First attempt at a parametrized JobCreate * Replace templates with pydantic models * Adapt SDK and SDK tests * Fix sdk unit tests * Fix notebook tests * Fix tests * Fix job definition in workflows * Fix job unit test * Start a default workflow for experiments * Rebase to main * Align with routes in main * Move to experiments new endpoint * Streamline new experiments api * Add dataset/samples to experiment, test background tasks * Factor out experiment formatting * Support Llamafile * cleanup for PR * Clean up for PR * Apply suggestions from code review Co-authored-by: Dimitris Poulopoulos Signed-off-by: Nathan Brake <33383515+njbrake@users.noreply.github.com> * Updates from PR, remove dup models.json file * fix enable_tqdm schema * Update naming * Fix a bug with status * Exclude unset and none * job unit test and jupy notebook patch * job unit test and jupy notebook patch * final pass through the code * clean notebook output --------- Signed-off-by: Nathan Brake <33383515+njbrake@users.noreply.github.com> Co-authored-by: Hareesh Co-authored-by: Javier Torres Co-authored-by: Dimitris Poulopoulos --- README.md | 7 +- docs/source/conceptual-guides/new-endpoint.md | 369 --- docs/source/get-started/installation.md | 2 +- docs/source/get-started/quickstart.md | 9 +- docs/source/get-started/suggested-models.md | 3 +- docs/source/user-guides/inference.md | 16 +- lumigator/backend/backend/api/routes/jobs.py | 3 +- lumigator/backend/backend/models.yaml | 57 +- lumigator/backend/backend/services/jobs.py | 142 +- .../backend/backend/services/workflows.py | 3 +- lumigator/backend/backend/tests/conftest.py | 18 +- .../tests/data/health_job_metadata.json | 2 +- .../tests/data/health_job_metadata_ray.json | 2 +- .../backend/backend/tests/data/models.json | 103 - .../api/routes/test_api_workflows.py | 10 +- .../tests/unit/api/routes/test_models.py | 19 +- .../tests/unit/services/test_job_service.py | 33 +- lumigator/backend/backend/tracking/mlflow.py | 8 +- lumigator/backend/pyproject.toml | 2 - lumigator/backend/uv.lock | 80 - .../experiments/LExperimentResults.vue | 4 +- .../components/experiments/LModelCards.vue | 10 +- .../src/helpers/retrieveEntrypoint.ts | 12 +- .../frontend/src/sdk/experimentsService.ts | 2 +- .../frontend/src/stores/experimentsStore.ts | 4 +- lumigator/frontend/src/types/Model.ts | 6 +- lumigator/frontend/src/types/Workflow.ts | 3 +- .../evaluator/tests/data/config_full_hf.json | 5 +- lumigator/jobs/inference/inference.py | 35 +- lumigator/jobs/inference/inference_config.py | 15 +- lumigator/jobs/inference/model_clients.py | 181 +- lumigator/jobs/inference/paths.py | 5 - lumigator/jobs/inference/requirements.txt | 3 +- lumigator/jobs/inference/requirements_cpu.txt | 3 +- lumigator/jobs/inference/schemas.py | 9 +- .../inference/tests/data/config_full_api.json | 3 +- .../inference/tests/data/config_full_hf.json | 2 +- .../jobs/inference/tests/test_configs.py | 2 +- lumigator/jobs/inference/utils.py | 17 - lumigator/schemas/lumigator_schemas/jobs.py | 3 +- lumigator/schemas/lumigator_schemas/models.py | 26 +- .../schemas/lumigator_schemas/workflows.py | 3 +- lumigator/sdk/lumigator_sdk/client.py | 2 +- lumigator/sdk/tests/conftest.py | 30 +- .../sdk/tests/data/experiment-post-all.json | 3 +- .../tests/data/experiment-post-simple.json | 3 +- .../sdk/tests/data/job-all-inference.json | 1 + .../sdk/tests/data/job-extra-inference.json | 1 + .../sdk/tests/data/job-minimal-inference.json | 3 +- lumigator/sdk/tests/data/job-submit-resp.json | 2 +- lumigator/sdk/tests/data/job.json | 2 +- lumigator/sdk/tests/data/jobs-submit.json | 4 +- lumigator/sdk/tests/data/models.json | 109 - .../sdk/tests/integration/test_scenarios.py | 9 +- lumigator/sdk/tests/unit/test_models.py | 10 +- lumigator/sdk/uv.lock | 4 +- notebooks/README.md | 2 +- notebooks/assets/model_info.csv | 18 +- notebooks/pyproject.toml | 12 + notebooks/walkthrough.ipynb | 21 +- pyproject.toml | 4 + uv.lock | 2712 ++++++++++++++++- 62 files changed, 2941 insertions(+), 1252 deletions(-) delete mode 100644 docs/source/conceptual-guides/new-endpoint.md delete mode 100644 lumigator/backend/backend/tests/data/models.json delete mode 100644 lumigator/sdk/tests/data/models.json create mode 100644 notebooks/pyproject.toml diff --git a/README.md b/README.md index 06c41fb14..b03538f51 100644 --- a/README.md +++ b/README.md @@ -59,8 +59,8 @@ services networked together to make up all the components of the Lumigator appli > uses SQLite for this purpose. > [!NOTE] -If you want to evaluate against LLM APIs like OpenAI and Mistral, you need to set the appropriate -environment variables: `OPENAI_API_KEY` or `MISTRAL_API_KEY`. Refer to the +If you want to evaluate against LLM APIs like OpenAI/Mistral/Deepseek, you need to set the appropriate +environment variables: `OPENAI_API_KEY` or `MISTRAL_API_KEY` or `DEEPSEEK_API_KEY`. Refer to the [troubleshooting section](https://mozilla-ai.github.io/lumigator/get-started/troubleshooting.html#tokens-api-keys-not-set) in our documentation for more details. @@ -87,10 +87,11 @@ To start Lumigator locally, follow these steps: **Important: Continue the next steps in this same terminal.** -1. If you intend to use Mistral API or OpenAI API, use that same terminal and run: +1. If you intend to use Mistral API or OpenAI API or Deepseek API, use that same terminal and run: ```bash export MISTRAL_API_KEY=your_mistral_api_key export OPENAI_API_KEY=your_openai_api_key + export DEEPSEEK_API_KEY=your_deepseek_api_key ``` **Important: Continue the next steps in this same terminal.** diff --git a/docs/source/conceptual-guides/new-endpoint.md b/docs/source/conceptual-guides/new-endpoint.md deleted file mode 100644 index f3eff19bc..000000000 --- a/docs/source/conceptual-guides/new-endpoint.md +++ /dev/null @@ -1,369 +0,0 @@ -# Creating a New Endpoint - -The examples in the [Understanding Lumigator Endpoints](https://mozilla-ai.github.io/lumigator/conceptual-guides/endpoints.html) -guide show the main pieces of code involved in writing a new endpoint. Let us take a toy example, -which is creating an endpoint which given a task in the form of a string (e.g. "summarization") will -return a list of model names (string URIs) to be evaluated for that task. As you do not have a table -with this information in our database yet, you will also create a method to actually *store* this -list in a table. - -## Step 1: Create a static endpoint - -As a first step, you'll create a static endpoint which, regardless of the input, will always return -the same list of models. This is a good way to start as it allows you to wire the endpoint, see it -in the docs, and make sure its schema is correct, without the need to access the database yet. - -### 1.1. Write the router code - -The following code implements a barebone version of our endpoint, one which does not require any -connection to the DB nor a specific schema definition. It defines one method `get_model_list` that -always returns a `ListingResponse[str]` object, basically a dictionary containing a list of `items` -of a predefined type (strings in our case) and a `total` field with the number of items. Note that -`task_id` is not even used, but this will change as soon as we want to get a list which is -task-specific. - -```python -from fastapi import APIRouter -from schemas.extras import ListingResponse - - -router = APIRouter() - -@router.get("/{task_id}/models") -def get_model_list(task_id: str) -> ListingResponse[str]: - """Get list of models for a given task.""" - return_data = { - "total": 3, - "items": [ - "hf://facebook/bart-large-cnn", - "mistral://open-mistral-7b", - "oai://gpt-4-turbo", - ], - } - return ListingResponse[str].model_validate(return_data) -``` - -Being this the code for a new route, you will save it in `backend/api/routes/tasks.py`. - -### 1.2. Add the route to `router.py` with the appropriate tags - -The following step is adding the new route to -{{ '[`backend/api/router.py`](https://github.com/mozilla-ai/lumigator/blob/{}/lumigator/backend/backend/api/router.py)'.format(commit_id) }}. -The code below shows the updated file with comments next to the two lines marked below as **NEW**: - -```python -from fastapi import APIRouter - - -from backend.api.routes import ( - datasets, - experiments, - health, - tasks, # NEW -) -from backend.api.tags import Tags - -API_V1_PREFIX = "/api/v1" - -api_router = APIRouter(prefix=API_V1_PREFIX) -api_router.include_router(health.router, prefix="/health", tags=[Tags.HEALTH]) -api_router.include_router(datasets.router, prefix="/datasets", tags=[Tags.DATASETS]) -api_router.include_router(experiments.router, prefix="/experiments", tags=[Tags.EXPERIMENTS]) -api_router.include_router(tasks.router, prefix="/tasks", tags=[Tags.TASKS]) # NEW -``` - -Also note that we are specifying some `Tags.TASKS` which have not been defined yet! Open -{{ '[`backend/api/tags.py`](https://github.com/mozilla-ai/lumigator/blob/{}/lumigator/backend/backend/api/tags.py)'.format(commit_id) }} -and add the sections marked below as **NEW**: - -```python -from enum import Enum - - -class Tags(str, Enum): - HEALTH = "health" - DATASETS = "datasets" - EXPERIMENTS = "experiments" - TASKS = "tasks" ### NEW - - -TAGS_METADATA = [ - { - "name": Tags.HEALTH, - "description": "Health check for the application.", - }, - { - "name": Tags.DATASETS, - "description": "Upload and download datasets.", - }, - { - "name": Tags.EXPERIMENTS, - "description": "Create and manage evaluation experiments.", - }, - # NEW TAGS BELOW - { - "name": Tags.TASKS, - "description": "Mapping model lists to tasks.", - }, -] -``` - -### 1.3. Test - -If you're running Lumigator locally, connect to [`http://localhost:8000/docs`](http://localhost:8000/docs) -and you should see the following: - -![Docs API description](../../assets/tasks_api_descr.png) - -If you click on `Try it out`, add any value for `task_id` and then click on `Execute` you should get -the following response: - -![Docs API response](../../assets/tasks_response.png) - -## Step 2: Wire the endpoint to the DB - -You have a new running endpoint! Too bad it always return the same identical data... Let's fix this -by connecting it to the database. - -### 2.2. Define schema - -To have a more useful endpoint, let us first update the schema with a few additional fields. The -following goes into`schemas/tasks.py`: - -```python -import datetime -from uuid import UUID - -from pydantic import BaseModel - - -class TaskCreate(BaseModel): - name: str - description: str = "" - models: list[str] - -class TaskResponse(BaseModel, from_attributes=True): - id: UUID - name: str - description: str - created_at: datetime.datetime - models: list[str] -``` - -If you look at `TaskResponse`, you can see that we added an `id` to uniquely identify the task, a -`name` that we can show e.g. in a list for users to choose from and a `description`. The field -`created_at` will be automatically filled when creating a new record. The list of `models` still -appears, but we removed the count as we'll likely provide short lists and we can expect to be able -to get their length programmatically. - -The `TaskCreate` class, instead, will be used to define the input to the `create_task` method in the -API (the fields `id` and `created_at` are not necessary, as they'll be created automatically by the -database). - -### 2.2. Define repositories and records - -The code for a new repository (to be stored in `backend/repositories/tasks.py`) is quite standard: - -```python -from sqlalchemy.orm import Session - -from backend.records.tasks import TaskRecord -from backend.repositories.base import BaseRepository - - -class TaskRepository(BaseRepository[TaskRecord]): - def __init__(self, session: Session): - super().__init__(TaskRecord, session) -``` - -This does not usually change much as long as you are fine with the base methods provided by the -{{ '[`BaseRepository`](https://github.com/mozilla-ai/lumigator/blob/{}/lumigator/backend/backend/repositories/base.py)'.format(commit_id) }} -class. - -The `TaskRepository` is a repository that allows to run the set of methods defined in the -`BaseRepository` on the table defined by `TaskRecord`. You can define a `TaskRecord` in -`backend/records/tasks.py` as follows: - -```python -from sqlalchemy.orm import Mapped, mapped_column - -from backend.records.base import BaseRecord -from backend.records.mixins import CreatedAtMixin, NameDescriptionMixin - - -class TaskRecord(BaseRecord, NameDescriptionMixin, CreatedAtMixin): - __tablename__ = "tasks" - models: Mapped[list[str]] = mapped_column(nullable=False) -``` - -Similarly to what you saw before for `DatasetRecord`, `TaskRecord` inherits from -{{ '[`BaseRecord`](https://github.com/mozilla-ai/lumigator/blob/{}/lumigator/backend/backend/records/base.py)'.format(commit_id) }} -the property of having an `id` primary key. In addition to that, it inherits `name` and -`description` from `NameDescriptionMixin` and `created_at` from `CreatedAtMixin`. The only field -that we need to specify manually is `models`, a non-null column holding a list of strings. - -As SQLAlchemy does not have a built-in mapping from `list[str]`, you also need to update -`BaseRecord` to provide one explicitly by changing the following: - -```python -type_annotation_map = {dict[str, Any]} -``` - -to - -```python -type_annotation_map = {dict[str, Any]: JSON, list[str]: JSON} -``` - -### 2.3. Save DB-accessing methods into a TasksService - -Now that you have an abstraction for the `tasks` table, you can use it to implement the different -methods needed to manage tasks. You'll do it inside `backend/services/tasks.py`: - -```python -from uuid import UUID - -from fastapi import HTTPException, status - -from backend.records.tasks import TaskRecord -from backend.repositories.tasks import TaskRepository -from schemas.extras import ListingResponse -from schemas.tasks import TaskCreate, TaskResponse - - -class TaskService: - def __init__(self, tasks_repo: TaskRepository): - self.tasks_repo = tasks_repo - - def _raise_not_found(self, task_id: UUID) -> None: - raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task '{task_id}' not found.") - - def _get_task_record(self, task_id: UUID) -> TaskRecord: - record = self.tasks_repo.get(task_id) - - if record is None: - self._raise_not_found(task_id) - return record - - def get_task(self, task_id: UUID) -> TaskResponse: - record = self._get_task_record(task_id) - return TaskResponse.model_validate(record) - - def create_task(self, request: TaskCreate) -> TaskResponse: - # Create DB record - record = self.tasks_repo.create( - name=request.name, description=request.description, models=request.models - ) - return TaskResponse.model_validate(record) - - def delete_task(self, task_id: UUID) -> None: - record = self._get_task_record(task_id) - # Delete DB record - self.tasks_repo.delete(record.id) - - def list_tasks(self, skip: int = 0, limit: int = 100) -> ListingResponse[TaskResponse]: - total = self.tasks_repo.count() - records = self.tasks_repo.list(skip, limit) - return ListingResponse( - total=total, - items=[TaskResponse.model_validate(x) for x in records], - ) -``` - -The main methods implemented here are: - -* `create_task`: uses the `create` method in the task repository (a method inherited by - `BaseRepository`) to save a new task record. The `request` input parameter is defined in the - `TaskCreate` schema and the output is a `TaskResponse`. - -* `delete_task`: given a task `UUID`, deletes the corresponding record from the table. Note that, as - all the other methods that rely on `_get_task_record`, an HTTP 404 exception is thrown if a - matching record is not found in the table. - -* `get_task`: given a task `UUID`, return the corresponding record (as a `TaskResponse`) - -* `list_tasks`: returns a `ListingResponse` of `TaskResponse` elements (i.e. a list of tasks stored - in the table). - -Note how similar this is to some of the other services (e.g. `DatasetService`): you can expect this -from services which only deal with the DB as the main operations you'll do are those who operate on -a table (create, delete, get, list, etc). You will likely see a different behavior in more advanced -endpoints (e.g. those which involve running ray jobs), but we'll discuss that in another tutorial. - -### 2.4. Inject dependencies into a TaskService - -As `TaskService` depends on the existence of a database, we should inject a dependency on a DB session. -To do this, add the following code to -{{ '[`backend/api/deps.py`](https://github.com/mozilla-ai/lumigator/blob/{}/lumigator/backend/backend/api/deps.py)'.format(commit_id) }}: - -```python -def get_task_service(session: DBSessionDep) -> TaskService: - task_repo = TaskRepository(session) - return TaskService(task_repo) - - -TaskServiceDep = Annotated[TaskService, Depends(get_task_service)] -``` - -### 2.5. Update routes - -You are almost there! The last thing you need to do is update the `backend/api/routes/tasks.py`file -with new code which will map API requests to `TaskService` methods: - -```python -from uuid import UUID - -from fastapi import APIRouter, status - -from backend.api.deps import TaskServiceDep -from schemas.extras import ListingResponse -from schemas.tasks import TaskCreate, TaskResponse - -router = APIRouter() - - -@router.post("/", status_code=status.HTTP_201_CREATED) -def create_task(service: TaskServiceDep, request: TaskCreate) -> TaskResponse: - return service.create_task(request) - - -@router.get("/{task_id}") -def get_task(service: TaskServiceDep, task_id: UUID) -> TaskResponse: - return service.get_task(task_id) - - -@router.delete("/{task_id}", status_code=status.HTTP_204_NO_CONTENT) -def delete_task(service: TaskServiceDep, task_id: UUID) -> None: - service.delete_task(task_id) - - -@router.get("/") -def list_tasks( - service: TaskServiceDep, - skip: int = 0, - limit: int = 100, -) -> ListingResponse[TaskResponse]: - return service.list_tasks(skip, limit) -``` - -### 2.6. Test - -To test your new endpoint, connect to [http://localhost:8000/docs](http://localhost/docs). You -should see a new section like the following one: - -![Docs API description](../../assets/tasks_api_methods.png) - -Here is an example for the creation of a new task: - -![Docs API create](../../assets/tasks_api_new.png) - -And here is the format of a task list (to which we have also added a `summarization_eval` task): - -![Docs API task list](../../assets/tasks_api_list.png) - -## Final checks and reference code - -Creating a new endpoint in Lumigator requires some prior knowledge about the system and due to how -the code is structured you will have to edit many different files (see the picture below). Once you -get the gist of it, though, the task can be quite straightforward (albeit a bit repetitive), -especially if the endpoint mostly has to deal with database operations. diff --git a/docs/source/get-started/installation.md b/docs/source/get-started/installation.md index e7675d10b..df77c1135 100644 --- a/docs/source/get-started/installation.md +++ b/docs/source/get-started/installation.md @@ -39,7 +39,7 @@ Lumigator. ```{note} If you want to evaluate against LLM APIs like OpenAI and Mistral, you need to set the appropriate -environment variables: `OPENAI_API_KEY` or `MISTRAL_API_KEY`. Refer to the +environment variables: `OPENAI_API_KEY` or `MISTRAL_API_KEY` or `DEEPSEEK_API_KEY`. Refer to the [troubleshooting section](../get-started/troubleshooting.md) for more details. ``` diff --git a/docs/source/get-started/quickstart.md b/docs/source/get-started/quickstart.md index 3475efd0c..777efd0e0 100644 --- a/docs/source/get-started/quickstart.md +++ b/docs/source/get-started/quickstart.md @@ -116,7 +116,8 @@ Set the following variables: ```console user@host:~/lumigator$ export EVAL_NAME="test_run_hugging_face" \ EVAL_DESC="Test run for Huggingface model" \ - EVAL_MODEL="hf://facebook/bart-large-cnn" \ + EVAL_MODEL="facebook/bart-large-cnn" \ + EVAL_MODEL_PROVIDER="hf" \ EVAL_DATASET="$(curl -s http://localhost:8000/api/v1/datasets/ | jq -r '.items | .[0].id')" \ EVAL_MAX_SAMPLES="10" ``` @@ -127,9 +128,10 @@ user@host:~/lumigator$ export JSON_STRING=$(jq -n \ --arg name "$EVAL_NAME" \ --arg desc "$EVAL_DESC" \ --arg model "$EVAL_MODEL" \ + --arg provider "$EVAL_PROVIDER" \ --arg dataset_id "$EVAL_DATASET" \ --arg max_samples "$EVAL_MAX_SAMPLES" \ - '{name: $name, description: $desc, model: $model, dataset: $dataset_id, max_samples: $max_samples}' + '{name: $name, description: $desc, model: $model, provider: $provider, dataset: $dataset_id, max_samples: $max_samples}' ) ``` @@ -158,7 +160,7 @@ from lumigator_schemas.jobs import JobType, JobEvalCreate dataset_id = datasets.items[-1].id -models = ['hf://facebook/bart-large-cnn',] +models = ['facebook/bart-large-cnn',] # set this value to limit the evaluation to the first max_samples items (0=all) max_samples = 10 @@ -171,6 +173,7 @@ for model in models: name=team_name, description="Test", model=model, + provider="hf", dataset=str(dataset_id), max_samples=max_samples ) diff --git a/docs/source/get-started/suggested-models.md b/docs/source/get-started/suggested-models.md index 583a8b693..9a5ffc6e4 100644 --- a/docs/source/get-started/suggested-models.md +++ b/docs/source/get-started/suggested-models.md @@ -22,7 +22,8 @@ user@host:~/lumigator$ curl -s http://localhost:8000/api/v1/models/summarization "items": [ { "name": "facebook/bart-large-cnn", - "uri": "hf://facebook/bart-large-cnn", + "model": "facebook/bart-large-cnn", + "provider": "hf", "description": "BART is a large-sized model fine-tuned on the CNN Daily Mail dataset.", "info": { "parameter_count": "406M", diff --git a/docs/source/user-guides/inference.md b/docs/source/user-guides/inference.md index e475390b9..6121c18be 100644 --- a/docs/source/user-guides/inference.md +++ b/docs/source/user-guides/inference.md @@ -56,12 +56,14 @@ Refer to the [troubleshooting section](../get-started/troubleshooting.md) for mo # Create and submit an inference job name = "bart-summarization-run" - model = "hf://facebook/bart-large-cnn" + model = "facebook/bart-large-cnn" + provider = "hf" task = "summarization" job_args = jobs.JobInferenceCreate( name=name, model=model, + provider=provider, dataset=dataset.id, task=task, ) @@ -95,12 +97,12 @@ Refer to the [troubleshooting section](../get-started/troubleshooting.md) for mo Different models can be chosen for summarization. The information about those models can be retrieved via the `http://:8000/api/v1/models/summarization` endpoint. It contains the following information for each model: -* `name`: an identification name for the model -* `uri`: a URI specifying how to use the model. The following protocols are supported: - * `hf://`: direct model usage in an [HF pipeline](https://huggingface.co/docs/transformers/en/main_classes/pipelines) - * `llamafile://`: model set up with [`llamafile`](https://github.com/Mozilla-Ocho/llamafile) on its default host and port - * `oai://`: OpenAI or compatible external API; needs a value for the environment variable OPENAI_API_KEY with a valid key - * `mistral://`: Mistral or compatible external API; needs a value for the environment variable MISTRAL_API_KEY with a valid key +* `display_name`: an identification name for the model +* `model`: The model to use, e.g. `facebook/bart-large-cnn` +* `provider`: a URI specifying how and where to use the model. The following protocols are supported: + * `hf`: direct model usage in an [HF pipeline](https://huggingface.co/docs/transformers/en/main_classes/pipelines) + * Any protocol supported by [LiteLLM](https://docs.litellm.ai/docs/providers). For example, `openai/`, `mistral/`, `deepseek/`, etc. You will need to have set the correct API keys for them, e.g. OPENAI_API_KEY, or MISTRAL_API_KEY, or DEEPSEEK_API_KEY +* `base_url`: this field can be filled out if running a custom model that uses the openai protocol. For example, llamafile is generally hosted on your computer at `http://localhost:8080/v1`. * `website_url`: a link to a web page with more information about the model * `description`: a short description about the model * `info`: a map containing information about the model like parameter count or model size diff --git a/lumigator/backend/backend/api/routes/jobs.py b/lumigator/backend/backend/api/routes/jobs.py index d1e921cfe..f4858d716 100644 --- a/lumigator/backend/backend/api/routes/jobs.py +++ b/lumigator/backend/backend/api/routes/jobs.py @@ -76,7 +76,8 @@ def create_annotation_job( See more: https://blog.mozilla.ai/lets-build-an-app-for-evaluating-llms/ """ inference_job_create_config_dict = job_create_request.job_config.dict() - inference_job_create_config_dict["model"] = "hf://facebook/bart-large-cnn" + inference_job_create_config_dict["model"] = "facebook/bart-large-cnn" + inference_job_create_config_dict["provider"] = "hf" inference_job_create_config_dict["output_field"] = "ground_truth" inference_job_create_config_dict["store_to_dataset"] = True inference_job_create_config_dict["job_type"] = JobType.INFERENCE diff --git a/lumigator/backend/backend/models.yaml b/lumigator/backend/backend/models.yaml index 06742f36e..68c1e4467 100644 --- a/lumigator/backend/backend/models.yaml +++ b/lumigator/backend/backend/models.yaml @@ -1,5 +1,6 @@ -- name: facebook/bart-large-cnn - uri: hf://facebook/bart-large-cnn +- display_name: facebook/bart-large-cnn + model: facebook/bart-large-cnn + provider: hf website_url: https://huggingface.co/facebook/bart-large-cnn description: BART is a large-sized model fine-tuned on the CNN Daily Mail dataset. info: @@ -15,8 +16,9 @@ no_repeat_ngram_size: 3 num_beams: 4 -- name: Falconsai/text_summarization - uri: hf://Falconsai/text_summarization +- display_name: Falconsai/text_summarization + model: Falconsai/text_summarization + provider: hf website_url: https://huggingface.co/Falconsai/text_summarization description: A fine-tuned variant of the T5 transformer model, designed for the task of text summarization. info: @@ -32,8 +34,9 @@ no_repeat_ngram_size: 3 num_beams: 4 -- name: gpt-4o-mini - uri: oai://gpt-4o-mini +- display_name: gpt-4o-mini + model: gpt-4o-mini + provider: openai website_url: https://platform.openai.com/docs/models#gpt-4o-mini description: OpenAI's GPT-4o-mini model. requirements: @@ -41,8 +44,9 @@ tasks: - summarization: -- name: gpt-4o - uri: oai://gpt-4o +- display_name: gpt-4o + model: gpt-4o + provider: openai website_url: https://platform.openai.com/docs/models#gpt-4o description: OpenAI's GPT-4o model. requirements: @@ -50,19 +54,42 @@ tasks: - summarization: -- name: open-mistral-7b - uri: mistral://open-mistral-7b - website_url: https://mistral.ai/technology/#models - description: Mistral's 7B model. +- display_name: deepseek-R1 + model: deepseek-reasoner + provider: deepseek + website_url: https://deepseek.ai/ + description: DeepSeek's R1 model, hosted by deepseek + requirements: + - api_key + tasks: + - summarization: + +- display_name: deepseek-V3 + model: deepseek-chat + provider: deepseek + website_url: https://deepseek.ai/ + description: DeepSeek's V3 model, hosted by deepseek requirements: - api_key tasks: - summarization: -- name: mistralai/Mistral-7B-Instruct-v0.2 - uri: llamafile://mistralai/Mistral-7B-Instruct-v0.2 +- display_name: open-mistral-7b + model: open-mistral-7b + provider: mistral website_url: https://mistral.ai/technology/#models - description: A llamafile package of Mistral's 7B Instruct model. + description: Mistral's 7B model, hosted by mistral + requirements: + - api_key + tasks: + - summarization: + +- display_name: Llamafile/Mistral-7B-Instruct-v0.2 + model: mistralai/Mistral-7B-Instruct-v0.2 + provider: openai + base_url: http://localhost:8080/v1 + website_url: https://huggingface.co/Mozilla/Mistral-7B-Instruct-v0.2-llamafile + description: A llamafile package of Mistral's 7B Instruct model. Assumes that llamafile is running on the same system where the Ray cluster is located. info: parameter_count: 7.24B tensor_type: BF16 diff --git a/lumigator/backend/backend/services/jobs.py b/lumigator/backend/backend/services/jobs.py index f441d5cb7..5963b77e2 100644 --- a/lumigator/backend/backend/services/jobs.py +++ b/lumigator/backend/backend/services/jobs.py @@ -338,32 +338,6 @@ def add_background_task(self, background_tasks: BackgroundTasks, task: callable, """Adds a background task to the background tasks queue.""" background_tasks.add_task(task, *args) - def _set_model_type(self, request: JobCreate) -> str: - """Sets model URL based on protocol address""" - if request.job_config.model.startswith("oai://"): - model_url = settings.OAI_API_URL - elif request.job_config.model.startswith("mistral://"): - model_url = settings.MISTRAL_API_URL - elif request.job_config.model.startswith("ds://"): - model_url = settings.DEEPSEEK_API_URL - else: - model_url = request.job_config.model_url - - return model_url - - def _validate_config(self, job_type: str, config_template: str, config_params: dict): - if job_type == JobType.INFERENCE: - InferenceJobConfig.model_validate_json(config_template.format(**config_params)) - elif job_type == JobType.EVALUATION: - EvalJobConfig.model_validate_json(config_template.format(**config_params)) - else: - loguru.logger.info(f"Validation for job type {job_type} not yet supported.") - - # The end result should be that InferenceJobConfig is actually JobInferenceConfig - # (resp. Eval) - # For the moment, something will convert one into the other, and we'll decide where - # to put this. The jobs should ideally have no dependency towards the backend. - def generate_inference_job_config(self, request: JobCreate, record_id: UUID, dataset_path: str, storage_path: str): # TODO Move to a custom validator in the schema if request.job_config.task == "text-generation" and not request.job_config.system_prompt: @@ -378,86 +352,34 @@ def generate_inference_job_config(self, request: JobCreate, record_id: UUID, dat output_field=request.job_config.output_field or "predictions", ), ) - # Maybe use just the protocol to decide? - match request.job_config.model: - case "oai://gpt-4o-mini" | "oai://gpt-4o": - job_config.inference_server = InferenceServerConfig( - base_url=self._set_model_type(request), - engine=request.job_config.model, - # FIXME Inferences may not always be summarizations! - system_prompt=request.job_config.system_prompt or settings.DEFAULT_SUMMARIZER_PROMPT, - max_retries=3, - ) - job_config.params = SamplingParameters( - max_tokens=request.job_config.max_tokens, - frequency_penalty=request.job_config.frequency_penalty, - temperature=request.job_config.temperature, - top_p=request.job_config.top_p, - ) - case "mistral://open-mistral-7b": - job_config.inference_server = InferenceServerConfig( - base_url=self._set_model_type(request), - engine=request.job_config.model, - system_prompt=request.job_config.system_prompt or settings.DEFAULT_SUMMARIZER_PROMPT, - max_retries=3, - ) - job_config.params = SamplingParameters( - max_tokens=request.job_config.max_tokens, - frequency_penalty=request.job_config.frequency_penalty, - temperature=request.job_config.temperature, - top_p=request.job_config.top_p, - ) - case "ds://deepseek-chat" | "ds://deepseek-reasoner": - job_config.inference_server = InferenceServerConfig( - base_url=self._set_model_type(request), - engine=request.job_config.model, - system_prompt=request.job_config.system_prompt or settings.DEFAULT_SUMMARIZER_PROMPT, - max_retries=3, - ) - job_config.params = SamplingParameters( - max_tokens=request.job_config.max_tokens, - frequency_penalty=request.job_config.frequency_penalty, - temperature=request.job_config.temperature, - top_p=request.job_config.top_p, - ) - case "llamafile://mistralai/mistral-7b-instruct-v0.2": - job_config.inference_server = InferenceServerConfig( - base_url=self._set_model_type(request), - engine=request.job_config.model, - system_prompt=request.job_config.system_prompt or settings.DEFAULT_SUMMARIZER_PROMPT, - max_retries=3, - ) - job_config.params = SamplingParameters( - max_tokens=request.job_config.max_tokens, - frequency_penalty=request.job_config.frequency_penalty, - temperature=request.job_config.temperature, - top_p=request.job_config.top_p, - ) - case _: - if request.job_config.model_url and request.job_config.model_url.startswith("http://"): - job_config.inference_server = InferenceServerConfig( - base_url=self._set_model_type(request), - engine=request.job_config.model, - system_prompt=request.job_config.system_prompt or settings.DEFAULT_SUMMARIZER_PROMPT, - max_retries=3, - ) - job_config.params = SamplingParameters( - max_tokens=request.job_config.max_tokens, - frequency_penalty=request.job_config.frequency_penalty, - temperature=request.job_config.temperature, - top_p=request.job_config.top_p, - ) - else: - job_config.hf_pipeline = HfPipelineConfig( - model_uri=request.job_config.model, - task=request.job_config.task, - accelerator=request.job_config.accelerator, - revision=request.job_config.revision, - use_fast=request.job_config.use_fast, - trust_remote_code=request.job_config.trust_remote_code, - torch_dtype=request.job_config.torch_dtype, - max_new_tokens=500, - ) + if request.job_config.provider == "hf": + # Custom logic: if provider is hf, we run the hf model inside the ray job + job_config.hf_pipeline = HfPipelineConfig( + model_name_or_path=request.job_config.model, + task=request.job_config.task, + accelerator=request.job_config.accelerator, + revision=request.job_config.revision, + use_fast=request.job_config.use_fast, + trust_remote_code=request.job_config.trust_remote_code, + torch_dtype=request.job_config.torch_dtype, + max_new_tokens=500, + ) + else: + # It will be a pass through to LiteLLM + job_config.inference_server = InferenceServerConfig( + base_url=request.job_config.base_url if request.job_config.base_url else None, + model=request.job_config.model, + provider=request.job_config.provider, + system_prompt=request.job_config.system_prompt or settings.DEFAULT_SUMMARIZER_PROMPT, + max_retries=3, + ) + job_config.params = SamplingParameters( + max_tokens=request.job_config.max_tokens, + frequency_penalty=request.job_config.frequency_penalty, + temperature=request.job_config.temperature, + top_p=request.job_config.top_p, + ) + return job_config def generate_evaluation_job_config(self, request: JobCreate, record_id: UUID, dataset_path: str, storage_path: str): @@ -486,11 +408,7 @@ def create_job( # Create a db record for the job # To find the experiment that a job belongs to, # we'd use https://mlflow.org/docs/latest/python_api/mlflow.client.html#mlflow.client.MlflowClient.search_runs - record = self.job_repo.create( - name=request.name, - description=request.description, - job_type=job_type, - ) + record = self.job_repo.create(name=request.name, description=request.description, job_type=job_type) # TODO defer to specific job if job_type == JobType.INFERENCE and not request.job_config.output_field: @@ -532,7 +450,7 @@ def create_job( settings.inherit_ray_env(runtime_env_vars) # set num_gpus per worker (zero if we are just hitting a service) - if job_type == JobType.INFERENCE and not request.job_config.model.startswith("hf://"): + if job_type == JobType.INFERENCE and not request.job_config.provider == "hf": worker_gpus = job_settings["ray_worker_gpus_fraction"] else: worker_gpus = job_settings["ray_worker_gpus"] diff --git a/lumigator/backend/backend/services/workflows.py b/lumigator/backend/backend/services/workflows.py index 0f21d8095..d80615f54 100644 --- a/lumigator/backend/backend/services/workflows.py +++ b/lumigator/backend/backend/services/workflows.py @@ -65,7 +65,8 @@ async def _run_inference_eval_pipeline( # JobInferenceCreate and one JobEvalCreate job_infer_config = JobInferenceConfig( model=request.model, - model_url=request.model_url, + provider=request.provider, + base_url=request.base_url, output_field=request.inference_output_field, system_prompt=request.system_prompt, # we store the dataset explicitly below, so it gets queued before eval diff --git a/lumigator/backend/backend/tests/conftest.py b/lumigator/backend/backend/tests/conftest.py index 4850050c9..0bf7a17f5 100644 --- a/lumigator/backend/backend/tests/conftest.py +++ b/lumigator/backend/backend/tests/conftest.py @@ -39,9 +39,7 @@ from backend.settings import BackendSettings, settings from backend.tests.fakes.fake_s3 import FakeS3Client -TEST_CAUSAL_MODEL = "hf://hf-internal-testing/tiny-random-LlamaForCausalLM" -TEST_SUMMARY_MODEL = "hf://hf-internal-testing/tiny-random-T5ForConditionalGeneration" -TEST_INFER_MODEL = "hf://hf-internal-testing/tiny-random-t5" +TEST_CAUSAL_MODEL = "hf-internal-testing/tiny-random-LlamaForCausalLM" # Maximum amount of polls done to check if a job has finished # (status FAILED or SUCCEEDED) in fucntion tests. @@ -344,11 +342,6 @@ def resources_dir() -> Path: return Path(__file__).parent / "data" -@pytest.fixture(scope="session") -def json_data_models(resources_dir) -> Path: - return resources_dir / "models.json" - - @pytest.fixture(scope="session") def json_ray_version(resources_dir) -> Path: return resources_dir / "ray_version.json" @@ -406,10 +399,11 @@ def create_job_config() -> JobConfig: conf_args = { "name": "test_run_hugging_face", "description": "Test run for Huggingface model", - "model": "hf://facebook/bart-large-cnn", + "model": "facebook/bart-large-cnn", + "provider": "hf", "dataset": "016c1f72-4604-48a1-b1b1-394239297e29", "max_samples": 10, - "model_url": "hf://facebook/bart-large-cnn", + "base_url": None, "system_prompt": "Hello Lumigator", "config_template": str, } @@ -428,7 +422,7 @@ def create_job_config() -> JobConfig: def simple_eval_template(): return """{{ "name": "{job_name}/{job_id}", - "model": {{ "path": "{model_uri}" }}, + "model": {{ "path": "{model_name_or_path}" }}, "dataset": {{ "path": "{dataset_path}" }}, "evaluation": {{ "metrics": ["meteor", "rouge"], @@ -446,7 +440,7 @@ def simple_infer_template(): "name": "{job_name}/{job_id}", "dataset": {{ "path": "{dataset_path}" }}, "hf_pipeline": {{ - "model_uri": "{model_uri}", + "model_name_or_path": "{model_name_or_path}", "task": "{task}", "accelerator": "{accelerator}", "revision": "{revision}", diff --git a/lumigator/backend/backend/tests/data/health_job_metadata.json b/lumigator/backend/backend/tests/data/health_job_metadata.json index e62140a9f..30e34814b 100644 --- a/lumigator/backend/backend/tests/data/health_job_metadata.json +++ b/lumigator/backend/backend/tests/data/health_job_metadata.json @@ -4,7 +4,7 @@ "submission_id": "e899341d-bada-4f3c-ae32-b87bf730f897", "driver_info": null, "status": "RUNNING", - "entrypoint": "python inference.py --config '{\n \"name\": \"test_run_hugging_face/e899341d-bada-4f3c-ae32-b87bf730f897\",\n \"model\": { \"path\": \"hf://facebook/bart-large-cnn\" },\n \"dataset\": { \"path\": \"s3://lumigator-storage/datasets/c404aa33-4c4c-4a59-845e-01e10ad22226/thunderbird_gt_bart.csv\" },\n \"evaluation\": {\n \"metrics\": [\"rouge\", \"meteor\", \"bertscore\"],\n \"use_pipeline\": true,\n \"max_samples\": 10,\n \"return_input_data\": true,\n \"return_predictions\": true,\n \"storage_path\": \"s3://lumigator-storage/jobs/results/\"\n }\n}'", + "entrypoint": "python inference.py --config '{\n \"name\": \"test_run_hugging_face/e899341d-bada-4f3c-ae32-b87bf730f897\",\n \"model\": { \"path\": \"facebook/bart-large-cnn\" },\n \"provider\": { \"path\": \"hf\" },\n \"dataset\": { \"path\": \"s3://lumigator-storage/datasets/c404aa33-4c4c-4a59-845e-01e10ad22226/thunderbird_gt_bart.csv\" },\n \"evaluation\": {\n \"metrics\": [\"rouge\", \"meteor\", \"bertscore\"],\n \"use_pipeline\": true,\n \"max_samples\": 10,\n \"return_input_data\": true,\n \"return_predictions\": true,\n \"storage_path\": \"s3://lumigator-storage/jobs/results/\"\n }\n}'", "message": "Job is currently running.", "error_type": null, "start_time": "2024-11-07T17:04:28.650000Z", diff --git a/lumigator/backend/backend/tests/data/health_job_metadata_ray.json b/lumigator/backend/backend/tests/data/health_job_metadata_ray.json index 2b1f2fbdc..be2b600e1 100644 --- a/lumigator/backend/backend/tests/data/health_job_metadata_ray.json +++ b/lumigator/backend/backend/tests/data/health_job_metadata_ray.json @@ -4,7 +4,7 @@ "submission_id": "e899341d-bada-4f3c-ae32-b87bf730f897", "driver_info": null, "status": "RUNNING", - "entrypoint": "python inference.py --config '{\n \"name\": \"test_run_hugging_face/e899341d-bada-4f3c-ae32-b87bf730f897\",\n \"model\": { \"path\": \"hf://facebook/bart-large-cnn\" },\n \"dataset\": { \"path\": \"s3://lumigator-storage/datasets/c404aa33-4c4c-4a59-845e-01e10ad22226/thunderbird_gt_bart.csv\" },\n \"evaluation\": {\n \"metrics\": [\"rouge\", \"meteor\", \"bertscore\"],\n \"use_pipeline\": true,\n \"max_samples\": 10,\n \"return_input_data\": true,\n \"return_predictions\": true,\n \"storage_path\": \"s3://lumigator-storage/jobs/results/\"\n }\n}'", + "entrypoint": "python inference.py --config '{\n \"name\": \"test_run_hugging_face/e899341d-bada-4f3c-ae32-b87bf730f897\",\n \"model\": { \"path\": \"facebook/bart-large-cnn\" },\n \"provider\": { \"path\": \"hf\" },\n \"dataset\": { \"path\": \"s3://lumigator-storage/datasets/c404aa33-4c4c-4a59-845e-01e10ad22226/thunderbird_gt_bart.csv\" },\n \"evaluation\": {\n \"metrics\": [\"rouge\", \"meteor\", \"bertscore\"],\n \"use_pipeline\": true,\n \"max_samples\": 10,\n \"return_input_data\": true,\n \"return_predictions\": true,\n \"storage_path\": \"s3://lumigator-storage/jobs/results/\"\n }\n}'", "message": "Job is currently running.", "error_type": null, "start_time": 1730999068650, diff --git a/lumigator/backend/backend/tests/data/models.json b/lumigator/backend/backend/tests/data/models.json deleted file mode 100644 index 0949d7c27..000000000 --- a/lumigator/backend/backend/tests/data/models.json +++ /dev/null @@ -1,103 +0,0 @@ -{ - "total": 6, - "items": [ - { - "name": "facebook/bart-large-cnn", - "uri": "hf://facebook/bart-large-cnn", - "description": "BART is a large-sized model fine-tuned on the CNN Daily Mail dataset.", - "requirements": [], - "info": { - "parameter_count": "406M", - "tensor_type": "F32", - "model_size": "1.63GB" - }, - "tasks": [ - { - "summarization": { - "max_length": 142, - "min_length": 56, - "length_penalty": 2, - "early_stopping": true, - "no_repeat_ngram_size": 3, - "num_beams": 4 - } - } - ] - }, - { - "name": "Falconsai/text_summarization", - "uri": "hf://Falconsai/text_summarization", - "description": "A fine-tuned variant of the T5 transformer model, designed for the task of text summarization.", - "requirements": [], - "info": { - "parameter_count": "60.5M", - "tensor_type": "F32", - "model_size": "242MB" - }, - "tasks": [ - { - "summarization": { - "max_length": 200, - "min_length": 30, - "length_penalty": 2, - "early_stopping": true, - "no_repeat_ngram_size": 3, - "num_beams": 4 - } - } - ] - }, - { - "name": "gpt-4o-mini", - "uri": "oai://gpt-4o-mini", - "description": "OpenAI's GPT-4o-mini model.", - "requirements": [ "api_key" ], - "info": null, - "tasks": [ - { - "summarization": null - } - ] - }, - { - "name": "gpt-4o", - "uri": "oai://gpt-4o", - "description": "OpenAI's GPT-4o model.", - "requirements": [ "api_key" ], - "info": null, - "tasks": [ - { - "summarization": null - } - ] - }, - { - "name": "open-mistral-7b", - "uri": "mistral://open-mistral-7b", - "description": "Mistral's 7B model.", - "requirements": [ "api_key" ], - "info": null, - "tasks": [ - { - "summarization": null - } - ] - }, - { - "name": "mistralai/Mistral-7B-Instruct-v0.2", - "uri": "llamafile://mistralai/Mistral-7B-Instruct-v0.2", - "description": "A llamafile package of Mistral's 7B Instruct model.", - "requirements": [ "llamafile" ], - "info": { - "parameter_count": "7.24B", - "tensor_type": "BF16", - "model_size": "14.5GB" - }, - "tasks": [ - { - "summarization": null - } - ] - } - ] - } diff --git a/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py b/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py index 6e5a52dbd..0c9b58fb3 100644 --- a/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py +++ b/lumigator/backend/backend/tests/integration/api/routes/test_api_workflows.py @@ -16,7 +16,7 @@ JobStatus, JobType, ) -from lumigator_schemas.workflows import WorkflowDetailsResponse, WorkflowResponse +from lumigator_schemas.workflows import WorkflowDetailsResponse, WorkflowResponse, WorkflowStatus from backend.main import app from backend.tests.conftest import ( @@ -73,6 +73,7 @@ def test_upload_data_launch_job( "job_config": { "job_type": JobType.INFERENCE, "model": TEST_CAUSAL_MODEL, + "provider": "hf", "output_field": "predictions", "store_to_dataset": True, }, @@ -106,6 +107,7 @@ def test_upload_data_launch_job( "job_type": JobType.EVALUATION, "metrics": ["rouge", "meteor"], "model": TEST_CAUSAL_MODEL, + "provider": "hf", }, } @@ -250,6 +252,7 @@ def run_workflow(local_client: TestClient, dataset_id, experiment_id, workflow_n "name": workflow_name, "description": "Test workflow for inf and eval", "model": TEST_CAUSAL_MODEL, + "provider": "hf", "dataset": str(dataset_id), "experiment_id": experiment_id, "max_samples": 1, @@ -362,15 +365,14 @@ def test_job_non_existing(local_client: TestClient, dependency_overrides_service def wait_for_workflow_complete(local_client: TestClient, workflow_id: UUID): - workflow_status = JobStatus.PENDING for _ in range(1, 300): time.sleep(1) workflow_details = WorkflowDetailsResponse.model_validate(local_client.get(f"/workflows/{workflow_id}").json()) workflow_status = workflow_details.status - if workflow_status in [JobStatus.SUCCEEDED, JobStatus.FAILED]: + if workflow_status in [WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED]: logger.info(f"Workflow status: {workflow_status}") break - if workflow_status not in [JobStatus.SUCCEEDED, JobStatus.FAILED]: + if workflow_status not in [WorkflowStatus.SUCCEEDED, WorkflowStatus.FAILED]: raise Exception(f"Stopped, job remains in {workflow_status} status") return workflow_details diff --git a/lumigator/backend/backend/tests/unit/api/routes/test_models.py b/lumigator/backend/backend/tests/unit/api/routes/test_models.py index 43e153ee0..17370b583 100644 --- a/lumigator/backend/backend/tests/unit/api/routes/test_models.py +++ b/lumigator/backend/backend/tests/unit/api/routes/test_models.py @@ -1,31 +1,34 @@ import json from pathlib import Path +import yaml from fastapi.testclient import TestClient from lumigator_schemas.extras import ListingResponse from lumigator_schemas.models import ModelsResponse from backend.api.routes.models import _get_supported_tasks +MODELS_PATH = Path(__file__).resolve().parents[4] / "models.yaml" -def test_get_suggested_models_summarization_ok(app_client: TestClient, json_data_models: Path): + +def test_get_suggested_models_summarization_ok(app_client: TestClient): response = app_client.get("/models/summarization") assert response.status_code == 200 models = ListingResponse[ModelsResponse].model_validate(response.json()) - with Path(json_data_models).open() as file: - data = json.load(file) + with Path(MODELS_PATH).open() as file: + data = yaml.safe_load(file) - assert models.total == data["total"] + assert models.total == len(data) -def test_get_suggested_models_invalid_task(app_client: TestClient, json_data_models: Path): +def test_get_suggested_models_invalid_task(app_client: TestClient): response = app_client.get("/models/invalid_task") assert response.status_code == 400 - with Path(json_data_models).open() as file: - data = json.load(file) + with Path(MODELS_PATH).open() as file: + data = yaml.safe_load(file) - supported_tasks = _get_supported_tasks(data.get("items", [])) + supported_tasks = _get_supported_tasks(data) assert response.json() == {"detail": f"Unsupported task. Choose from: {supported_tasks}"} diff --git a/lumigator/backend/backend/tests/unit/services/test_job_service.py b/lumigator/backend/backend/tests/unit/services/test_job_service.py index adca68809..2218c489b 100644 --- a/lumigator/backend/backend/tests/unit/services/test_job_service.py +++ b/lumigator/backend/backend/tests/unit/services/test_job_service.py @@ -16,7 +16,7 @@ def test_set_null_inference_job_params(job_record, job_service): request = JobCreate( name="test_run_hugging_face", description="Test run for Huggingface model", - job_config=JobInferenceConfig(job_type=JobType.INFERENCE, model="hf://facebook/bart-large-cnn"), + job_config=JobInferenceConfig(job_type=JobType.INFERENCE, model="facebook/bart-large-cnn", provider="hf"), dataset="cced289c-f869-4af1-9195-1d58e32d1cc1", ) @@ -37,7 +37,7 @@ def test_set_explicit_inference_job_params(job_record, job_service): name="test_run_hugging_face", description="Test run for Huggingface model", max_samples=10, - job_config=JobInferenceConfig(job_type=JobType.INFERENCE, model="hf://facebook/bart-large-cnn"), + job_config=JobInferenceConfig(job_type=JobType.INFERENCE, model="facebook/bart-large-cnn", provider="hf"), dataset="cced289c-f869-4af1-9195-1d58e32d1cc1", ) @@ -54,43 +54,46 @@ def test_set_explicit_inference_job_params(job_record, job_service): @pytest.mark.parametrize( - ["model", "input_model_url", "returned_model_url"], + ["model", "provider", "input_base_url", "returned_base_url"], [ # generic HF model loaded locally - ("hf://facebook/bart-large-cnn", None, None), - # vLLM served model (with HF model name specified to be passed as "engine") + ("facebook/bart-large-cnn", "hf", None, None), + # vLLM served model (with HF model name specified to be passed as "model") ( - "hf://mistralai/Mistral-7B-Instruct-v0.3", + "mistralai/Mistral-7B-Instruct-v0.3", + "hf", "http://localhost:8000/v1/chat/completions", "http://localhost:8000/v1/chat/completions", ), # llamafile served model (with custom model name) ( - "llamafile://mistralai/Mistral-7B-Instruct-v0.2", + "mistralai/Mistral-7B-Instruct-v0.2", + "openai", "http://localhost:8000/v1/chat/completions", "http://localhost:8000/v1/chat/completions", ), # openai model (from API) - ("oai://gpt-4-turbo", None, settings.OAI_API_URL), + ("gpt-4-turbo", "openai", "https://api.openai.com/v1", settings.OAI_API_URL), # mistral model (from API) - ("mistral://open-mistral-7b", None, settings.MISTRAL_API_URL), + ("open-mistral-7b", "mistral", "https://api.mistral.ai/v1", settings.MISTRAL_API_URL), # deepseek model (from API) - ("ds://deepseek-chat", None, settings.DEEPSEEK_API_URL), + ("deepseek-chat", "deepseek", "https://api.deepseek.com/v1", settings.DEEPSEEK_API_URL), ], ) -def test_set_model(job_service, model, input_model_url, returned_model_url): +def test_set_model(job_service, model, provider, input_base_url, returned_base_url): request = JobCreate( name="test_run", description="Test run to verify how model URL is set", job_config=JobInferenceConfig( job_type=JobType.INFERENCE, model=model, - model_url=input_model_url, + provider=provider, + base_url=input_base_url, ), dataset="d34dd34d-d34d-d34d-d34d-d34dd34dd34d", ) - model_url = job_service._set_model_type(request) - assert model_url == returned_model_url + base_url = request.job_config.base_url + assert base_url == returned_base_url def test_invalid_text_generation(job_service): @@ -98,7 +101,7 @@ def test_invalid_text_generation(job_service): name="test_text_generation_run", description="Test run to verify that system prompt is set.", job_config=JobInferenceConfig( - job_type=JobType.INFERENCE, model="hf://microsoft/Phi-3.5-mini-instruct", task="text-generation" + job_type=JobType.INFERENCE, model="microsoft/Phi-3.5-mini-instruct", provider="hf", task="text-generation" ), dataset="d34dd34d-d34d-d34d-d34d-d34dd34dd34d", ) diff --git a/lumigator/backend/backend/tracking/mlflow.py b/lumigator/backend/backend/tracking/mlflow.py index bb87e84d5..e8b951629 100644 --- a/lumigator/backend/backend/tracking/mlflow.py +++ b/lumigator/backend/backend/tracking/mlflow.py @@ -195,7 +195,7 @@ def create_workflow(self, experiment_id: str, description: str, name: str, model experiment_id=experiment_id, tags={ "mlflow.runName": name, - "status": WorkflowStatus.CREATED, + "status": WorkflowStatus.CREATED.value, "description": description, "model": model, }, @@ -232,7 +232,7 @@ def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse: description=workflow.data.tags.get("description"), name=workflow.data.tags.get("mlflow.runName"), model=workflow.data.tags.get("model"), - status=WorkflowStatus(WorkflowStatus[workflow.data.tags.get("status").split(".")[1]]), + status=WorkflowStatus(workflow.data.tags.get("status")), created_at=datetime.fromtimestamp(workflow.info.start_time / 1000), jobs=[self.get_job(job_id) for job_id in all_job_ids], metrics=self._compile_metrics(all_job_ids), @@ -284,7 +284,7 @@ def get_workflow(self, workflow_id: str) -> WorkflowDetailsResponse: def update_workflow_status(self, workflow_id: str, status: WorkflowStatus) -> None: """Update the status of a workflow.""" - self._client.set_tag(workflow_id, "status", status) + self._client.set_tag(workflow_id, "status", status.value) def _get_ray_job_logs(self, ray_job_id: str): """Get the logs for a Ray job.""" @@ -352,7 +352,7 @@ def delete_workflow(self, workflow_id: str) -> WorkflowResponse: name=workflow.data.tags.get("mlflow.runName"), description=workflow.data.tags.get("description"), model=workflow.data.tags.get("model"), - status=WorkflowStatus(WorkflowStatus[workflow.data.tags.get("status").split(".")[1]]), + status=WorkflowStatus(workflow.data.tags.get("status")), created_at=datetime.fromtimestamp(workflow.info.start_time / 1000), ) diff --git a/lumigator/backend/pyproject.toml b/lumigator/backend/pyproject.toml index 7a925f4d4..7c0bf4fc8 100644 --- a/lumigator/backend/pyproject.toml +++ b/lumigator/backend/pyproject.toml @@ -9,9 +9,7 @@ dependencies = [ "boto3==1.34.105", "boto3-stubs[essential,s3]==1.34.105", "loguru==0.7.2", - "mistralai==0.4.2", "mypy-boto3==1.34.105", - "openai==1.38.0", "pydantic>=2.10.0", "pydantic-settings==2.2.1", "requests>=2,<3", diff --git a/lumigator/backend/uv.lock b/lumigator/backend/uv.lock index 52aa475c5..156c3d79f 100644 --- a/lumigator/backend/uv.lock +++ b/lumigator/backend/uv.lock @@ -172,10 +172,8 @@ dependencies = [ { name = "fastapi", extra = ["standard"] }, { name = "loguru" }, { name = "lumigator-schemas" }, - { name = "mistralai" }, { name = "mlflow" }, { name = "mypy-boto3" }, - { name = "openai" }, { name = "psycopg2-binary" }, { name = "pydantic" }, { name = "pydantic-settings" }, @@ -204,10 +202,8 @@ requires-dist = [ { name = "fastapi", extras = ["standard"], specifier = ">=0.115.0" }, { name = "loguru", specifier = "==0.7.2" }, { name = "lumigator-schemas", editable = "../schemas" }, - { name = "mistralai", specifier = "==0.4.2" }, { name = "mlflow", specifier = ">=2.20.0" }, { name = "mypy-boto3", specifier = "==1.34.105" }, - { name = "openai", specifier = "==1.38.0" }, { name = "psycopg2-binary", specifier = "==2.9.9" }, { name = "pydantic", specifier = ">=2.10.0" }, { name = "pydantic-settings", specifier = "==2.2.1" }, @@ -621,15 +617,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252 }, ] -[[package]] -name = "distro" -version = "1.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fc/f8/98eea607f65de6527f8a2e8885fc8015d3e6f5775df186e443e0964a11c3/distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed", size = 60722 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 }, -] - [[package]] name = "dnspython" version = "2.7.0" @@ -1401,20 +1388,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979 }, ] -[[package]] -name = "mistralai" -version = "0.4.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "httpx" }, - { name = "orjson" }, - { name = "pydantic" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/fa/20/4204f461588310b3a7ffbbbb7fa573493dc1c8185d376ee72516c04575bf/mistralai-0.4.2.tar.gz", hash = "sha256:5eb656710517168ae053f9847b0bb7f617eda07f1f93f946ad6c91a4d407fd93", size = 14234 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4f/fe/79dad76b8d94b62d9e2aab8446183190e1dc384c617d06c3c93307850e11/mistralai-0.4.2-py3-none-any.whl", hash = "sha256:63c98eea139585f0a3b2c4c6c09c453738bac3958055e6f2362d3866e96b0168", size = 20334 }, -] - [[package]] name = "mlflow" version = "2.20.1" @@ -1748,24 +1721,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/6f/129e3c17e3befe7fefdeaa6890f4c4df3f3cf0831aa053802c3862da67aa/numpy-2.1.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:ef444c57d664d35cac4e18c298c47d7b504c66b17c2ea91312e979fcfbdfb08a", size = 14066202 }, ] -[[package]] -name = "openai" -version = "1.38.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "distro" }, - { name = "httpx" }, - { name = "pydantic" }, - { name = "sniffio" }, - { name = "tqdm" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/04/c2/66d2fc6e7dd0c8d1d95a2256090c51b0c4b93b0e15dd0205fdecba02281c/openai-1.38.0.tar.gz", hash = "sha256:30fb324bf452ecb1194ca7dbc64566a4d7aa054c6a5da857937ede7d517a220b", size = 257665 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f5/25/8ef12125750ffb3a6e8390a718f5fe307c70ca35c96ad9b715ac7e41089a/openai-1.38.0-py3-none-any.whl", hash = "sha256:a19ef052f1676320f52183ae6f9775da6d888fbe3aec57886117163c095d9f7c", size = 335921 }, -] - [[package]] name = "opentelemetry-api" version = "1.29.0" @@ -1806,41 +1761,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/fb/dc15fad105450a015e913cfa4f5c27b6a5f1bea8fb649f8cae11e699c8af/opentelemetry_semantic_conventions-0.50b0-py3-none-any.whl", hash = "sha256:e87efba8fdb67fb38113efea6a349531e75ed7ffc01562f65b802fcecb5e115e", size = 166602 }, ] -[[package]] -name = "orjson" -version = "3.10.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/80/44/d36e86b33fc84f224b5f2cdf525adf3b8f9f475753e721c402b1ddef731e/orjson-3.10.10.tar.gz", hash = "sha256:37949383c4df7b4337ce82ee35b6d7471e55195efa7dcb45ab8226ceadb0fe3b", size = 5404170 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/79/bc/2a0eb0029729f1e466d5a595261446e5c5b6ed9213759ee56b6202f99417/orjson-3.10.10-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:879e99486c0fbb256266c7c6a67ff84f46035e4f8749ac6317cc83dacd7f993a", size = 270717 }, - { url = "https://files.pythonhosted.org/packages/3d/2b/5af226f183ce264bf64f15afe58647b09263dc1bde06aaadae6bbeca17f1/orjson-3.10.10-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:019481fa9ea5ff13b5d5d95e6fd5ab25ded0810c80b150c2c7b1cc8660b662a7", size = 153294 }, - { url = "https://files.pythonhosted.org/packages/1d/95/d6a68ab51ed76e3794669dabb51bf7fa6ec2f4745f66e4af4518aeab4b73/orjson-3.10.10-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:0dd57eff09894938b4c86d4b871a479260f9e156fa7f12f8cad4b39ea8028bb5", size = 168628 }, - { url = "https://files.pythonhosted.org/packages/c0/c9/1bbe5262f5e9df3e1aeec44ca8cc86846c7afb2746fa76bf668a7d0979e9/orjson-3.10.10-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dbde6d70cd95ab4d11ea8ac5e738e30764e510fc54d777336eec09bb93b8576c", size = 155845 }, - { url = "https://files.pythonhosted.org/packages/bf/22/e17b14ff74646e6c080dccb2859686a820bc6468f6b62ea3fe29a8bd3b05/orjson-3.10.10-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b2625cb37b8fb42e2147404e5ff7ef08712099197a9cd38895006d7053e69d6", size = 166406 }, - { url = "https://files.pythonhosted.org/packages/8a/1e/b3abbe352f648f96a418acd1e602b1c77ffcc60cf801a57033da990b2c49/orjson-3.10.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dbf3c20c6a7db69df58672a0d5815647ecf78c8e62a4d9bd284e8621c1fe5ccb", size = 144518 }, - { url = "https://files.pythonhosted.org/packages/0e/5e/28f521ee0950d279489db1522e7a2460d0596df7c5ca452e242ff1509cfe/orjson-3.10.10-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:75c38f5647e02d423807d252ce4528bf6a95bd776af999cb1fb48867ed01d1f6", size = 172187 }, - { url = "https://files.pythonhosted.org/packages/04/b4/538bf6f42eb0fd5a485abbe61e488d401a23fd6d6a758daefcf7811b6807/orjson-3.10.10-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:23458d31fa50ec18e0ec4b0b4343730928296b11111df5f547c75913714116b2", size = 170152 }, - { url = "https://files.pythonhosted.org/packages/94/5c/a1a326a58452f9261972ad326ae3bb46d7945681239b7062a1b85d8811e2/orjson-3.10.10-cp311-none-win32.whl", hash = "sha256:2787cd9dedc591c989f3facd7e3e86508eafdc9536a26ec277699c0aa63c685b", size = 145116 }, - { url = "https://files.pythonhosted.org/packages/df/12/a02965df75f5a247091306d6cf40a77d20bf6c0490d0a5cb8719551ee815/orjson-3.10.10-cp311-none-win_amd64.whl", hash = "sha256:6514449d2c202a75183f807bc755167713297c69f1db57a89a1ef4a0170ee269", size = 139307 }, - { url = "https://files.pythonhosted.org/packages/21/c6/f1d2ec3ffe9d6a23a62af0477cd11dd2926762e0186a1fad8658a4f48117/orjson-3.10.10-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:8564f48f3620861f5ef1e080ce7cd122ee89d7d6dacf25fcae675ff63b4d6e05", size = 270801 }, - { url = "https://files.pythonhosted.org/packages/52/01/eba0226efaa4d4be8e44d9685750428503a3803648878fa5607100a74f81/orjson-3.10.10-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5bf161a32b479034098c5b81f2608f09167ad2fa1c06abd4e527ea6bf4837a9", size = 153221 }, - { url = "https://files.pythonhosted.org/packages/da/4b/a705f9d3ae4786955ee0ac840b20960add357e612f1b0a54883d1811fe1a/orjson-3.10.10-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:68b65c93617bcafa7f04b74ae8bc2cc214bd5cb45168a953256ff83015c6747d", size = 168590 }, - { url = "https://files.pythonhosted.org/packages/de/6c/eb405252e7d9ae9905a12bad582cfe37ef8ef18fdfee941549cb5834c7b2/orjson-3.10.10-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8e28406f97fc2ea0c6150f4c1b6e8261453318930b334abc419214c82314f85", size = 156052 }, - { url = "https://files.pythonhosted.org/packages/9f/e7/65a0461574078a38f204575153524876350f0865162faa6e6e300ecaa199/orjson-3.10.10-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4d0d9fe174cc7a5bdce2e6c378bcdb4c49b2bf522a8f996aa586020e1b96cee", size = 166562 }, - { url = "https://files.pythonhosted.org/packages/dd/99/85780be173e7014428859ba0211e6f2a8f8038ea6ebabe344b42d5daa277/orjson-3.10.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3be81c42f1242cbed03cbb3973501fcaa2675a0af638f8be494eaf37143d999", size = 144892 }, - { url = "https://files.pythonhosted.org/packages/ed/c0/c7c42a2daeb262da417f70064746b700786ee0811b9a5821d9d37543b29d/orjson-3.10.10-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:65f9886d3bae65be026219c0a5f32dbbe91a9e6272f56d092ab22561ad0ea33b", size = 172093 }, - { url = "https://files.pythonhosted.org/packages/ad/9b/be8b3d3aec42aa47f6058482ace0d2ca3023477a46643d766e96281d5d31/orjson-3.10.10-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:730ed5350147db7beb23ddaf072f490329e90a1d059711d364b49fe352ec987b", size = 170424 }, - { url = "https://files.pythonhosted.org/packages/1b/15/a4cc61e23c39b9dec4620cb95817c83c84078be1771d602f6d03f0e5c696/orjson-3.10.10-cp312-none-win32.whl", hash = "sha256:a8f4bf5f1c85bea2170800020d53a8877812892697f9c2de73d576c9307a8a5f", size = 145132 }, - { url = "https://files.pythonhosted.org/packages/9f/8a/ce7c28e4ea337f6d95261345d7c61322f8561c52f57b263a3ad7025984f4/orjson-3.10.10-cp312-none-win_amd64.whl", hash = "sha256:384cd13579a1b4cd689d218e329f459eb9ddc504fa48c5a83ef4889db7fd7a4f", size = 139389 }, - { url = "https://files.pythonhosted.org/packages/0c/69/f1c4382cd44bdaf10006c4e82cb85d2bcae735369f84031e203c4e5d87de/orjson-3.10.10-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:44bffae68c291f94ff5a9b4149fe9d1bdd4cd0ff0fb575bcea8351d48db629a1", size = 270695 }, - { url = "https://files.pythonhosted.org/packages/61/29/aeb5153271d4953872b06ed239eb54993a5f344353727c42d3aabb2046f6/orjson-3.10.10-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e27b4c6437315df3024f0835887127dac2a0a3ff643500ec27088d2588fa5ae1", size = 141632 }, - { url = "https://files.pythonhosted.org/packages/bc/a2/c8ac38d8fb461a9b717c766fbe1f7d3acf9bde2f12488eb13194960782e4/orjson-3.10.10-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bca84df16d6b49325a4084fd8b2fe2229cb415e15c46c529f868c3387bb1339d", size = 144854 }, - { url = "https://files.pythonhosted.org/packages/79/51/e7698fdb28bdec633888cc667edc29fd5376fce9ade0a5b3e22f5ebe0343/orjson-3.10.10-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c14ce70e8f39bd71f9f80423801b5d10bf93d1dceffdecd04df0f64d2c69bc01", size = 172023 }, - { url = "https://files.pythonhosted.org/packages/02/2d/0d99c20878658c7e33b90e6a4bb75cf2924d6ff29c2365262cff3c26589a/orjson-3.10.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:24ac62336da9bda1bd93c0491eff0613003b48d3cb5d01470842e7b52a40d5b4", size = 170429 }, - { url = "https://files.pythonhosted.org/packages/cd/45/6a4a446f4fb29bb4703c3537d5c6a2bf7fed768cb4d7b7dce9d71b72fc93/orjson-3.10.10-cp313-none-win32.whl", hash = "sha256:eb0a42831372ec2b05acc9ee45af77bcaccbd91257345f93780a8e654efc75db", size = 145099 }, - { url = "https://files.pythonhosted.org/packages/72/6e/4631fe219a4203aa111e9bb763ad2e2e0cdd1a03805029e4da124d96863f/orjson-3.10.10-cp313-none-win_amd64.whl", hash = "sha256:f0c4f37f8bf3f1075c6cc8dd8a9f843689a4b618628f8812d0a71e6968b95ffd", size = 139176 }, -] - [[package]] name = "packaging" version = "24.1" diff --git a/lumigator/frontend/src/components/experiments/LExperimentResults.vue b/lumigator/frontend/src/components/experiments/LExperimentResults.vue index 6e6f2bec9..3011a61a9 100644 --- a/lumigator/frontend/src/components/experiments/LExperimentResults.vue +++ b/lumigator/frontend/src/components/experiments/LExperimentResults.vue @@ -14,7 +14,7 @@ Model @@ -189,7 +189,7 @@ const tooltips = ref({ const tableData = computed(() => { const data = selectedExperimentResults.value.map((results) => ({ ...results, - model: models.value.find((model: Model) => model.name === results.model), + model: models.value.find((model: Model) => model.display_name === results.model), })) return data diff --git a/lumigator/frontend/src/components/experiments/LModelCards.vue b/lumigator/frontend/src/components/experiments/LModelCards.vue index 80561e1c9..9046025c4 100644 --- a/lumigator/frontend/src/components/experiments/LModelCards.vue +++ b/lumigator/frontend/src/components/experiments/LModelCards.vue @@ -10,7 +10,7 @@
@@ -21,7 +21,7 @@ name="model" @click.stop /> - +