Skip to content

Commit

Permalink
Add keys infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
javiermtorres committed Feb 20, 2025
1 parent 50a9fc7 commit 1dcc0f8
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Add keys table
Revision ID: 54e483b63d8a
Revises: 9b5d04b45a85
Create Date: 2025-02-20 17:03:30.394254
"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '54e483b63d8a' # pragma: allowlist secret
down_revision: Union[str, None] = '9b5d04b45a85' # pragma: allowlist secret
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
op.create_table(
"keys",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("key_name", sa.String(), nullable=False, unique=True),
sa.Column("key_value", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("(CURRENT_TIMESTAMP)"),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
)


def downgrade() -> None:
op.drop_table("keys")
12 changes: 12 additions & 0 deletions lumigator/backend/backend/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from backend.db import session_manager
from backend.repositories.datasets import DatasetRepository
from backend.repositories.jobs import JobRepository, JobResultRepository
from backend.repositories.keys import KeyRepository
from backend.services.datasets import DatasetService
from backend.services.experiments import ExperimentService
from backend.services.jobs import JobService
from backend.services.workflows import WorkflowService
from backend.services.keys import KeyService
from backend.settings import settings
from backend.tracking import TrackingClientManager, tracking_client_manager

Expand Down Expand Up @@ -96,3 +98,13 @@ def get_workflow_service(


WorkflowServiceDep = Annotated[WorkflowService, Depends(get_workflow_service)]


def get_key_service(
session: DBSessionDep,
) -> KeyService:
key_repo = KeyRepository(session)
return KeyService(key_repo)


KeyServiceDep = Annotated[KeyService, Depends(get_key_service)]
2 changes: 2 additions & 0 deletions lumigator/backend/backend/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from backend.api.routes import datasets, experiments, health, jobs, models, workflows
from backend.api.tags import Tags
from backend.api.routes import keys

API_V1_PREFIX = "/api/v1"

Expand All @@ -12,3 +13,4 @@
api_router.include_router(experiments.router, prefix="/experiments", tags=[Tags.EXPERIMENTS])
api_router.include_router(models.router, prefix="/models", tags=[Tags.MODELS])
api_router.include_router(workflows.router, prefix="/workflows", tags=[Tags.WORKFLOWS])
api_router.include_router(keys.router, prefix="/keys", tags=[Tags.KEYS])
43 changes: 43 additions & 0 deletions lumigator/backend/backend/api/routes/keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from http import HTTPStatus

from fastapi import APIRouter, status
from lumigator_schemas.keys import Key
from starlette.requests import Request
from starlette.responses import Response

from loguru import logger

from backend.api.deps import KeyServiceDep
from backend.api.http_headers import HttpHeaders
from backend.services.exceptions.base_exceptions import ServiceError
from backend.settings import settings

router = APIRouter()

def key_exception_mappings() -> dict[type[ServiceError], HTTPStatus]:
return {
}



@router.put(
"/{key_name}",
status_code=status.HTTP_200_OK,
responses={
status.HTTP_201_CREATED: {"description": "Dataset successfully uploaded"},
},
)
def upload_key(
service: KeyServiceDep,
key: Key,
key_name: str,
request: Request,
response: Response,
) -> None:
"""Uploads a key for use in Lumigator.
Lumigator uses different keys for purposes such as external API calls.
The user can upload new values for these keys, but they cannot retrieve
them.
"""
service.upload_key(key, key_name)
5 changes: 5 additions & 0 deletions lumigator/backend/backend/api/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class Tags(str, Enum):
EXPERIMENTS = "experiments"
WORKFLOWS = "workflows"
MODELS = "models"
KEYS= "keys"


TAGS_METADATA = [
Expand Down Expand Up @@ -35,6 +36,10 @@ class Tags(str, Enum):
"name": Tags.MODELS,
"description": "Return a list of suggested models for a given task.",
},
{
"name": Tags.KEYS,
"description": "Put keys for different services.",
},
]
"""Metadata to associate with route tags in the OpenAPI documentation.
Expand Down
2 changes: 2 additions & 0 deletions lumigator/backend/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from backend.api.routes.experiments import experiment_exception_mappings
from backend.api.routes.jobs import job_exception_mappings
from backend.api.routes.workflows import workflow_exception_mappings
from backend.api.routes.keys import key_exception_mappings
from backend.api.tags import TAGS_METADATA
from backend.services.exceptions.base_exceptions import ServiceError
from backend.settings import settings
Expand Down Expand Up @@ -97,6 +98,7 @@ def create_app() -> FastAPI:
experiment_exception_mappings(),
job_exception_mappings(),
workflow_exception_mappings(),
key_exception_mappings(),
]

# Add a handler for each error -> status mapping.
Expand Down
11 changes: 11 additions & 0 deletions lumigator/backend/backend/records/keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sqlalchemy.orm import Mapped, mapped_column

from backend.records.base import BaseRecord
from backend.records.mixins import CreatedAtMixin


class KeyRecord(BaseRecord, CreatedAtMixin):
__tablename__ = "keys"
key_name: Mapped[str] = mapped_column(unique=True)
key_value: Mapped[str]
description: Mapped[str]
16 changes: 16 additions & 0 deletions lumigator/backend/backend/repositories/keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from uuid import UUID

from sqlalchemy.orm import Session

from backend.records.keys import KeyRecord
from backend.repositories.base import BaseRepository


class KeyRepository(BaseRepository[KeyRecord]):
def __init__(self, session: Session):
super().__init__(KeyRecord, session)

def get_by_key_name(self, key_name: str) -> KeyRecord | None:
return self.session.query(KeyRecord).where(KeyRecord.key_name == key_name).all().first()


30 changes: 30 additions & 0 deletions lumigator/backend/backend/services/keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import csv
from pathlib import Path
from uuid import UUID

from datasets import load_dataset
from fastapi import UploadFile
from loguru import logger
from mypy_boto3_s3.client import S3Client
from pydantic import ByteSize
from s3fs import S3FileSystem

from backend.repositories.keys import KeyRepository
from backend.settings import settings
from lumigator_schemas.keys import Key

class KeyService:
def __init__(
self, key_repo: KeyRepository
):
self._key_repo = key_repo

def upload_key(
self,
key: Key,
key_name: str
) -> None:
"""Uploads a key under a certain name
"""
self._key_repo.create(**key.model_dump(), key_name=key_name)

12 changes: 12 additions & 0 deletions lumigator/backend/backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
from backend.records.jobs import JobRecord
from backend.repositories.datasets import DatasetRepository
from backend.repositories.jobs import JobRepository, JobResultRepository
from backend.repositories.keys import KeyRepository
from backend.services.datasets import DatasetService
from backend.services.jobs import JobService
from backend.services.keys import KeyService
from backend.settings import BackendSettings, settings
from backend.tests.fakes.fake_s3 import FakeS3Client

Expand Down Expand Up @@ -379,6 +381,16 @@ def dataset_service(db_session, fake_s3_client, fake_s3fs):
return DatasetService(dataset_repo=dataset_repo, s3_client=fake_s3_client, s3_filesystem=fake_s3fs)


@pytest.fixture(scope="function")
def key_repository(db_session):
return KeyRepository(db_session)


@pytest.fixture(scope="function")
def key_service(db_session, key_repository):
return KeyService(key_repo=key_repository)


@pytest.fixture(scope="function")
def job_record(db_session):
return JobRecord
Expand Down
21 changes: 21 additions & 0 deletions lumigator/backend/backend/tests/unit/api/routes/test_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import json
from pathlib import Path

from fastapi import status
from fastapi.testclient import TestClient
from lumigator_schemas.keys import Key

from backend.repositories.keys import KeyRepository
from backend.services.keys import KeyService

def test_put_key(app_client: TestClient, key_service: KeyService, key_repository: KeyRepository, dependency_overrides_fakes):
new_key = Key(key_value="123456", description="test key")
new_key_name = 'TEST_AI_CLIENT_KEY'
assert key_repository.list() == []
response = app_client.put(f'/keys/{new_key_name}', json=new_key.model_dump())
assert response.status_code == status.HTTP_200_OK
assert key_repository.list() != []
db_key = key_repository.list()[0]
assert db_key.key_name == new_key_name
assert db_key.key_value == new_key.key_value
assert db_key.description == new_key.description
5 changes: 5 additions & 0 deletions lumigator/schemas/lumigator_schemas/keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pydantic import BaseModel

class Key(BaseModel):
key_value: str
description: str = ""

0 comments on commit 1dcc0f8

Please sign in to comment.