Skip to content

Commit

Permalink
feat(api): add update model features API (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
maocorte authored Jul 26, 2024
1 parent a19cd20 commit 80875c2
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 6 deletions.
11 changes: 10 additions & 1 deletion api/app/db/dao/model_dao.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import re
from typing import List, Optional
from typing import Dict, List, Optional
from uuid import UUID

from fastapi_pagination import Page, Params
Expand Down Expand Up @@ -32,6 +32,15 @@ def get_by_uuid(self, uuid: UUID) -> Optional[Model]:
.one_or_none()
)

def update_features(self, uuid: UUID, model_features: List[Dict]):
with self.db.begin_session() as session:
query = (
sqlalchemy.update(Model)
.where(Model.uuid == uuid)
.values(features=model_features)
)
return session.execute(query).rowcount

def delete(self, uuid: UUID) -> int:
with self.db.begin_session() as session:
deleted_at = datetime.datetime.now(tz=datetime.UTC)
Expand Down
8 changes: 8 additions & 0 deletions api/app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ def to_dict(self):
return self.model_dump()


class ModelFeatures(BaseModel):
features: List[ColumnDefinition]

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)


class ModelIn(BaseModel, validate_assignment=True):
name: str
description: Optional[str] = None
Expand Down
12 changes: 10 additions & 2 deletions api/app/routes/model_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Annotated, List, Optional
from uuid import UUID

from fastapi import APIRouter
from fastapi import APIRouter, Response
from fastapi.params import Query
from fastapi_pagination import Page, Params

from app.core import get_config
from app.models.model_dto import ModelIn, ModelOut
from app.models.model_dto import ModelFeatures, ModelIn, ModelOut
from app.models.model_order import OrderType
from app.services.model_service import ModelService

Expand Down Expand Up @@ -51,4 +51,12 @@ def delete_model(model_uuid: UUID):
logger.info('Model %s with name %s deleted.', model.uuid, model.name)
return model

@router.post('/{model_uuid}', status_code=200)
def update_model_features_by_uuid(
model_uuid: UUID, model_features: ModelFeatures
):
if model_service.update_model_features_by_uuid(model_uuid, model_features):
return Response(status_code=200)
return Response(status_code=404)

return router
21 changes: 19 additions & 2 deletions api/app/services/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from app.db.tables.current_dataset_table import CurrentDataset
from app.db.tables.model_table import Model
from app.db.tables.reference_dataset_table import ReferenceDataset
from app.models.exceptions import ModelInternalError, ModelNotFoundError
from app.models.model_dto import ModelIn, ModelOut
from app.models.exceptions import ModelError, ModelInternalError, ModelNotFoundError
from app.models.model_dto import ModelFeatures, ModelIn, ModelOut
from app.models.model_order import OrderType


Expand Down Expand Up @@ -46,6 +46,23 @@ def get_model_by_uuid(self, model_uuid: UUID) -> Optional[ModelOut]:
latest_current_dataset=latest_current_dataset,
)

def update_model_features_by_uuid(
self, model_uuid: UUID, model_features: ModelFeatures
) -> bool:
last_reference = self.rd_dao.get_latest_reference_dataset_by_model_uuid(
model_uuid
)
if last_reference is not None:
raise ModelError(
'Model already has a reference dataset, could not be updated', 400
) from None
return (
self.model_dao.update_features(
model_uuid, [feature.to_dict() for feature in model_features.features]
)
> 0
)

def delete_model(self, model_uuid: UUID) -> Optional[ModelOut]:
model = self.check_and_get_model(model_uuid)
self.model_dao.delete(model_uuid)
Expand Down
13 changes: 13 additions & 0 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DataType,
FieldType,
Granularity,
ModelFeatures,
ModelIn,
ModelType,
OutputType,
Expand Down Expand Up @@ -72,6 +73,18 @@ def get_sample_model(
)


def get_sample_model_features(
features: List[ColumnDefinition] = [
ColumnDefinition(
name='feature1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
):
return ModelFeatures(features=features)


def get_sample_model_in(
name: str = 'model_name',
description: Optional[str] = None,
Expand Down
12 changes: 12 additions & 0 deletions api/tests/dao/model_dao_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ def test_get_by_uuid_empty(self):
retrieved = self.model_dao.get_by_uuid(uuid.uuid4())
assert retrieved is None

def test_update(self):
model = db_mock.get_sample_model()
self.model_dao.insert(model)
new_features = [
{'name': 'feature1', 'type': 'string', 'fieldType': 'categorical'},
{'name': 'feature2', 'type': 'int', 'fieldType': 'numerical'},
]
updated_rows = self.model_dao.update_features(model.uuid, new_features)
retrieved = self.model_dao.get_by_uuid(model.uuid)
assert updated_rows == 1
assert retrieved.features == new_features

def test_delete(self):
model = db_mock.get_sample_model()
self.model_dao.insert(model)
Expand Down
28 changes: 28 additions & 0 deletions api/tests/routes/model_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,34 @@ def test_create_model(self):
assert jsonable_encoder(model_out) == res.json()
self.model_service.create_model.assert_called_once_with(model_in)

def test_update_model_ok(self):
model_features = db_mock.get_sample_model_features()
self.model_service.update_model_features_by_uuid = MagicMock(return_value=True)

res = self.client.post(
f'{self.prefix}/{db_mock.MODEL_UUID}',
json=jsonable_encoder(model_features),
)

assert res.status_code == 200
self.model_service.update_model_features_by_uuid.assert_called_once_with(
db_mock.MODEL_UUID, model_features
)

def test_update_model_ko(self):
model_features = db_mock.get_sample_model_features()
self.model_service.update_model_features_by_uuid = MagicMock(return_value=False)

res = self.client.post(
f'{self.prefix}/{db_mock.MODEL_UUID}',
json=jsonable_encoder(model_features),
)

assert res.status_code == 404
self.model_service.update_model_features_by_uuid.assert_called_once_with(
db_mock.MODEL_UUID, model_features
)

def test_get_model_by_uuid(self):
model = db_mock.get_sample_model()
model_out = ModelOut.from_model(model)
Expand Down
52 changes: 51 additions & 1 deletion api/tests/services/model_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from app.db.dao.current_dataset_dao import CurrentDatasetDAO
from app.db.dao.model_dao import ModelDAO
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.models.exceptions import ModelNotFoundError
from app.models.exceptions import ModelError, ModelNotFoundError
from app.models.model_dto import ModelOut
from app.models.model_order import OrderType
from app.services.model_service import ModelService
Expand Down Expand Up @@ -66,6 +66,56 @@ def test_get_model_by_uuid_not_found(self):
)
self.model_dao.get_by_uuid.assert_called_once()

def test_update_model_ok(self):
model_features = db_mock.get_sample_model_features()
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=None
)
self.model_dao.update_features = MagicMock(return_value=1)
res = self.model_service.update_model_features_by_uuid(
model_uuid, model_features
)
feature_dict = [feature.to_dict() for feature in model_features.features]
self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict)

assert res is True

def test_update_model_ko(self):
model_features = db_mock.get_sample_model_features()
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=None
)
self.model_dao.update_features = MagicMock(return_value=0)
res = self.model_service.update_model_features_by_uuid(
model_uuid, model_features
)
feature_dict = [feature.to_dict() for feature in model_features.features]
self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.model_dao.update_features.assert_called_once_with(model_uuid, feature_dict)

assert res is False

def test_update_model_freezed(self):
model_features = db_mock.get_sample_model_features()
self.rd_dao.get_latest_reference_dataset_by_model_uuid = MagicMock(
return_value=db_mock.get_sample_reference_dataset()
)
self.model_dao.update_features = MagicMock(return_value=0)
with pytest.raises(ModelError):
self.model_service.update_model_features_by_uuid(model_uuid, model_features)
feature_dict = [feature.to_dict() for feature in model_features.features]
self.rd_dao.get_latest_reference_dataset_by_model_uuid.assert_called_once_with(
model_uuid
)
self.model_dao.update_features.assert_called_once_with(
model_uuid, feature_dict
)

def test_delete_model_ok(self):
model = db_mock.get_sample_model()
self.model_dao.get_by_uuid = MagicMock(return_value=model)
Expand Down

0 comments on commit 80875c2

Please sign in to comment.