Skip to content

Commit

Permalink
Migrate from function factory to using metaclass
Browse files Browse the repository at this point in the history
  • Loading branch information
hugobessa committed Dec 3, 2024
1 parent c76e276 commit ff7a92c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 79 deletions.
18 changes: 11 additions & 7 deletions example_app/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from sqlalchemy import String, BigInteger, Integer
from sqlalchemy.orm import Mapped, mapped_column, declarative_base
from sqlalchemy.orm import Mapped, mapped_column, DeclarativeBase, relationship

from vintasend_sqlalchemy.model_factory import create_notification_model
from vintasend_sqlalchemy.model_factory import GenericNotification, NotificationMeta


Base = declarative_base()
metadata = Base.metadata
class Base(DeclarativeBase):
pass


class User(Base):
Expand All @@ -14,9 +14,13 @@ class User(Base):
email: Mapped[str] = mapped_column("email", String(255), nullable=False)


BaseNotification = create_notification_model(User, "id", BigInteger)


class Notification(BaseNotification):
class Notification(
GenericNotification[User, int],
metaclass=NotificationMeta,
user_model=User,
user_primary_key_field_name="id",
user_primary_key_field_type=int
):
def get_user_email(self) -> str:
return self.get_user().email
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ pytest-cov = "^6.0.0"
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"


[tool.ruff]
select = [
# pycodestyle
Expand Down
1 change: 0 additions & 1 deletion vintasend_sqlalchemy/alembic_initial_migration_ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from vintasend_sqlalchemy.model_factory import create_notification_model
from alembic import op
import sqlalchemy as sa
import datetime
Expand Down
119 changes: 49 additions & 70 deletions vintasend_sqlalchemy/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import datetime
import uuid
from typing import Any, TypeVar, Generic, overload
from typing import Any, Callable, TypeVar, Generic

from sqlalchemy import JSON, UUID, DateTime, ForeignKey, Integer, BigInteger, String
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, declarative_base
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept

Base = declarative_base()
class Base(DeclarativeBase):
pass


class NotificationMixin(Base):
__tablename__ = "notifications"
__abstract__ = True
id: Mapped[int] = mapped_column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) # noqa: A003
notification_type: Mapped[str] = mapped_column("notification_type", String(50), nullable=False)
title: Mapped[str] = mapped_column("title", String(255), nullable=False)
Expand Down Expand Up @@ -57,76 +60,52 @@ def set_user_id(self, user_id: Any):


UserType = TypeVar('UserType', bound=DeclarativeBase)
UserPrimaryKeyType = TypeVar('UserPrimaryKeyType', type[Integer], type[BigInteger], type[String], type[UUID])
UserMappedPrimaryKeyType = TypeVar('UserMappedPrimaryKeyType')


class UserIdTypeMappingMetaClass(type):
@overload
def __getitem__(cls, user_id_type: type[Integer]) -> type[int]: ...
@overload
def __getitem__(cls, user_id_type: type[BigInteger]) -> type[int]: ...
@overload
def __getitem__(cls, user_id_type: type[String]) -> type[str]: ...
@overload
def __getitem__(cls, user_id_type: type[UUID]) -> type[uuid.UUID]: ...
def __getitem__(cls, user_id_type: type) -> type:
types_mapping = {
String: str,
Integer: int,
BigInteger: int,
UUID: uuid.UUID,
}
return types_mapping.get(user_id_type, int)
UserPrimaryKeyType = TypeVar('UserPrimaryKeyType', int, str, uuid.UUID)


class NotificationMeta(DeclarativeAttributeIntercept):
def __new__(cls, name, bases, dct, user_model, user_primary_key_field_name, user_primary_key_field_type):
if user_primary_key_field_type == int:
dct['user_id'] = mapped_column(ForeignKey(getattr(user_model, user_primary_key_field_name)))
dct['set_user_id'] = lambda self, user_id: setattr(self, 'user_id', user_id)
elif user_primary_key_field_type == str:
dct['user_id'] = mapped_column(ForeignKey(getattr(user_model, user_primary_key_field_name)))
dct['set_user_id'] = lambda self, user_id: setattr(self, 'user_id', user_id)
elif user_primary_key_field_type == uuid.UUID:
dct['user_id'] = mapped_column(ForeignKey(getattr(user_model, user_primary_key_field_name)))
dct['set_user_id'] = lambda self, user_id: setattr(self, 'user_id', user_id)

dct['user'] = relationship(user_model, backref="notifications")
dct['get_user_id'] = lambda self: self.user_id
dct['get_user'] = lambda self: self.user
dct['__tablename__'] = "notifications"
dct['__tableargs__'] = {"extend_existing": True}

return super().__new__(cls, name, bases, dct)


class GenericNotification(
NotificationMixin,
Generic[UserType, UserPrimaryKeyType],
):
__abstract__ = True

user: Mapped[UserType]
user_id: Mapped[UserPrimaryKeyType]

class UserIdTypeMapping(Generic[UserPrimaryKeyType], metaclass=UserIdTypeMappingMetaClass):
pass


class GenericNotification(NotificationMixin, Generic[UserType, UserMappedPrimaryKeyType]):
def get_user(self) -> UserType:
raise NotImplementedError

def get_user_id(self) -> UserPrimaryKeyType:
raise NotImplementedError

def set_user_id(self, user_id: UserPrimaryKeyType):
def set_user_id(self, user_id: UserPrimaryKeyType) -> None:
raise NotImplementedError

def get_user(self) -> UserType:
raise NotImplementedError

@staticmethod
def get_user_id_attr_name() -> str:
return "user_id"

def create_notification_model(
user_model: type[UserType],
user_primary_key_field_name: str,
user_primary_key_field_type: UserPrimaryKeyType
) -> type[GenericNotification[UserType, UserIdTypeMapping[UserPrimaryKeyType]]]:
class Notification(NotificationMixin):
__tablename__ = "notifications"
__table_args__ = {"extend_existing": True}
user_id: Mapped[UserIdTypeMapping[UserPrimaryKeyType]] = mapped_column(
ForeignKey(getattr(user_model, user_primary_key_field_name)),
)
user: Mapped[UserType] = relationship(user_model, back_populates="notifications")

def get_user(self) -> UserType:
return self.user

def get_user_id(self):
return self.user_id

@staticmethod
def get_user_id_attr_name() -> str:
return "user_id"

@staticmethod
def get_user_attr_name() -> str:
return "user"

def set_user_id(self, user_id: UserIdTypeMapping[UserPrimaryKeyType]):
self.user_id = user_id

user_model.notifications = relationship(
Notification,
order_by=Notification.created,
back_populates="user",
)
return Notification
@staticmethod
def get_user_attr_name() -> str:
return "user"
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class SQLAlchemyNotificationBackend(Generic[NotificationModel], BaseNotification
notification_model_cls: "type[NotificationModel]"

def __init__(self, session: sessionmaker[Session], notification_model_cls: "type[NotificationModel]") -> None:
self.backend_kwargs = {
"session": session,
"notification_model_cls": notification_model_cls,
}
self.session_manager = session
self.notification_model_cls = (
notification_model_cls if notification_model_cls else self._get_notification_model_cls()
Expand Down

0 comments on commit ff7a92c

Please sign in to comment.