Skip to content

Commit

Permalink
feat: get all models not paginated (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
maocorte authored Jun 25, 2024
1 parent d81d86b commit 06bd2ef
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 75 deletions.
8 changes: 7 additions & 1 deletion api/app/db/dao/model_dao.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import re
from typing import Optional
from typing import List, Optional
from uuid import UUID

from fastapi_pagination import Page, Params
Expand Down Expand Up @@ -44,6 +44,12 @@ def delete(self, uuid: UUID) -> int:

def get_all(
self,
) -> List[Model]:
with self.db.begin_session() as session:
return session.query(Model).where(Model.deleted.is_(False))

def get_all_paginated(
self,
params: Params = Params(),
order: OrderType = OrderType.ASC,
sort: Optional[str] = None,
Expand Down
28 changes: 17 additions & 11 deletions api/app/routes/model_route.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Optional
from typing import Annotated, List, Optional
from uuid import UUID

from fastapi import APIRouter
Expand All @@ -19,6 +19,22 @@ class ModelRoute:
def get_router(model_service: ModelService) -> APIRouter:
router = APIRouter(tags=['model_api'])

@router.get('', status_code=200, response_model=Page[ModelOut])
def get_all_models_paginated(
_page: Annotated[int, Query()] = 1,
_limit: Annotated[int, Query()] = 50,
_order: Annotated[OrderType, Query()] = OrderType.ASC,
_sort: Annotated[Optional[str], Query()] = None,
):
params = Params(page=_page, size=_limit)
return model_service.get_all_models_paginated(
params=params, order=_order, sort=_sort
)

@router.get('/all', status_code=200, response_model=List[ModelOut])
def get_all_models():
return model_service.get_all_models()

@router.post('', status_code=201, response_model=ModelOut)
def create_model(model_in: ModelIn):
model = model_service.create_model(model_in)
Expand All @@ -35,14 +51,4 @@ def delete_model(model_uuid: UUID):
logger.info('Model %s with name %s deleted.', model.uuid, model.name)
return model

@router.get('', status_code=200, response_model=Page[ModelOut])
def get_all_models(
_page: Annotated[int, Query()] = 1,
_limit: Annotated[int, Query()] = 50,
_order: Annotated[OrderType, Query()] = OrderType.ASC,
_sort: Annotated[Optional[str], Query()] = None,
):
params = Params(page=_page, size=_limit)
return model_service.get_all_models(params=params, order=_order, sort=_sort)

return router
10 changes: 8 additions & 2 deletions api/app/services/model_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import List, Optional
from uuid import UUID

from fastapi_pagination import Page, Params
Expand Down Expand Up @@ -35,11 +35,17 @@ def delete_model(self, model_uuid: UUID) -> Optional[ModelOut]:

def get_all_models(
self,
) -> List[ModelOut]:
models = self.model_dao.get_all()
return [ModelOut.from_model(model) for model in models]

def get_all_models_paginated(
self,
params: Params = Params(),
order: OrderType = OrderType.ASC,
sort: Optional[str] = None,
) -> Page[ModelOut]:
models: Page[Model] = self.model_dao.get_all(
models: Page[Model] = self.model_dao.get_all_paginated(
params=params, order=order, sort=sort
)
_items = [ModelOut.from_model(model) for model in models.items]
Expand Down
8 changes: 4 additions & 4 deletions api/tests/dao/model_dao_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,24 @@ def test_delete(self):
assert rows == 1
assert retrieved is None

def test_get_all(self):
def test_get_all_paginated(self):
model1 = db_mock.get_sample_model(id=1, uuid=uuid.uuid4(), name='model1')
model2 = db_mock.get_sample_model(id=2, uuid=uuid.uuid4(), name='model2')
model3 = db_mock.get_sample_model(id=3, uuid=uuid.uuid4(), name='model3')
self.model_dao.insert(model1)
self.model_dao.insert(model2)
self.model_dao.insert(model3)
models = self.model_dao.get_all()
models = self.model_dao.get_all_paginated()
assert models.items[0].uuid == model1.uuid
assert len(models.items) == 3

def test_get_all_ordered(self):
def test_get_all_paginated_ordered(self):
model1 = db_mock.get_sample_model(id=1, uuid=uuid.uuid4(), name='first_model')
model2 = db_mock.get_sample_model(id=2, uuid=uuid.uuid4(), name='second_model')
model3 = db_mock.get_sample_model(id=3, uuid=uuid.uuid4(), name='third_model')
self.model_dao.insert(model1)
self.model_dao.insert(model2)
self.model_dao.insert(model3)
models = self.model_dao.get_all(order=OrderType.DESC, sort='name')
models = self.model_dao.get_all_paginated(order=OrderType.DESC, sort='name')
assert models.items[0].name == model3.name
assert len(models.items) == 3
6 changes: 3 additions & 3 deletions api/tests/routes/model_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_delete_model(self):
assert jsonable_encoder(model_out) == res.json()
self.model_service.delete_model.assert_called_once_with(model.uuid)

def test_get_all_models(self):
def test_get_all_models_paginated(self):
model1 = db_mock.get_sample_model(id=1, uuid=uuid.uuid4(), name='model1')
model2 = db_mock.get_sample_model(id=2, uuid=uuid.uuid4(), name='model2')
model3 = db_mock.get_sample_model(id=3, uuid=uuid.uuid4(), name='model3')
Expand All @@ -83,12 +83,12 @@ def test_get_all_models(self):
page = Page.create(
items=sample_models_out, total=len(sample_models_out), params=Params()
)
self.model_service.get_all_models = MagicMock(return_value=page)
self.model_service.get_all_models_paginated = MagicMock(return_value=page)

res = self.client.get(f'{self.prefix}')
assert res.status_code == 200
assert jsonable_encoder(page) == res.json()
self.model_service.get_all_models.assert_called_once_with(
self.model_service.get_all_models_paginated.assert_called_once_with(
params=Params(page=1, size=50), order=OrderType.ASC, sort=None
)

Expand Down
8 changes: 4 additions & 4 deletions api/tests/services/model_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_delete_model_ok(self):

assert res == ModelOut.from_model(model)

def test_get_all_models_ok(self):
def test_get_all_models_paginated_ok(self):
model1 = db_mock.get_sample_model(id=1, uuid=uuid.uuid4(), name='model1')
model2 = db_mock.get_sample_model(id=2, uuid=uuid.uuid4(), name='model2')
model3 = db_mock.get_sample_model(id=3, uuid=uuid.uuid4(), name='model3')
Expand All @@ -65,13 +65,13 @@ def test_get_all_models_ok(self):
order=OrderType.ASC,
sort=None,
)
self.model_dao.get_all = MagicMock(return_value=page)
self.model_dao.get_all_paginated = MagicMock(return_value=page)

result = self.model_service.get_all_models(
result = self.model_service.get_all_models_paginated(
params=Params(page=1, size=10), order=OrderType.ASC, sort=None
)

self.model_dao.get_all.assert_called_once_with(
self.model_dao.get_all_paginated.assert_called_once_with(
params=Params(page=1, size=10), order=OrderType.ASC, sort=None
)

Expand Down
19 changes: 6 additions & 13 deletions sdk/radicalbit_platform_sdk/client.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from typing import List
from uuid import UUID

from pydantic import ValidationError
from pydantic import TypeAdapter, ValidationError
import requests

from radicalbit_platform_sdk.apis import Model
from radicalbit_platform_sdk.commons import invoke
from radicalbit_platform_sdk.errors import ClientError
from radicalbit_platform_sdk.models import (
CreateModel,
ModelDefinition,
PaginatedModelDefinitions,
)
from radicalbit_platform_sdk.models import CreateModel, ModelDefinition


class Client:
Expand Down Expand Up @@ -52,18 +48,15 @@ def __callback(response: requests.Response) -> Model:
def search_models(self) -> List[Model]:
def __callback(response: requests.Response) -> List[Model]:
try:
paginated_response = PaginatedModelDefinitions.model_validate(
response.json()
)
return [
Model(self.__base_url, model) for model in paginated_response.items
]
adapter = TypeAdapter(List[ModelDefinition])
model_definitions = adapter.validate_python(response.json())
return [Model(self.__base_url, model) for model in model_definitions]
except ValidationError as e:
raise ClientError(f'Unable to parse response: {response.text}') from e

return invoke(
method='GET',
url=f'{self.__base_url}/api/models',
url=f'{self.__base_url}/api/models/all',
valid_response_code=200,
func=__callback,
)
2 changes: 0 additions & 2 deletions sdk/radicalbit_platform_sdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
Granularity,
ModelDefinition,
OutputType,
PaginatedModelDefinitions,
)
from .model_type import ModelType

Expand Down Expand Up @@ -77,7 +76,6 @@
'BinaryClassDrift',
'MultiClassDrift',
'RegressionDrift',
'PaginatedModelDefinitions',
'ReferenceFileUpload',
'CurrentFileUpload',
'FileReference',
Expand Down
6 changes: 0 additions & 6 deletions sdk/radicalbit_platform_sdk/models/model_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,3 @@ class ModelDefinition(BaseModelDefinition):
updated_at: str = Field(alias='updatedAt')

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class PaginatedModelDefinitions(BaseModel):
items: List[ModelDefinition]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
53 changes: 24 additions & 29 deletions sdk/tests/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
ModelDefinition,
ModelType,
OutputType,
PaginatedModelDefinitions,
)


Expand Down Expand Up @@ -180,45 +179,41 @@ def test_create_model(self):
def test_search_models(self):
base_url = 'http://api:9000'

paginated_response = PaginatedModelDefinitions(
items=[
ModelDefinition(
name='My Model',
model_type=ModelType.BINARY,
data_type=DataType.TABULAR,
granularity=Granularity.DAY,
features=[ColumnDefinition(name='feature_column', type='string')],
outputs=OutputType(
prediction=ColumnDefinition(name='result_column', type='int'),
output=[ColumnDefinition(name='result_column', type='int')],
),
target=ColumnDefinition(name='target_column', type='string'),
timestamp=ColumnDefinition(name='tst_column', type='string'),
created_at=str(time.time()),
updated_at=str(time.time()),
)
]
model_definition = ModelDefinition(
name='My Model',
model_type=ModelType.BINARY,
data_type=DataType.TABULAR,
granularity=Granularity.DAY,
features=[ColumnDefinition(name='feature_column', type='string')],
outputs=OutputType(
prediction=ColumnDefinition(name='result_column', type='int'),
output=[ColumnDefinition(name='result_column', type='int')],
),
target=ColumnDefinition(name='target_column', type='string'),
timestamp=ColumnDefinition(name='tst_column', type='string'),
created_at=str(time.time()),
updated_at=str(time.time()),
)

responses.add(
method=responses.GET,
url=f'{base_url}/api/models',
body=paginated_response.model_dump_json(),
url=f'{base_url}/api/models/all',
body=f'[{model_definition.model_dump_json()}]',
status=200,
content_type='application/json',
)

client = Client(base_url)
models = client.search_models()
assert len(models) == 1
assert models[0].name() == paginated_response.items[0].name
assert models[0].model_type() == paginated_response.items[0].model_type
assert models[0].data_type() == paginated_response.items[0].data_type
assert models[0].granularity() == paginated_response.items[0].granularity
assert models[0].features() == paginated_response.items[0].features
assert models[0].outputs() == paginated_response.items[0].outputs
assert models[0].target() == paginated_response.items[0].target
assert models[0].timestamp() == paginated_response.items[0].timestamp
assert models[0].name() == model_definition.name
assert models[0].model_type() == model_definition.model_type
assert models[0].data_type() == model_definition.data_type
assert models[0].granularity() == model_definition.granularity
assert models[0].features() == model_definition.features
assert models[0].outputs() == model_definition.outputs
assert models[0].target() == model_definition.target
assert models[0].timestamp() == model_definition.timestamp
assert models[0].description() is None
assert models[0].algorithm() is None
assert models[0].frameworks() is None

0 comments on commit 06bd2ef

Please sign in to comment.