diff --git a/.gitignore b/.gitignore index f15b188..ef36a6e 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,7 @@ cython_debug/ #.idea/ # Development databases -*.sqlite3 \ No newline at end of file +*.sqlite3 + +# Development storage +storage/ \ No newline at end of file diff --git a/nad_ch/application_context.py b/nad_ch/application_context.py index ede8691..7da70f5 100644 --- a/nad_ch/application_context.py +++ b/nad_ch/application_context.py @@ -1,40 +1,67 @@ import os import logging +from nad_ch.config import STORAGE_PATH from nad_ch.infrastructure.database import ( session_scope, - SqlAlchemyDataProviderRepository + SqlAlchemyDataProviderRepository, + SqlAlchemyDataSubmissionRepository ) from nad_ch.infrastructure.logger import Logger -from tests.mocks import MockDataProviderRepository +from nad_ch.infrastructure.storage import LocalStorage +from tests.fakes import ( + FakeDataProviderRepository, + FakeDataSubmissionRepository, + FakeStorage +) class ApplicationContext: def __init__(self): self._providers = SqlAlchemyDataProviderRepository(session_scope) + self._submissions = SqlAlchemyDataSubmissionRepository(session_scope) self._logger = Logger(__name__) + self._storage = LocalStorage(STORAGE_PATH) @property def providers(self): return self._providers + @property + def submissions(self): + return self._submissions + @property def logger(self): return self._logger + @property + def storage(self): + return self._storage + class TestApplicationContext(ApplicationContext): def __init__(self): - self._providers = MockDataProviderRepository() + self._providers = FakeDataProviderRepository() + self._submissions = FakeDataSubmissionRepository() self._logger = Logger(__name__, logging.DEBUG) + self._storage = FakeStorage() @property def providers(self): return self._providers + @property + def submissions(self): + return self._submissions + @property def logger(self): return self._logger + @property + def storage(self): + return self._storage + def create_app_context(): if os.environ.get('APP_ENV') == 'test': diff --git a/nad_ch/config.py b/nad_ch/config.py index c61f8f5..47a2b43 100644 --- a/nad_ch/config.py +++ b/nad_ch/config.py @@ -7,3 +7,4 @@ APP_ENV = os.getenv('APP_ENV') DATABASE_URL = os.getenv('DATABASE_URL') +STORAGE_PATH = os.getenv('STORAGE_PATH') diff --git a/nad_ch/controllers/cli.py b/nad_ch/controllers/cli.py index 00398de..4450c11 100644 --- a/nad_ch/controllers/cli.py +++ b/nad_ch/controllers/cli.py @@ -3,6 +3,7 @@ add_data_provider, list_data_providers, ingest_data_submission, + list_data_submissions_by_provider ) @@ -29,8 +30,16 @@ def list_providers(ctx): @cli.command() @click.pass_context -@click.argument('filepath') +@click.argument('file_path') @click.argument('provider') def ingest(ctx, file_path, provider): context = ctx.obj ingest_data_submission(context, file_path, provider) + + +@cli.command() +@click.pass_context +@click.argument('provider') +def list_submissions_by_provider(ctx, provider): + context = ctx.obj + list_data_submissions_by_provider(context, provider) diff --git a/nad_ch/domain/entities.py b/nad_ch/domain/entities.py index 1f3d65e..4b9c903 100644 --- a/nad_ch/domain/entities.py +++ b/nad_ch/domain/entities.py @@ -1,11 +1,40 @@ -class DataProvider: - def __init__(self, name: str, id: int = None): +class Entity: + def __init__(self, id: int = None): self.id = id + self.created_at = None + self.updated_at = None + + def set_created_at(self, created_at): + self.created_at = created_at + + def set_updated_at(self, updated_at): + self.updated_at = updated_at + + +class DataProvider(Entity): + def __init__(self, name: str, id: int = None): + super().__init__(id) self.name = name + def __repr__(self): + return f'DataProvider {self.id}, {self.name} \ + (created: {self.created_at}; updated: {self.updated_at})' -class DataSubmission: - def __init__(self, file_path: str, provider: DataProvider, id: int = None): - self.id = id - self.file_path = file_path + +class DataSubmission(Entity): + def __init__( + self, + file_name: str, + url: str, + provider: DataProvider, + id: int = None, + ): + super().__init__(id) + self.file_name = file_name + self.url = url self.provider = provider + + def __repr__(self): + return f'DataSubmission \ + {self.id}, {self.file_name}, {self.url}, {self.provider} \ + (created: {self.created_at}; updated: {self.updated_at})' diff --git a/nad_ch/domain/repositories.py b/nad_ch/domain/repositories.py index cf9c1fa..d463fcd 100644 --- a/nad_ch/domain/repositories.py +++ b/nad_ch/domain/repositories.py @@ -1,10 +1,10 @@ from typing import Protocol from collections.abc import Iterable -from nad_ch.domain.entities import DataProvider +from nad_ch.domain.entities import DataProvider, DataSubmission class DataProviderRepository(Protocol): - def add(self, provider: DataProvider) -> None: + def add(self, provider: DataProvider) -> DataProvider: ... def get_by_name(self, name: str) -> DataProvider: @@ -12,3 +12,11 @@ def get_by_name(self, name: str) -> DataProvider: def get_all(self) -> Iterable[DataProvider]: ... + + +class DataSubmissionRepository(Protocol): + def add(self, submission: DataSubmission) -> DataSubmission: + ... + + def get_by_provider(self, provider: DataProvider) -> Iterable[DataSubmission]: + ... diff --git a/nad_ch/infrastructure/database.py b/nad_ch/infrastructure/database.py index 2f34234..39148d4 100644 --- a/nad_ch/infrastructure/database.py +++ b/nad_ch/infrastructure/database.py @@ -1,10 +1,11 @@ from typing import List, Optional -from sqlalchemy import Column, Integer, String, create_engine -from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy import Column, Integer, String, create_engine, ForeignKey, DateTime +from sqlalchemy.orm import sessionmaker, declarative_base, relationship +from sqlalchemy.sql import func import contextlib from nad_ch.config import DATABASE_URL -from nad_ch.domain.entities import DataProvider -from nad_ch.domain.repositories import DataProviderRepository +from nad_ch.domain.entities import DataProvider, DataSubmission +from nad_ch.domain.repositories import DataProviderRepository, DataSubmissionRepository engine = create_engine(DATABASE_URL) @@ -27,28 +28,89 @@ def session_scope(): ModelBase = declarative_base() -class DataProviderModel(ModelBase): - __tablename__ = 'data_providers' +class CommonBase(ModelBase): + __abstract__ = True id = Column(Integer, primary_key=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + +class DataProviderModel(CommonBase): + __tablename__ = 'data_providers' + name = Column(String) + data_submissions = relationship( + 'DataSubmissionModel', + back_populates='data_provider' + ) + @staticmethod def from_entity(provider): - return DataProviderModel(id=provider.id, name=provider.name) + model = DataProviderModel(id=provider.id, name=provider.name) + return model def to_entity(self): - return DataProvider(id=self.id, name=self.name) + entity = DataProvider(id=self.id, name=self.name) + + if self.created_at is not None: + entity.set_created_at(self.created_at) + + if self.updated_at is not None: + entity.set_updated_at(self.updated_at) + + return entity + + +class DataSubmissionModel(CommonBase): + __tablename__ = 'data_submissions' + + file_name = Column(String) + url = Column(String) + data_provider_id = Column(Integer, ForeignKey('data_providers.id')) + + data_provider = relationship('DataProviderModel', back_populates='data_submissions') + + @staticmethod + def from_entity(submission): + model = DataSubmissionModel( + id=submission.id, + file_name=submission.file_name, + url=submission.url, + data_provider_id=submission.provider.id + ) + return model + + def to_entity(self, provider: DataProvider): + entity = DataSubmission( + id=self.id, + file_name=self.file_name, + url=self.url, + provider=provider + ) + + if self.created_at is not None: + entity.set_created_at(self.created_at) + + if self.updated_at is not None: + entity.set_updated_at(self.updated_at) + + return entity class SqlAlchemyDataProviderRepository(DataProviderRepository): def __init__(self, session_factory): self.session_factory = session_factory - def add(self, provider: DataProvider): + def add(self, provider: DataProvider) -> DataProvider: with self.session_factory() as session: provider_model = DataProviderModel.from_entity(provider) session.add(provider_model) + session.commit() + session.refresh(provider_model) return provider_model.to_entity() def get_by_name(self, name: str) -> Optional[DataProvider]: @@ -65,3 +127,51 @@ def get_all(self) -> List[DataProvider]: provider_models = session.query(DataProviderModel).all() providers_entities = [provider.to_entity() for provider in provider_models] return providers_entities + + +class SqlAlchemyDataSubmissionRepository(DataSubmissionRepository): + def __init__(self, session_factory): + self.session_factory = session_factory + + def add(self, submission: DataSubmission) -> DataSubmission: + with self.session_factory() as session: + submission_model = DataSubmissionModel.from_entity(submission) + session.add(submission_model) + session.commit() + session.refresh(submission_model) + provider_model = ( + session.query(DataProviderModel) + .filter(DataProviderModel.id == submission_model.data_provider_id) + .first() + ) + return submission_model.to_entity(provider_model.to_entity()) + + def get_by_name(self, file_name: str) -> Optional[DataSubmission]: + with self.session_factory() as session: + result = ( + session.query(DataSubmissionModel, DataProviderModel) + .join( + DataProviderModel, DataProviderModel.id == + DataSubmissionModel.data_provider_id + ) + .filter(DataSubmissionModel.file_name == file_name) + .first() + ) + + if result: + submission_model, provider_model = result + return submission_model.to_entity(provider_model.to_entity()) + else: + return None + + def get_by_provider(self, provider: DataProvider) -> List[DataSubmission]: + with self.session_factory() as session: + submission_models = ( + session.query(DataSubmissionModel) + .filter(DataSubmissionModel.data_provider_id == provider.id) + .all() + ) + submission_entities = ( + [submission.to_entity(provider) for submission in submission_models] + ) + return submission_entities diff --git a/nad_ch/infrastructure/storage.py b/nad_ch/infrastructure/storage.py index e69de29..4b83e68 100644 --- a/nad_ch/infrastructure/storage.py +++ b/nad_ch/infrastructure/storage.py @@ -0,0 +1,21 @@ +import os +import shutil + + +class LocalStorage: + def __init__(self, base_path: str): + self.base_path = base_path + + def _full_path(self, path: str) -> str: + return os.path.join(self.base_path, path) + + def upload(self, source: str, destination: str) -> None: + shutil.copy(source, self._full_path(destination)) + + def delete(self, file_path: str) -> None: + full_file_path = self._full_path(file_path) + if os.path.exists(full_file_path): + os.remove(full_file_path) + + def get_file_url(self, file_name: str) -> str: + return file_name diff --git a/nad_ch/use_cases.py b/nad_ch/use_cases.py index cbf764b..acc0ed5 100644 --- a/nad_ch/use_cases.py +++ b/nad_ch/use_cases.py @@ -1,6 +1,6 @@ from typing import List from nad_ch.application_context import ApplicationContext -from nad_ch.domain.entities import DataProvider +from nad_ch.domain.entities import DataProvider, DataSubmission def add_data_provider( @@ -32,4 +32,38 @@ def list_data_providers(ctx: ApplicationContext) -> List[DataProvider]: def ingest_data_submission( ctx: ApplicationContext, file_path: str, provider_name: str ) -> None: - pass + if not file_path: + ctx.logger.error('File path required') + return + + provider = ctx.providers.get_by_name(provider_name) + if not provider: + ctx.logger.error('Provider with that name does not exist') + return + + try: + ctx.storage.upload(file_path, f'{provider.name}_{file_path}') + url = ctx.storage.get_file_url(file_path) + + submission = DataSubmission(file_path, url, provider) + ctx.submissions.add(submission) + ctx.logger.info('Submission added') + except Exception as e: + ctx.storage.delete(file_path) + ctx.logger.error(f'Failed to process submission: {e}') + + +def list_data_submissions_by_provider( + ctx: ApplicationContext, provider_name: str +) -> List[DataSubmission]: + provider = ctx.providers.get_by_name(provider_name) + if not provider: + ctx.logger.error('Provider with that name does not exist') + return + + submissions = ctx.submissions.get_by_provider(provider) + ctx.logger.info(f'Data submissions for {provider.name}') + for s in submissions: + ctx.logger.info(f'{s.provider.name}: {s.file_name}') + + return submissions diff --git a/tests/fakes.py b/tests/fakes.py new file mode 100644 index 0000000..c1531ea --- /dev/null +++ b/tests/fakes.py @@ -0,0 +1,52 @@ +from typing import Optional +from nad_ch.domain.entities import DataProvider, DataSubmission +from nad_ch.domain.repositories import DataProviderRepository, DataSubmissionRepository + + +class FakeDataProviderRepository(DataProviderRepository): + def __init__(self) -> None: + self._providers = set() + self._next_id = 1 + + def add(self, provider: DataProvider) -> DataProvider: + provider.id = self._next_id + self._providers.add(provider) + self._next_id += 1 + return provider + + def get_by_name(self, name: str) -> Optional[DataProvider]: + return next((p for p in self._providers if p.name == name), None) + + def get_all(self): + return sorted(list(self._providers), key=lambda provider: provider.id) + + +class FakeDataSubmissionRepository(DataSubmissionRepository): + def __init__(self) -> None: + self._submissions = set() + self._next_id = 1 + + def add(self, submission: DataSubmission) -> DataSubmission: + submission.id = self._next_id + self._submissions.add(submission) + self._next_id += 1 + return submission + + def get_by_name(self, file_name: str) -> Optional[DataSubmission]: + return next( + (s for s in self._submissions if s.file_name == file_name), None + ) + + def get_by_provider(self, provider: DataProvider) -> Optional[DataSubmission]: + return [s for s in self._submissions if s.provider.name == provider.name] + + +class FakeStorage(): + def __init__(self): + self._files = set() + + def upload(self, source: str, destination: str) -> None: + self._files.add(destination) + + def get_file_url(self, file_name: str) -> str: + return file_name diff --git a/tests/infrastructure/test_database.py b/tests/infrastructure/test_database.py index 953167f..b01304a 100644 --- a/tests/infrastructure/test_database.py +++ b/tests/infrastructure/test_database.py @@ -3,8 +3,12 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from nad_ch.config import DATABASE_URL -from nad_ch.domain.entities import DataProvider -from nad_ch.infrastructure.database import ModelBase, SqlAlchemyDataProviderRepository +from nad_ch.domain.entities import DataProvider, DataSubmission +from nad_ch.infrastructure.database import ( + ModelBase, + SqlAlchemyDataProviderRepository, + SqlAlchemyDataSubmissionRepository +) @pytest.fixture(scope='function') @@ -33,13 +37,53 @@ def providers(test_session): return SqlAlchemyDataProviderRepository(test_session) +@pytest.fixture(scope='function') +def submissions(test_session): + return SqlAlchemyDataSubmissionRepository(test_session) + + def test_add_data_provider_to_repository_and_get_by_name(providers): provider_name = 'State X' new_provider = DataProvider(provider_name) providers.add(new_provider) - retreived_provider = providers.get_by_name(provider_name) - assert retreived_provider.id == 1 - assert retreived_provider.name == provider_name - assert isinstance(retreived_provider, DataProvider) is True + retrieved_provider = providers.get_by_name(provider_name) + assert retrieved_provider.id == 1 + assert retrieved_provider.created_at is not None + assert retrieved_provider.updated_at is not None + assert retrieved_provider.name == provider_name + assert isinstance(retrieved_provider, DataProvider) is True + + +def test_add_data_provider_and_then_data_submission(providers, submissions): + provider_name = 'State X' + new_provider = DataProvider(provider_name) + saved_provider = providers.add(new_provider) + new_submission = DataSubmission( + 'some-file-name', 'some-url', saved_provider) + + result = submissions.add(new_submission) + + assert result.id == 1 + assert result.created_at is not None + assert result.updated_at is not None + assert result.provider.id == saved_provider.id + assert result.file_name == 'some-file-name' + assert result.url == 'some-url' + + +def test_retrieve_a_list_of_submissions_by_provider(providers, submissions): + provider_name = 'State X' + new_provider = DataProvider(provider_name) + saved_provider = providers.add(new_provider) + new_submission = DataSubmission( + 'some-file-name', 'some-url', saved_provider) + submissions.add(new_submission) + another_new_submission = DataSubmission( + 'some-other-file-name', 'some-other-url', saved_provider) + submissions.add(another_new_submission) + + submissions = submissions.get_by_provider(saved_provider) + + assert len(submissions) == 2 diff --git a/tests/mocks.py b/tests/mocks.py deleted file mode 100644 index 1b80e2b..0000000 --- a/tests/mocks.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Optional -from nad_ch.domain.entities import DataProvider -from nad_ch.domain.repositories import DataProviderRepository - - -class MockDataProviderRepository(DataProviderRepository): - def __init__(self) -> None: - self._providers = set() - self._next_id = 1 - - def add(self, provider: DataProvider) -> None: - provider.id = self._next_id - self._providers.add(provider) - self._next_id += 1 - - def get_by_name(self, name: str) -> Optional[DataProvider]: - return next((p for p in self._providers if p.name == name), None) - - def get_all(self): - return sorted(list(self._providers), key=lambda provider: provider.id) diff --git a/tests/test_use_cases.py b/tests/test_use_cases.py index 86589da..19b08d1 100644 --- a/tests/test_use_cases.py +++ b/tests/test_use_cases.py @@ -1,9 +1,10 @@ import pytest from nad_ch.application_context import create_app_context -from nad_ch.domain.entities import DataProvider +from nad_ch.domain.entities import DataProvider, DataSubmission from nad_ch.use_cases import ( add_data_provider, list_data_providers, + ingest_data_submission ) @@ -57,3 +58,27 @@ def test_list_multiple_data_providers(app_context): assert len(providers) == 2 assert providers[0].name == first_name assert providers[1].name == second_name + + +def test_ingest_data_submission(app_context): + provider_name = 'State X' + add_data_provider(app_context, provider_name) + + file_name = 'my_cool_file.txt' + ingest_data_submission(app_context, file_name, provider_name) + + submission = app_context.submissions.get_by_name(file_name) + assert submission.file_name == file_name + assert isinstance(submission, DataSubmission) is True + + +def test_list_data_submissions_by_provider(app_context): + provider_name = 'State X' + add_data_provider(app_context, provider_name) + + file_name = 'my_cool_file.txt' + ingest_data_submission(app_context, file_name, provider_name) + + provider = app_context.providers.get_by_name(provider_name) + submissions = app_context.submissions.get_by_provider(provider) + assert len(submissions) == 1