Skip to content

Commit

Permalink
feat(api): get all reference/current datasets not paginated (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
maocorte authored Jun 25, 2024
1 parent 06bd2ef commit 6967044
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 31 deletions.
13 changes: 12 additions & 1 deletion api/app/db/dao/current_dataset_dao.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Optional
from typing import List, Optional
from uuid import UUID

from fastapi_pagination import Page, Params
Expand Down Expand Up @@ -50,6 +50,17 @@ def get_latest_current_dataset_by_model_uuid(
def get_all_current_datasets_by_model_uuid(
self,
model_uuid: UUID,
) -> List[CurrentDataset]:
with self.db.begin_session() as session:
return (
session.query(CurrentDataset)
.order_by(desc(CurrentDataset.date))
.where(CurrentDataset.model_uuid == model_uuid)
)

def get_all_current_datasets_by_model_uuid_paginated(
self,
model_uuid: UUID,
params: Params = Params(),
order: OrderType = OrderType.ASC,
sort: Optional[str] = None,
Expand Down
11 changes: 11 additions & 0 deletions api/app/db/dao/reference_dataset_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ def get_reference_dataset_by_model_uuid(
def get_all_reference_datasets_by_model_uuid(
self,
model_uuid: UUID,
) -> Page[ReferenceDataset]:
with self.db.begin_session() as session:
return (
session.query(ReferenceDataset)
.order_by(desc(ReferenceDataset.date))
.where(ReferenceDataset.model_uuid == model_uuid)
)

def get_all_reference_datasets_by_model_uuid_paginated(
self,
model_uuid: UUID,
params: Params = Params(),
order: OrderType = OrderType.ASC,
sort: Optional[str] = None,
Expand Down
30 changes: 25 additions & 5 deletions api/app/routes/upload_dataset_route.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Annotated, Optional
from typing import Annotated, List, Optional
from uuid import UUID

from fastapi import APIRouter, File, Form, UploadFile, status
Expand Down Expand Up @@ -69,33 +69,53 @@ def bind_current_file(
status_code=200,
response_model=Page[ReferenceDatasetDTO],
)
def get_all_reference_datasets_by_model_uuid(
def get_all_reference_datasets_by_model_uuid_paginated(
model_uuid: UUID,
_page: Annotated[int, Query()] = 1,
_limit: Annotated[int, Query()] = 50,
_order: Annotated[OrderType, Query()] = OrderType.ASC,
_sort: Annotated[Optional[str], Query()] = None,
):
params = Params(page=_page, size=_limit)
return file_service.get_all_reference_datasets_by_model_uuid(
return file_service.get_all_reference_datasets_by_model_uuid_paginated(
model_uuid, params=params, order=_order, sort=_sort
)

@router.get(
'/{model_uuid}/reference/all',
status_code=200,
response_model=List[ReferenceDatasetDTO],
)
def get_all_reference_datasets_by_model_uuid(
model_uuid: UUID,
):
return file_service.get_all_reference_datasets_by_model_uuid(model_uuid)

@router.get(
'/{model_uuid}/current',
status_code=200,
response_model=Page[CurrentDatasetDTO],
)
def get_all_current_datasets_by_model_uuid(
def get_all_current_datasets_by_model_uuid_paginated(
model_uuid: UUID,
_page: Annotated[int, Query()] = 1,
_limit: Annotated[int, Query()] = 50,
_order: Annotated[OrderType, Query()] = OrderType.ASC,
_sort: Annotated[Optional[str], Query()] = None,
):
params = Params(page=_page, size=_limit)
return file_service.get_all_current_datasets_by_model_uuid(
return file_service.get_all_current_datasets_by_model_uuid_paginated(
model_uuid, params=params, order=_order, sort=_sort
)

@router.get(
'/{model_uuid}/current/all',
status_code=200,
response_model=List[CurrentDatasetDTO],
)
def get_all_current_datasets_by_model_uuid(
model_uuid: UUID,
):
return file_service.get_all_current_datasets_by_model_uuid(model_uuid)

return router
29 changes: 25 additions & 4 deletions api/app/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,15 +343,15 @@ def bind_current_file(
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

def get_all_reference_datasets_by_model_uuid(
def get_all_reference_datasets_by_model_uuid_paginated(
self,
model_uuid: UUID,
params: Params = Params(),
order: OrderType = OrderType.ASC,
sort: Optional[str] = None,
) -> Page[ReferenceDatasetDTO]:
results: Page[ReferenceDatasetDTO] = (
self.rd_dao.get_all_reference_datasets_by_model_uuid(
self.rd_dao.get_all_reference_datasets_by_model_uuid_paginated(
model_uuid, params=params, order=order, sort=sort
)
)
Expand All @@ -363,15 +363,26 @@ def get_all_reference_datasets_by_model_uuid(

return Page.create(items=_items, params=params, total=results.total)

def get_all_current_datasets_by_model_uuid(
def get_all_reference_datasets_by_model_uuid(
self,
model_uuid: UUID,
) -> List[ReferenceDatasetDTO]:
references = self.rd_dao.get_all_reference_datasets_by_model_uuid(model_uuid)

return [
ReferenceDatasetDTO.from_reference_dataset(reference)
for reference in references
]

def get_all_current_datasets_by_model_uuid_paginated(
self,
model_uuid: UUID,
params: Params = Params(),
order: OrderType = OrderType.ASC,
sort: Optional[str] = None,
) -> Page[CurrentDatasetDTO]:
results: Page[CurrentDatasetDTO] = (
self.cd_dao.get_all_current_datasets_by_model_uuid(
self.cd_dao.get_all_current_datasets_by_model_uuid_paginated(
model_uuid, params=params, order=order, sort=sort
)
)
Expand All @@ -383,6 +394,16 @@ def get_all_current_datasets_by_model_uuid(

return Page.create(items=_items, params=params, total=results.total)

def get_all_current_datasets_by_model_uuid(
self,
model_uuid: UUID,
) -> List[CurrentDatasetDTO]:
currents = self.cd_dao.get_all_current_datasets_by_model_uuid(model_uuid)
return [
CurrentDatasetDTO.from_current_dataset(current_dataset)
for current_dataset in currents
]

@staticmethod
def infer_schema(csv_file: UploadFile, sep: str = ',') -> InferredSchemaDTO:
FileService.validate_file(csv_file, sep)
Expand Down
8 changes: 5 additions & 3 deletions api/tests/dao/current_dataset_dao_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_get_latest_current_dataset_by_model_uuid(self):
assert inserted_two.model_uuid == retrieved.model_uuid
assert inserted_two.path == retrieved.path

def test_get_all_current_datasets_by_model_uuid(self):
def test_get_all_current_datasets_by_model_uuid_paginated(self):
model = self.model_dao.insert(db_mock.get_sample_model())
current_upload_1 = CurrentDataset(
uuid=uuid4(),
Expand All @@ -104,8 +104,10 @@ def test_get_all_current_datasets_by_model_uuid(self):
inserted_2 = self.current_dataset_dao.insert_current_dataset(current_upload_2)
inserted_3 = self.current_dataset_dao.insert_current_dataset(current_upload_3)

retrieved = self.current_dataset_dao.get_all_current_datasets_by_model_uuid(
model.uuid, Params(page=1, size=10)
retrieved = (
self.current_dataset_dao.get_all_current_datasets_by_model_uuid_paginated(
model.uuid, Params(page=1, size=10)
)
)

assert inserted_1.uuid == retrieved.items[0].uuid
Expand Down
9 changes: 4 additions & 5 deletions api/tests/dao/reference_dataset_dao_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_get_reference_dataset_by_model_uuid(self):
assert inserted.model_uuid == retrieved.model_uuid
assert inserted.path == retrieved.path

def test_get_all_reference_datasets_by_model_uuid(self):
def test_get_all_reference_datasets_by_model_uuid_paginated(self):
model = self.model_dao.insert(db_mock.get_sample_model())
reference_upload_1 = ReferenceDataset(
uuid=uuid4(),
Expand Down Expand Up @@ -76,10 +76,8 @@ def test_get_all_reference_datasets_by_model_uuid(self):
reference_upload_3
)

retrieved = (
self.f_reference_dataset_dao.get_all_reference_datasets_by_model_uuid(
model.uuid, Params(page=1, size=10)
)
retrieved = self.f_reference_dataset_dao.get_all_reference_datasets_by_model_uuid_paginated(
model.uuid, Params(page=1, size=10)
)

assert inserted_1.uuid == retrieved.items[0].uuid
Expand All @@ -95,3 +93,4 @@ def test_get_all_reference_datasets_by_model_uuid(self):
assert inserted_3.path == retrieved.items[2].path

assert len(retrieved.items) == 3

70 changes: 63 additions & 7 deletions api/tests/routes/upload_dataset_route_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_bind_reference(self):
assert res.status_code == 200
assert jsonable_encoder(upload_file_result) == res.json()

def test_get_all_reference_datasets_by_model_uuid(self):
def test_get_all_reference_datasets_by_model_uuid_paginated(self):
test_model_uuid = uuid.uuid4()
reference_upload_1 = db_mock.get_sample_reference_dataset(
model_uuid=test_model_uuid, path='reference/test_1.csv'
Expand All @@ -93,21 +93,21 @@ def test_get_all_reference_datasets_by_model_uuid(self):
page = Page.create(
items=sample_results, total=len(sample_results), params=Params()
)
self.file_service.get_all_reference_datasets_by_model_uuid = MagicMock(
return_value=page
self.file_service.get_all_reference_datasets_by_model_uuid_paginated = (
MagicMock(return_value=page)
)

res = self.client.get(f'{self.prefix}/{test_model_uuid}/reference')
assert res.status_code == 200
assert jsonable_encoder(page) == res.json()
self.file_service.get_all_reference_datasets_by_model_uuid.assert_called_once_with(
self.file_service.get_all_reference_datasets_by_model_uuid_paginated.assert_called_once_with(
test_model_uuid,
params=Params(page=1, size=50),
order=OrderType.ASC,
sort=None,
)

def test_get_all_current_datasets_by_model_uuid(self):
def test_get_all_current_datasets_by_model_uuid_paginated(self):
test_model_uuid = uuid.uuid4()
current_upload_1 = db_mock.get_sample_current_dataset(
model_uuid=test_model_uuid, path='reference/test_1.csv'
Expand All @@ -127,16 +127,72 @@ def test_get_all_current_datasets_by_model_uuid(self):
page = Page.create(
items=sample_results, total=len(sample_results), params=Params()
)
self.file_service.get_all_current_datasets_by_model_uuid = MagicMock(
self.file_service.get_all_current_datasets_by_model_uuid_paginated = MagicMock(
return_value=page
)

res = self.client.get(f'{self.prefix}/{test_model_uuid}/current')
assert res.status_code == 200
assert jsonable_encoder(page) == res.json()
self.file_service.get_all_current_datasets_by_model_uuid.assert_called_once_with(
self.file_service.get_all_current_datasets_by_model_uuid_paginated.assert_called_once_with(
test_model_uuid,
params=Params(page=1, size=50),
order=OrderType.ASC,
sort=None,
)

def test_get_all_reference_datasets_by_model_uuid(self):
test_model_uuid = uuid.uuid4()
reference_upload_1 = db_mock.get_sample_reference_dataset(
model_uuid=test_model_uuid, path='reference/test_1.csv'
)
reference_upload_2 = db_mock.get_sample_reference_dataset(
model_uuid=test_model_uuid, path='reference/test_2.csv'
)
reference_upload_3 = db_mock.get_sample_reference_dataset(
model_uuid=test_model_uuid, path='reference/test_3.csv'
)

sample_results = [
ReferenceDatasetDTO.from_reference_dataset(reference_upload_1),
ReferenceDatasetDTO.from_reference_dataset(reference_upload_2),
ReferenceDatasetDTO.from_reference_dataset(reference_upload_3),
]
self.file_service.get_all_reference_datasets_by_model_uuid = MagicMock(
return_value=sample_results
)

res = self.client.get(f'{self.prefix}/{test_model_uuid}/reference/all')
assert res.status_code == 200
assert jsonable_encoder(sample_results) == res.json()
self.file_service.get_all_reference_datasets_by_model_uuid.assert_called_once_with(
test_model_uuid,
)

def test_get_all_current_datasets_by_model_uuid(self):
test_model_uuid = uuid.uuid4()
current_upload_1 = db_mock.get_sample_current_dataset(
model_uuid=test_model_uuid, path='reference/test_1.csv'
)
current_upload_2 = db_mock.get_sample_current_dataset(
model_uuid=test_model_uuid, path='reference/test_2.csv'
)
current_upload_3 = db_mock.get_sample_current_dataset(
model_uuid=test_model_uuid, path='reference/test_3.csv'
)

sample_results = [
CurrentDatasetDTO.from_current_dataset(current_upload_1),
CurrentDatasetDTO.from_current_dataset(current_upload_2),
CurrentDatasetDTO.from_current_dataset(current_upload_3),
]
self.file_service.get_all_current_datasets_by_model_uuid = MagicMock(
return_value=sample_results
)

res = self.client.get(f'{self.prefix}/{test_model_uuid}/current/all')
assert res.status_code == 200
assert jsonable_encoder(sample_results) == res.json()
self.file_service.get_all_current_datasets_by_model_uuid.assert_called_once_with(
test_model_uuid,
)
Loading

0 comments on commit 6967044

Please sign in to comment.