diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml index e50cea7..e472d75 100644 --- a/.github/workflows/testing.yaml +++ b/.github/workflows/testing.yaml @@ -26,6 +26,7 @@ jobs: 3.10 3.11 3.12 + 3.13 - run: make test tox: runs-on: ubuntu-latest @@ -38,6 +39,7 @@ jobs: 3.10 3.11 3.12 + 3.13 - run: make tox build: runs-on: ubuntu-latest diff --git a/CHANGES.md b/CHANGES.md index db0926b..8f52211 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,7 @@ +## Version 4.0.0 + +* Add `Validation` decorators. + ## Version 3.0.0 * Add method `paginate()` to `ReadMixin`. diff --git a/Makefile b/Makefile index 810c904..98974b9 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,7 @@ build: clean venv install: build $(PIP) install dist/db_first-*.tar.gz + $(PRE_COMMIT) install upload_to_testpypi: build $(PYTHON_VENV) -m twine upload --repository-url https://test.pypi.org/legacy/ dist/* diff --git a/README.md b/README.md index e7b3d94..e9527a0 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,9 @@ $ pip install -U db_first ### Full example ```python -from uuid import UUID - from db_first import BaseCRUD from db_first.base_model import ModelMixin +from db_first.decorators import Validation from db_first.mixins import CreateMixin from db_first.mixins import DeleteMixin from db_first.mixins import ReadMixin @@ -46,6 +45,8 @@ from db_first.mixins import UpdateMixin from marshmallow import fields from marshmallow import Schema from sqlalchemy import create_engine +from sqlalchemy import Result +from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import declarative_base from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -64,19 +65,19 @@ class Items(ModelMixin, Base): Base.metadata.create_all(engine) -class InputSchemaOfCreate(Schema): - data = fields.String() +class IdSchema(Schema): + id = fields.UUID() -class InputSchemaOfUpdate(InputSchemaOfCreate): - id = fields.UUID() +class SchemaOfCreate(Schema): + data = fields.String() -class InputSchemaOfRead(Schema): - id = fields.UUID() +class SchemaOfUpdate(IdSchema, SchemaOfCreate): + """Update item schema.""" -class OutputSchema(InputSchemaOfUpdate): +class OutputSchema(SchemaOfUpdate): created_at = fields.DateTime() @@ -84,40 +85,43 @@ class ItemController(CreateMixin, ReadMixin, UpdateMixin, DeleteMixin, BaseCRUD) class Meta: session = session model = Items - input_schema_of_create = InputSchemaOfCreate - input_schema_of_update = InputSchemaOfUpdate - output_schema_of_create = OutputSchema - input_schema_of_read = InputSchemaOfRead - output_schema_of_read = OutputSchema - output_schema_of_update = OutputSchema - schema_of_paginate = OutputSchema sortable = ['created_at'] + @Validation.input(SchemaOfCreate) + @Validation.output(OutputSchema, serialize=True) + def create(self, **data) -> Result: + return super().create_object(**data) -if __name__ == '__main__': - item = ItemController() + @Validation.input(IdSchema, keys=['id']) + @Validation.output(OutputSchema, serialize=True) + def read(self, **data) -> Result: + return super().read_object(data['id']) + + @Validation.input(SchemaOfUpdate) + @Validation.output(OutputSchema, serialize=True) + def update(self, **data) -> Result: + return super().update_object(**data) - first_new_item = item.create({'data': 'first'}, deserialize=True) - print('Item as object:', first_new_item) - second_new_item = item.create({'data': 'second'}, deserialize=True, serialize=True) - print('Item as dict:', second_new_item) + @Validation.input(IdSchema, keys=['id']) + def delete(self, **data) -> None: + super().delete_object(**data) - first_item = item.read({'id': first_new_item.id}) - print('Item as object:', first_item) - first_item = item.read({'id': first_new_item.id}) - print('Item as dict:', first_item) - updated_first_item = item.update(data={'id': first_new_item.id, 'data': 'updated_first'}) - print('Item as object:', updated_first_item) - updated_second_item = item.update( - data={'id': UUID(second_new_item['id']), 'data': 'updated_second'}, serialize=True - ) - print('Item as dict:', updated_second_item) +if __name__ == '__main__': + item_controller = ItemController() + + new_item = item_controller.create(data='first') + print('Item as dict:', new_item) - items = item.paginate(sort_created_at='desc') - print('Items as objects:', items) - items = item.paginate(sort_created_at='desc', serialize=True) - print('Items as dicts:', items) + item = item_controller.read(id=new_item['id']) + print('Item as dict:', item) + updated_item = item_controller.update(id=new_item['id'], data='updated_first') + print('Item as dict:', updated_item) + item_controller.delete(id=new_item['id']) + try: + item = item_controller.read(id=new_item['id']) + except NoResultFound: + print('Item deleted:', item) ``` diff --git a/examples/full_example.py b/examples/full_example.py index 95e3104..5c2e607 100644 --- a/examples/full_example.py +++ b/examples/full_example.py @@ -1,7 +1,6 @@ -from uuid import UUID - from db_first import BaseCRUD from db_first.base_model import ModelMixin +from db_first.decorators import Validation from db_first.mixins import CreateMixin from db_first.mixins import DeleteMixin from db_first.mixins import ReadMixin @@ -9,6 +8,8 @@ from marshmallow import fields from marshmallow import Schema from sqlalchemy import create_engine +from sqlalchemy import Result +from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import declarative_base from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column @@ -27,19 +28,19 @@ class Items(ModelMixin, Base): Base.metadata.create_all(engine) -class InputSchemaOfCreate(Schema): - data = fields.String() +class IdSchema(Schema): + id = fields.UUID() -class InputSchemaOfUpdate(InputSchemaOfCreate): - id = fields.UUID() +class SchemaOfCreate(Schema): + data = fields.String() -class InputSchemaOfRead(Schema): - id = fields.UUID() +class SchemaOfUpdate(IdSchema, SchemaOfCreate): + """Update item schema.""" -class OutputSchema(InputSchemaOfUpdate): +class OutputSchema(SchemaOfUpdate): created_at = fields.DateTime() @@ -47,37 +48,42 @@ class ItemController(CreateMixin, ReadMixin, UpdateMixin, DeleteMixin, BaseCRUD) class Meta: session = session model = Items - input_schema_of_create = InputSchemaOfCreate - input_schema_of_update = InputSchemaOfUpdate - output_schema_of_create = OutputSchema - input_schema_of_read = InputSchemaOfRead - output_schema_of_read = OutputSchema - output_schema_of_update = OutputSchema - schema_of_paginate = OutputSchema sortable = ['created_at'] + @Validation.input(SchemaOfCreate) + @Validation.output(OutputSchema, serialize=True) + def create(self, **data) -> Result: + return super().create_object(**data) + + @Validation.input(IdSchema, keys=['id']) + @Validation.output(OutputSchema, serialize=True) + def read(self, **data) -> Result: + return super().read_object(data['id']) + + @Validation.input(SchemaOfUpdate) + @Validation.output(OutputSchema, serialize=True) + def update(self, **data) -> Result: + return super().update_object(**data) + + @Validation.input(IdSchema, keys=['id']) + def delete(self, **data) -> None: + super().delete_object(**data) + if __name__ == '__main__': - item = ItemController() - - first_new_item = item.create({'data': 'first'}, deserialize=True) - print('Item as object:', first_new_item) - second_new_item = item.create({'data': 'second'}, deserialize=True, serialize=True) - print('Item as dict:', second_new_item) - - first_item = item.read({'id': first_new_item.id}) - print('Item as object:', first_item) - first_item = item.read({'id': first_new_item.id}) - print('Item as dict:', first_item) - - updated_first_item = item.update(data={'id': first_new_item.id, 'data': 'updated_first'}) - print('Item as object:', updated_first_item) - updated_second_item = item.update( - data={'id': UUID(second_new_item['id']), 'data': 'updated_second'}, serialize=True - ) - print('Item as dict:', updated_second_item) - - items = item.paginate(sort_created_at='desc') - print('Items as objects:', items) - items = item.paginate(sort_created_at='desc', serialize=True) - print('Items as dicts:', items) + item_controller = ItemController() + + new_item = item_controller.create(data='first') + print('Item as dict:', new_item) + + item = item_controller.read(id=new_item['id']) + print('Item as dict:', item) + + updated_item = item_controller.update(id=new_item['id'], data='updated_first') + print('Item as dict:', updated_item) + + item_controller.delete(id=new_item['id']) + try: + item = item_controller.read(id=new_item['id']) + except NoResultFound: + print('Item deleted:', item) diff --git a/pyproject.toml b/pyproject.toml index ee3ce75..7f7a1d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ license = {file = "LICENSE"} name = "DB-First" readme = "README.md" requires-python = ">=3.9" -version = "3.0.0" +version = "4.0.0" [project.optional-dependencies] dev = [ @@ -68,7 +68,7 @@ check-overridden = true check-property-returns = true check-protected = true check-protected-class-methods = true -disable = ["SIG101"] +disable = ["SIG101", "SIG501"] [tool.setuptools.packages.find] include = ["db_first*"] diff --git a/src/db_first/base.py b/src/db_first/base.py index f0dd0b5..0766496 100644 --- a/src/db_first/base.py +++ b/src/db_first/base.py @@ -1,8 +1,6 @@ from typing import Any from typing import Optional -from sqlalchemy.engine import Result - from .exc import MetaNotFound from .exc import OptionNotFound @@ -24,42 +22,3 @@ def _get_option_from_meta(cls, name: str, default: Optional[Any] = ...) -> Any: option = default return option - - @classmethod - def deserialize_data(cls, schema_name: str, data: dict) -> dict: - schema = cls._get_option_from_meta(schema_name) - return schema().load(data) - - @classmethod - def _clean_data(cls, data: Any) -> Any: - """Clearing hierarchical structures from empty values. - - Cleaning occurs for objects of the list and dict types, other types do not clean. - - :param data: an object for cleaning. - :return: cleaned object. - """ - - empty_values = ('', None, ..., [], {}, (), set()) - - if isinstance(data, dict): - cleaned_dict = {k: cls._clean_data(v) for k, v in data.items()} - return {k: v for k, v in cleaned_dict.items() if v not in empty_values} - - elif isinstance(data, list): - cleaned_list = [cls._clean_data(item) for item in data] - return [item for item in cleaned_list if item not in empty_values] - - else: - return data - - @classmethod - def serialize_data(cls, schema_name: str, data: Result, fields: list = None) -> dict: - output_schema = cls._get_option_from_meta(schema_name) - - if isinstance(data, list): - serialized_data = output_schema(many=True, only=fields).dump(data) - else: - serialized_data = output_schema(only=fields).dump(data) - - return cls._clean_data(serialized_data) diff --git a/src/db_first/base_model.py b/src/db_first/base_model.py index 3018908..69c2234 100644 --- a/src/db_first/base_model.py +++ b/src/db_first/base_model.py @@ -1,13 +1,12 @@ import uuid from datetime import datetime from typing import Optional -from uuid import UUID from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column -def make_uuid4() -> UUID: +def make_uuid4() -> uuid.UUID: return uuid.uuid4() diff --git a/src/db_first/decorators/__init__.py b/src/db_first/decorators/__init__.py new file mode 100644 index 0000000..4cfd5bc --- /dev/null +++ b/src/db_first/decorators/__init__.py @@ -0,0 +1,3 @@ +from .validation import Validation + +__all__ = ['Validation'] diff --git a/src/db_first/decorators/validation.py b/src/db_first/decorators/validation.py new file mode 100644 index 0000000..b59e0dd --- /dev/null +++ b/src/db_first/decorators/validation.py @@ -0,0 +1,46 @@ +from collections.abc import Callable +from collections.abc import Iterable +from functools import wraps +from typing import Any + +from marshmallow.schema import SchemaMeta + + +class Validation: + @classmethod + def input( + cls, schema: SchemaMeta, deserialize: bool = True, keys: Iterable[str] or None = None + ) -> Callable: + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(self, **data) -> Any: + if deserialize: + deserialized_data = schema(only=keys).load(data) + return func(self, **deserialized_data) + else: + schema(only=keys).validate(data) + return func(self, **data) + + return wrapper + + return decorator + + @classmethod + def output( + cls, schema: SchemaMeta, serialize: bool = False, keys: Iterable[str] or None = None + ) -> Callable: + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(self, *args, **kwargs) -> Any: + obj = func(self, *args, **kwargs) + + if serialize: + serialized_data = schema(only=keys).dump(obj) + return serialized_data + else: + schema(only=keys).validate(obj) + return obj + + return wrapper + + return decorator diff --git a/src/db_first/mixins/crud.py b/src/db_first/mixins/crud.py index 93254cc..bb33e16 100644 --- a/src/db_first/mixins/crud.py +++ b/src/db_first/mixins/crud.py @@ -43,19 +43,6 @@ def create_object(self, **kwargs) -> Result: session.commit() return new_obj - def create( - self, data: dict, deserialize: bool = False, serialize: bool = False - ) -> Result or dict: - if deserialize: - data = self.deserialize_data('input_schema_of_create', data) - - new_object = self.create_object(**data) - - if serialize: - return self.serialize_data('output_schema_of_create', new_object) - - return new_object - class ReadMixin: """Read objects from database. @@ -95,7 +82,7 @@ def _calculate_items_per_page( self, session: Session, statement: Select, per_page: int ) -> tuple[int, int]: total = session.execute( - statement.with_only_columns(func.count()).order_by(None) + statement.with_only_columns(func.count(self.Meta.model.id)).order_by(None) ).scalar_one() if per_page == 0: @@ -115,9 +102,7 @@ def _paginate( page: int = 1, per_page: Optional[int] = 20, max_per_page: Optional[int] = 100, - serialize: bool = False, include_metadata: bool = False, - fields: Optional[list] = None, ) -> dict: session: Session = self._get_option_from_meta('session') @@ -139,30 +124,21 @@ def _paginate( else: paginated_rows = [] - if serialize: - items['items'] = self.serialize_data('output_schema_of_read', paginated_rows, fields) - else: - items['items'] = paginated_rows + items['items'] = paginated_rows return items - def paginate( + def base_paginate( self, page: int = 1, per_page: Optional[int] = None, max_per_page: Optional[int] = None, statement: Optional[Select] = None, - deserialize: bool = False, - serialize: bool = False, include_metadata: bool = False, - fields: Optional[list] = None, **kwargs, ) -> Result or dict: model = self._get_option_from_meta('model') - if deserialize: - kwargs = self.deserialize_data('input_schema_of_read', kwargs) - filterable_fields = self._get_option_from_meta('filterable', ()) interval_filterable_fields = self._get_option_from_meta('interval_filterable', ()) searchable_fields = self._get_option_from_meta('searchable', ()) @@ -189,34 +165,18 @@ def paginate( page=page, per_page=per_page, max_per_page=max_per_page, - serialize=serialize, include_metadata=include_metadata, - fields=fields, ) return items def read_object(self, id: Any) -> Result: - session = self._get_option_from_meta('session') model = self._get_option_from_meta('model') stmt = select(model).where(model.id == id) return session.scalars(stmt).one() - def read( - self, data: dict, deserialize: bool = False, serialize: bool = False - ) -> Result or dict: - if deserialize: - data = self.deserialize_data('input_schema_of_read', data) - - object_ = self.read_object(**data) - - if serialize: - return self.serialize_data('output_schema_of_read', object_) - - return object_ - class UpdateMixin: """Update object in database. @@ -247,19 +207,6 @@ def update_object(self, id: Any, **kwargs) -> Result: obj = session.scalars(stmt).one() return obj - def update( - self, data: dict, deserialize: bool = False, serialize: bool = False - ) -> Result or dict: - if deserialize: - data = self.deserialize_data('input_schema_of_update', data) - - updated_object = self.update_object(**data) - - if serialize: - return self.serialize_data('output_schema_of_update', updated_object) - - return updated_object - class DeleteMixin: """Delete object from database.""" @@ -271,8 +218,3 @@ def delete_object(self, id: Any) -> None: model = self._get_option_from_meta('model') session.execute(delete(model).where(model.id == id)) - - def delete(self, data: dict, deserialize: bool = False) -> None: - if deserialize: - data = self.deserialize_data('input_schema_of_read', data) - self.delete_object(**data) diff --git a/src/db_first/schemas/__init__.py b/src/db_first/schemas/__init__.py new file mode 100644 index 0000000..79b6144 --- /dev/null +++ b/src/db_first/schemas/__init__.py @@ -0,0 +1,3 @@ +from .base import BaseSchema + +__all__ = ['BaseSchema'] diff --git a/src/db_first/schemas/base.py b/src/db_first/schemas/base.py new file mode 100644 index 0000000..ef01fe4 --- /dev/null +++ b/src/db_first/schemas/base.py @@ -0,0 +1,42 @@ +from marshmallow import post_dump +from marshmallow import Schema + + +class BaseSchema(Schema): + __empty_values__ = ('', None, ..., [], {}, (), set()) + __skipped_keys__ = () + + @post_dump() + def _delete_keys_with_empty_value(self, data, many=False) -> dict or list: + """Clearing hierarchical structures from empty values. + + Cleaning occurs for objects of the list and dict types, other types do not clean. + + :param data: an object for cleaning. + :param many: Should be set to `True` if ``obj`` is a collection so that the object will + be serialized to a list. + :return: cleaned object. + """ + + if isinstance(data, dict): + pre_cleaned_dict = { + k: self._delete_keys_with_empty_value(v, many=many) for k, v in data.items() + } + + cleaned_dict = {} + for k, v in pre_cleaned_dict.items(): + if k not in self.__skipped_keys__ and v in self.__empty_values__: + continue + else: + cleaned_dict[k] = v + + return cleaned_dict + + elif isinstance(data, list): + pre_cleaned_list = [ + self._delete_keys_with_empty_value(item, many=many) for item in data + ] + return [item for item in pre_cleaned_list if item not in self.__empty_values__] + + else: + return data diff --git a/src/db_first/statement_maker.py b/src/db_first/statement_maker.py index f590730..5de702b 100644 --- a/src/db_first/statement_maker.py +++ b/src/db_first/statement_maker.py @@ -116,4 +116,5 @@ def make_statement(self) -> Select: self._add_interval_filtration() self._add_searching() self._add_sorting() + return self.stmt diff --git a/tests/conftest.py b/tests/conftest.py index fa46758..3c37169 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,15 +2,6 @@ from uuid import UUID import pytest -from db_first import BaseCRUD -from db_first import ModelMixin -from db_first.mixins.crud import CreateMixin -from db_first.mixins.crud import DeleteMixin -from db_first.mixins.crud import ReadMixin -from db_first.mixins.crud import UpdateMixin -from marshmallow import fields -from marshmallow import Schema -from marshmallow import validate from sqlalchemy import create_engine from sqlalchemy import ForeignKey from sqlalchemy.engine import Result @@ -21,6 +12,19 @@ from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from .contrib.schemas import ChildSchema +from .contrib.schemas import FatherSchema +from .contrib.schemas import ParentSchema +from .contrib.schemas import ParentSchemaOfPaginate +from .contrib.schemas import ParentSchemaParametersOfPaginate +from src.db_first import BaseCRUD +from src.db_first import ModelMixin +from src.db_first.decorators import Validation +from src.db_first.mixins.crud import CreateMixin +from src.db_first.mixins.crud import DeleteMixin +from src.db_first.mixins.crud import ReadMixin +from src.db_first.mixins.crud import UpdateMixin + DATE_TIME_FORMAT = '%Y-%m-%dT%H:%M:%SZ' UNIQUE_STRING = (f'name_{number}' for number in range(1_000)) @@ -76,139 +80,77 @@ class Fathers(Base, ModelMixin): @pytest.fixture(scope='session') -def fx_father_schema_of_create() -> type[Schema]: - class SchemaOfCreate(Schema): - first = fields.String(required=True) - second = fields.String() - - return SchemaOfCreate - - -@pytest.fixture(scope='session') -def fx_father_output_schema(fx_father_schema_of_create) -> type[Schema]: - class OutputSchema(fx_father_schema_of_create): - id = fields.String(required=True) - created_at = fields.DateTime() - - return OutputSchema - - -@pytest.fixture(scope='session') -def fx_parent_schema_of_create() -> type[Schema]: - class SchemaOfCreate(Schema): - first = fields.String(required=True) - second = fields.String() - father_id = fields.UUID() - - return SchemaOfCreate - - -@pytest.fixture(scope='session') -def fx_parent_schema_of_update(fx_parent_schema_of_create) -> type[Schema]: - class SchemaOfUpdate(fx_parent_schema_of_create): - id = fields.UUID(required=True) - - return SchemaOfUpdate - - -@pytest.fixture(scope='session') -def fx_parent_schema_of_read() -> type[Schema]: - class SchemaOfRead(Schema): - id = fields.List(fields.UUID()) - page = fields.Integer(validate=validate.Range(min=0)) - max_per_page = fields.Integer(validate=validate.Range(min=0)) - per_page = fields.Integer(validate=validate.Range(min=0)) - sort_created_at = fields.String(validate=validate.OneOf(['asc', 'desc'])) - search_first = fields.String() - first = fields.String() - start_created_at = fields.DateTime() - end_created_at = fields.DateTime() - - return SchemaOfRead - - -@pytest.fixture(scope='session') -def fx_parent_output_schema(fx_father_output_schema, fx_child_output_schema) -> type[Schema]: - class OutputSchema(Schema): - id = fields.String(required=True) - first = fields.String(required=True) - second = fields.String() - created_at = fields.DateTime() - father = fields.Nested(fx_father_output_schema) - children = fields.Nested(fx_child_output_schema, many=True) - - return OutputSchema - - -@pytest.fixture(scope='session') -def fx_child_schema_of_create() -> type[Schema]: - class SchemaOfCreate(Schema): - first = fields.String(required=True) - second = fields.String() - parent_id = fields.UUID() - - return SchemaOfCreate - - -@pytest.fixture(scope='session') -def fx_child_output_schema(fx_child_schema_of_create) -> type[Schema]: - class SchemaOfCreate(fx_child_schema_of_create): - id = fields.String(required=True) - - return SchemaOfCreate - - -@pytest.fixture(scope='session') -def fx_parent_controller( - fx_db, - fx_parent_schema_of_create, - fx_parent_schema_of_read, - fx_parent_schema_of_update, - fx_parent_output_schema, -): +def fx_parent_controller(fx_db): session_db, parents_model, _, _ = fx_db class Parent(CreateMixin, ReadMixin, UpdateMixin, DeleteMixin, BaseCRUD): class Meta: session = session_db model = parents_model - input_schema_of_create = fx_parent_schema_of_create - output_schema_of_create = fx_parent_output_schema - input_schema_of_update = fx_parent_schema_of_update filterable = ['id', 'first'] interval_filterable = ['created_at'] sortable = ['created_at'] searchable = ['first'] - input_schema_of_read = fx_parent_schema_of_read - output_schema_of_read = fx_parent_output_schema + + @Validation.input(ParentSchema, keys=('first', 'second', 'father_id')) + @Validation.output(ParentSchema) + def create(self, **data: dict) -> Result or dict: + return super().create_object(**data) + + @Validation.input(ParentSchema, keys=('id', 'first', 'second', 'father_id')) + @Validation.output(ParentSchema) + def update(self, **data: dict) -> Result or dict: + return super().update_object(**data) + + @Validation.input(ParentSchemaParametersOfPaginate) + @Validation.output(ParentSchemaOfPaginate, serialize=True) + def paginate(self, **data: dict) -> Result or dict: + if 'ids' in data: + data['id'] = data.pop('ids') + return super().base_paginate(**data) + + @Validation.input(ParentSchemaParametersOfPaginate) + @Validation.output(ParentSchemaOfPaginate, keys=('items.id',), serialize=True) + def paginate_ids(self, **data: dict) -> Result or dict: + if 'ids' in data: + data['id'] = data.pop('ids') + return super().base_paginate(**data) return Parent() @pytest.fixture(scope='session') -def fx_child_controller(fx_db, fx_child_schema_of_create): +def fx_child_controller(fx_db): session_db, _, child_model, _ = fx_db class Child(CreateMixin, ReadMixin, UpdateMixin, DeleteMixin, BaseCRUD): class Meta: session = session_db model = child_model - input_schema_of_create = fx_child_schema_of_create filterable = ['id', 'parent_id'] sortable = ['parent_id', 'created_at'] + @Validation.input(ChildSchema, keys=('first', 'second', 'parent_id')) + @Validation.output(ChildSchema) + def create(self, **data: dict) -> Result or dict: + return super().create_object(**data) + return Child() @pytest.fixture(scope='session') -def fx_father_controller(fx_db, fx_father_schema_of_create): +def fx_father_controller(fx_db): session_db, _, _, father_model = fx_db class Father(CreateMixin, ReadMixin, UpdateMixin, DeleteMixin, BaseCRUD): class Meta: session = session_db model = father_model - input_schema_of_create = fx_father_schema_of_create + + @Validation.input(FatherSchema, keys=('first', 'second')) + @Validation.output(FatherSchema) + def create(self, **data: dict) -> Result or dict: + return super().create_object(**data) return Father() @@ -216,15 +158,11 @@ class Meta: @pytest.fixture def fx_parents__non_deletion(fx_parent_controller, fx_child_controller, fx_father_controller): def _create_item() -> Result: - new_father = fx_father_controller.create({'first': next(UNIQUE_STRING)}) + new_father = fx_father_controller.create(first=next(UNIQUE_STRING)) new_parent = fx_parent_controller.create( - { - 'first': next(UNIQUE_STRING), - 'second': f'full {next(UNIQUE_STRING)}', - 'father_id': new_father.id, - } + first=next(UNIQUE_STRING), second=f'full {next(UNIQUE_STRING)}', father_id=new_father.id ) - fx_child_controller.create({'first': next(UNIQUE_STRING), 'parent_id': new_parent.id}) + fx_child_controller.create(first=next(UNIQUE_STRING), parent_id=new_parent.id) return new_parent return _create_item diff --git a/tests/contrib/schemas.py b/tests/contrib/schemas.py new file mode 100644 index 0000000..f3a6244 --- /dev/null +++ b/tests/contrib/schemas.py @@ -0,0 +1,62 @@ +from marshmallow import fields +from marshmallow import validate + +from src.db_first.schemas import BaseSchema + + +class IdSchema(BaseSchema): + id = fields.UUID(required=True) + + +class FatherSchema(IdSchema): + first = fields.String(required=True) + second = fields.String() + created_at = fields.DateTime() + + +class ChildSchema(IdSchema): + first = fields.String(required=True) + second = fields.String() + parent_id = fields.UUID() + + +class ParentSchema(IdSchema): + first = fields.String(required=True) + second = fields.String(allow_none=False) + father_id = fields.UUID() + created_at = fields.DateTime() + father = fields.Nested(FatherSchema) + children = fields.Nested(ChildSchema, many=True) + + +class ParentSchemaParametersOfPaginate(BaseSchema): + id = fields.UUID() + ids = fields.List(fields.UUID()) + page = fields.Integer(validate=validate.Range(min=0)) + max_per_page = fields.Integer(validate=validate.Range(min=0)) + per_page = fields.Integer(validate=validate.Range(min=0)) + sort_created_at = fields.String(validate=validate.OneOf(['asc', 'desc'])) + search_first = fields.String() + first = fields.String() + start_created_at = fields.DateTime() + end_created_at = fields.DateTime() + include_metadata = fields.Boolean(validate=validate.OneOf([True])) + fields = fields.List(fields.String()) + + +class Paginate(BaseSchema): + page = fields.Integer(allow_none=False) + per_page = fields.Integer(allow_none=False) + pages = fields.Integer(allow_none=False) + total = fields.Integer(allow_none=False) + + +class MetadataSchema(BaseSchema): + pagination = fields.Nested(Paginate) + + +class ParentSchemaOfPaginate(BaseSchema): + __skipped_keys__ = ('items',) + + _metadata = fields.Nested(MetadataSchema) + items = fields.Nested(ParentSchema, many=True) diff --git a/tests/test_crud_mixin.py b/tests/test_crud_mixin.py index ce46f85..a86ba4f 100644 --- a/tests/test_crud_mixin.py +++ b/tests/test_crud_mixin.py @@ -3,21 +3,28 @@ from uuid import uuid4 import pytest -from db_first import BaseCRUD -from db_first import ModelMixin -from db_first.exc import MetaNotFound -from db_first.exc import OptionNotFound -from db_first.mixins.crud import CreateMixin -from db_first.mixins.crud import DeleteMixin -from db_first.mixins.crud import ReadMixin -from db_first.mixins.crud import UpdateMixin from marshmallow import fields from marshmallow import Schema +from sqlalchemy import Result from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column -from .conftest import UNIQUE_STRING +from src.db_first import BaseCRUD +from src.db_first import ModelMixin +from src.db_first.decorators import Validation +from src.db_first.exc import MetaNotFound +from src.db_first.exc import OptionNotFound +from src.db_first.mixins.crud import CreateMixin +from src.db_first.mixins.crud import DeleteMixin +from src.db_first.mixins.crud import ReadMixin +from src.db_first.mixins.crud import UpdateMixin +from tests.conftest import UNIQUE_STRING + + +class TestSchema(Schema): + id = fields.UUID() + first = fields.String() def test_crud_mixin(fx_db_connection): @@ -30,64 +37,50 @@ class TestModel(Base, ModelMixin): Base.metadata.create_all(engine) - class SchemaOfCreate(Schema): - first = fields.String() - - class SchemaOfResultCreate(Schema): - id = fields.UUID() - first = fields.String() - - class SchemaOfRead(Schema): - id = fields.UUID() - - class TestCreate(CreateMixin, BaseCRUD): + class TestCRUD(CreateMixin, ReadMixin, UpdateMixin, DeleteMixin, BaseCRUD): class Meta: model = TestModel session = db_session - input_schema_of_create = SchemaOfCreate - output_schema_of_create = SchemaOfResultCreate + filterable = ['id'] + + @Validation.input(TestSchema) + @Validation.output(TestSchema, serialize=True) + def create(self, **data) -> Result: + return super().create_object(**data) + + @Validation.input(TestSchema, keys=['id']) + @Validation.output(TestSchema, serialize=True) + def read(self, **data) -> Result: + return super().read_object(data['id']) + + @Validation.input(TestSchema) + @Validation.output(TestSchema, serialize=True) + def update(self, **data) -> Result: + return super().update_object(**data) + + @Validation.input(TestSchema, keys=['id']) + def delete(self, **data) -> None: + super().delete_object(**data) data_for_create = {'first': next(UNIQUE_STRING)} - TestCreate().create(data_for_create, serialize=True) - new_data = TestCreate().create(data_for_create, serialize=True) + TestCRUD().create(**data_for_create) + new_data = TestCRUD().create(**data_for_create) new_data_for_assert = deepcopy(new_data) assert new_data_for_assert.pop('id') assert new_data_for_assert == data_for_create - class TestRead(ReadMixin, BaseCRUD): - class Meta: - model = TestModel - session = db_session - filterable = ['id'] - input_schema_of_read = SchemaOfRead - output_schema_of_read = SchemaOfResultCreate - - data_for_read = TestRead().read({'id': UUID(new_data['id'])}, serialize=True) + data_for_read = TestCRUD().read(**{'id': UUID(new_data['id'])}) assert new_data == data_for_read - class TestUpdate(UpdateMixin, BaseCRUD): - class Meta: - model = TestModel - session = db_session - input_schema_of_update = SchemaOfResultCreate - output_schema_of_update = SchemaOfResultCreate - data_for_update = {'id': UUID(new_data['id']), 'first': next(UNIQUE_STRING)} - updated_data = TestUpdate().update(data=data_for_update, serialize=True) + updated_data = TestCRUD().update(**data_for_update) data_for_update['id'] = str(data_for_update['id']) assert updated_data == data_for_update - class TestDelete(DeleteMixin, BaseCRUD): - class Meta: - model = TestModel - session = db_session - input_schema_of_update = SchemaOfResultCreate - output_schema_of_update = SchemaOfResultCreate - - TestDelete().delete({'id': UUID(new_data['id'])}) + TestCRUD().delete(**{'id': UUID(new_data['id'])}) with pytest.raises(NoResultFound): - assert not TestRead().read({'id': UUID(new_data['id'])}) + assert not TestCRUD().read(**{'id': UUID(new_data['id'])}) def test_crud_mixin__wrong_meta(fx_db_connection): @@ -97,20 +90,20 @@ class TestController(CreateMixin, ReadMixin, UpdateMixin, DeleteMixin, BaseCRUD) data_for_create = {'first': next(UNIQUE_STRING)} with pytest.raises(MetaNotFound) as e: - TestController().create(data=data_for_create, serialize=True) + TestController().create_object(**data_for_create) assert e.value.args[0] == 'You need add class Meta with options.' with pytest.raises(MetaNotFound) as e: - TestController().read({'id': uuid4()}, serialize=True) + TestController().read_object(**{'id': uuid4()}) assert e.value.args[0] == 'You need add class Meta with options.' data_for_update = {'id': uuid4(), 'first': next(UNIQUE_STRING)} with pytest.raises(MetaNotFound) as e: - TestController().update(data=data_for_update, serialize=True) + TestController().update_object(**data_for_update) assert e.value.args[0] == 'You need add class Meta with options.' with pytest.raises(MetaNotFound) as e: - TestController().delete({'id': uuid4()}) + TestController().delete_object(**{'id': uuid4()}) assert e.value.args[0] == 'You need add class Meta with options.' @@ -122,18 +115,18 @@ class Meta: data_for_create = {'first': next(UNIQUE_STRING)} with pytest.raises(OptionNotFound) as e: - TestController().create(data=data_for_create, deserialize=True) - assert e.value.args[0] == 'Option not set in Meta class.' + TestController().create_object(**data_for_create) + assert e.value.args[0] == 'Option not set in Meta class.' with pytest.raises(OptionNotFound) as e: - TestController().read({'id': uuid4()}, serialize=True) + TestController().read_object(**{'id': uuid4()}) assert e.value.args[0] == 'Option not set in Meta class.' data_for_update = {'id': uuid4(), 'first': next(UNIQUE_STRING)} with pytest.raises(OptionNotFound) as e: - TestController().update(data=data_for_update, deserialize=True) - assert e.value.args[0] == 'Option not set in Meta class.' + TestController().update_object(**data_for_update) + assert e.value.args[0] == 'Option not set in Meta class.' with pytest.raises(OptionNotFound) as e: - TestController().delete({'id': uuid4()}) + TestController().delete_object(**{'id': uuid4()}) assert e.value.args[0] == 'Option not set in Meta class.' diff --git a/tests/test_pagination_mixin.py b/tests/test_pagination_mixin.py index 7b29cef..1c1bdd9 100644 --- a/tests/test_pagination_mixin.py +++ b/tests/test_pagination_mixin.py @@ -3,24 +3,24 @@ from math import ceil import pytest -import sqlalchemy -from db_first import BaseCRUD -from db_first.mixins import CreateMixin -from db_first.mixins import ReadMixin from sqlalchemy import select +from sqlalchemy.engine import Result from .conftest import UNIQUE_STRING +from src.db_first import BaseCRUD +from src.db_first.mixins import CreateMixin +from src.db_first.mixins import ReadMixin def test_controller__pagination_without_metadata(fx_parent_controller): - fx_parent_controller.create({'first': next(UNIQUE_STRING)}) - fx_parent_controller.create({'first': next(UNIQUE_STRING)}) + fx_parent_controller.create(first=next(UNIQUE_STRING)) + fx_parent_controller.create(first=next(UNIQUE_STRING)) items = fx_parent_controller.paginate() assert not items.get('_metadata') for item in items['items']: - assert item.id + assert item['id'] def test_controller__pagination(fx_parents__non_deletion, fx_parent_controller): @@ -28,23 +28,25 @@ def test_controller__pagination(fx_parents__non_deletion, fx_parent_controller): ids = [fx_parents__non_deletion().id for _ in range(total_items_number)] items = fx_parent_controller.paginate( - page=1, per_page=2, max_per_page=20, include_metadata=True, id=ids + page=1, per_page=2, max_per_page=20, include_metadata=True, ids=ids ) assert items['items'] assert len(items['items']) == 2 assert items['_metadata']['pagination']['page'] == 1 assert items['_metadata']['pagination']['per_page'] == 2 - assert items['_metadata']['pagination']['pages'] == total_items_number / 2 - assert items['_metadata']['pagination']['total'] == total_items_number + assert ( + items['_metadata']['pagination']['pages'] == items['_metadata']['pagination']['total'] / 2 + ) + assert items['_metadata']['pagination']['total'] >= total_items_number def test_controller__sorting(fx_parents__non_deletion, fx_parent_controller): _ = [fx_parents__non_deletion() for _ in range(10)] items = fx_parent_controller.paginate(sort_created_at='asc') - asc_first_item = items['items'][0].id + asc_first_item = items['items'][0]['id'] items = fx_parent_controller.paginate(sort_created_at='desc') - desc_first_item = items['items'][0].id + desc_first_item = items['items'][0]['id'] assert asc_first_item != desc_first_item @@ -54,78 +56,77 @@ def test_controller__searching(fx_parents__non_deletion, fx_parent_controller): items = fx_parent_controller.paginate(search_first=new_item.first) for item in items['items']: - assert item.first == new_item.first + assert item['first'] == new_item.first def test_controller__get_fields_of_list(fx_db, fx_parent_controller): _, _, _, Fathers = fx_db - fx_parent_controller.create({'first': next(UNIQUE_STRING)}) + fx_parent_controller.create(first=next(UNIQUE_STRING)) - fields = ['id'] - items = fx_parent_controller.paginate(fields=fields, serialize=True) + items = fx_parent_controller.paginate_ids() assert items['items'] for item in items['items']: - assert list(item) == fields + assert list(item) == ['id'] def test_controller__filtrating(fx_parents__non_deletion, fx_parent_controller): first_item = fx_parents__non_deletion() _ = [fx_parents__non_deletion() for _ in range(10)] - patched_item_payload = {'id': first_item.id, 'first': 'first for test filtrating'} - patched_first_item = fx_parent_controller.update(patched_item_payload) + patched_item_payload = {'id': str(first_item.id), 'first': 'first for test filtrating'} + patched_first_item = fx_parent_controller.update(**patched_item_payload) items = fx_parent_controller.paginate(first=patched_first_item.first) assert len(items['items']) == 1 - assert items['items'][0].id == first_item.id + assert items['items'][0]['id'] == str(first_item.id) items = fx_parent_controller.paginate(page=1, per_page=1, first=patched_first_item.first) assert items['items'] - assert items['items'][0].id == first_item.id + assert items['items'][0]['id'] == str(first_item.id) def test_controller__interval_filtration(fx_parent_controller): - item_first = fx_parent_controller.create({'first': next(UNIQUE_STRING)}) - item_second = fx_parent_controller.create({'first': next(UNIQUE_STRING)}) - item_third = fx_parent_controller.create({'first': next(UNIQUE_STRING)}) + item_1 = fx_parent_controller.create(first=next(UNIQUE_STRING)) + item_2 = fx_parent_controller.create(first=next(UNIQUE_STRING)) + item_3 = fx_parent_controller.create(first=next(UNIQUE_STRING)) items_asc = fx_parent_controller.paginate( include_metadata=True, - id=[item_first.id, item_second.id, item_third.id], + ids=[item_1.id, item_2.id, item_3.id], sort_created_at='asc', - start_created_at=item_first.created_at, - end_created_at=item_third.created_at, + start_created_at=item_1.created_at.isoformat(), + end_created_at=item_3.created_at.isoformat(), ) assert items_asc['_metadata']['pagination']['page'] == 1 assert items_asc['_metadata']['pagination']['pages'] == 1 assert items_asc['items'] assert len(items_asc['items']) == 2 - assert items_asc['items'][0].id == item_first.id - assert items_asc['items'][1].id == item_second.id + assert items_asc['items'][0]['id'] == str(item_1.id) + assert items_asc['items'][1]['id'] == str(item_2.id) items_desc = fx_parent_controller.paginate( include_metadata=True, - id=[item_first.id, item_second.id, item_third.id], + ids=[item_1.id, item_2.id, item_3.id], sort_created_at='desc', - start_created_at=item_first.created_at, - end_created_at=item_third.created_at, + start_created_at=item_1.created_at.isoformat(), + end_created_at=item_3.created_at.isoformat(), ) assert items_desc['_metadata']['pagination']['page'] == 1 assert items_desc['_metadata']['pagination']['pages'] == 1 assert items_desc['items'] assert len(items_desc['items']) == 2 - assert items_desc['items'][0].id == item_second.id - assert items_desc['items'][1].id == item_first.id + assert items_desc['items'][0]['id'] == str(item_2.id) + assert items_desc['items'][1]['id'] == str(item_1.id) @pytest.mark.parametrize('page', [-1, 0]) @pytest.mark.parametrize('per_page', [-1, 0]) def test_controller__get_non_exist_page(fx_parent_controller, page, per_page): - fx_parent_controller.create({'first': next(UNIQUE_STRING)}) + fx_parent_controller.create(first=next(UNIQUE_STRING)) - items = fx_parent_controller.paginate(page=page, per_page=per_page, include_metadata=True) + items = fx_parent_controller.base_paginate(page=page, per_page=per_page, include_metadata=True) assert items['_metadata']['pagination']['per_page'] == 0 assert items['_metadata']['pagination']['pages'] == 0 @@ -136,20 +137,22 @@ def test_controller__get_non_exist_page(fx_parent_controller, page, per_page): @pytest.mark.parametrize('page', [1, 2]) @pytest.mark.parametrize('per_page', [1, 2]) def test_controller__get_pages(fx_parent_controller, page, per_page): - fx_parent_controller.create({'first': next(UNIQUE_STRING)}) + fx_parent_controller.create(first=next(UNIQUE_STRING)) items = fx_parent_controller.paginate(page=page, per_page=per_page, include_metadata=True) - assert 1 <= items['_metadata']['pagination']['per_page'] <= 100 - assert items['_metadata']['pagination']['pages'] == 1 - assert items['_metadata']['pagination']['total'] > 0 + assert items['_metadata']['pagination']['page'] == page + assert items['_metadata']['pagination']['per_page'] == per_page + assert items['_metadata']['pagination']['pages'] == ceil( + items['_metadata']['pagination']['total'] / per_page + ) assert items['items'] @pytest.mark.parametrize('page', [101, 1_000_000]) @pytest.mark.parametrize('per_page', [101, 1_000_000]) def test_controller__get_over_pages(fx_parent_controller, page, per_page): - fx_parent_controller.create({'first': next(UNIQUE_STRING)}) + fx_parent_controller.create(first=next(UNIQUE_STRING)) items = fx_parent_controller.paginate(page=page, per_page=per_page, include_metadata=True) @@ -162,7 +165,7 @@ def test_controller__get_over_pages(fx_parent_controller, page, per_page): assert not items['items'] -def test_controller__without_meta_pagination(fx_db, fx_parent_schema_of_create): +def test_controller__without_meta_pagination(fx_db): db_session, Parents, _, _ = fx_db class CustomController(CreateMixin, ReadMixin, BaseCRUD): @@ -170,37 +173,39 @@ class Meta: model = Parents statement = select(model) session = db_session - input_schema_of_create = fx_parent_schema_of_create + + def create(self, **data: dict) -> Result or dict: + return super().create_object(**data) + + def paginate(self, **data: dict) -> Result or dict: + if 'ids' in data: + data['id'] = data.pop('ids') + return super().base_paginate(**data) custom_controller = CustomController() - custom_controller.create({'first': next(UNIQUE_STRING)}) + custom_controller.create(first=next(UNIQUE_STRING)) items = custom_controller.paginate(fields=['id']) assert '_metadata' not in items assert items['items'] -def test_controller__statement( - fx_db, fx_parent_schema_of_create, fx_parents__non_deletion, fx_parent_controller -): +def test_controller__statement(fx_db, fx_parents__non_deletion, fx_parent_controller): session, Parents, Children, _ = fx_db total_items_number = 10 ids = [fx_parents__non_deletion().id for _ in range(total_items_number)] - statement = sqlalchemy.select(Parents).join(Children) - items = fx_parent_controller.paginate( - statement=statement, page=1, per_page=2, max_per_page=20, include_metadata=True, - id=ids, + ids=ids, sort_created_at='desc', - search='name', - start__created_at=datetime.utcnow() - timedelta(minutes=1), - end__created_at=datetime.utcnow(), + search_first='name', + start_created_at=(datetime.utcnow() - timedelta(minutes=1)).isoformat(), + end_created_at=datetime.utcnow().isoformat(), ) assert items['items'] assert len(items['items']) == 2 @@ -213,15 +218,11 @@ def test_controller__statement( def test_controller__fields_for_relations( fx_parent_controller, fx_child_controller, fx_father_controller ): - new_father = fx_father_controller.create({'first': next(UNIQUE_STRING)}) - new_parent = fx_parent_controller.create( - {'first': next(UNIQUE_STRING), 'father_id': new_father.id} - ) - new_child = fx_child_controller.create( - {'first': next(UNIQUE_STRING), 'parent_id': new_parent.id} - ) + new_father = fx_father_controller.create(first=next(UNIQUE_STRING)) + new_parent = fx_parent_controller.create(first=next(UNIQUE_STRING), father_id=new_father.id) + new_child = fx_child_controller.create(first=next(UNIQUE_STRING), parent_id=new_parent.id) - items = fx_parent_controller.paginate(id=new_parent.id, serialize=True) + items = fx_parent_controller.paginate(id=new_parent.id) assert len(items['items']) == 1 assert items['items'][0]['id'] == str(new_parent.id) assert items['items'][0]['created_at'] == new_parent.created_at.isoformat()