|
1 | 1 | """FastAPI Users database adapter for SQLAlchemy."""
|
2 | 2 | import uuid
|
3 |
| -from typing import Any, Dict, Generic, Optional, Type |
| 3 | +from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Type |
4 | 4 |
|
5 | 5 | from fastapi_users.db.base import BaseUserDatabase
|
6 | 6 | from fastapi_users.models import ID, OAP, UP
|
7 | 7 | from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, func, select
|
8 | 8 | from sqlalchemy.ext.asyncio import AsyncSession
|
9 |
| -from sqlalchemy.ext.declarative import declared_attr |
| 9 | +from sqlalchemy.orm import declarative_mixin, declared_attr |
10 | 10 | from sqlalchemy.sql import Select
|
11 | 11 |
|
12 | 12 | from fastapi_users_db_sqlalchemy.generics import GUID
|
|
16 | 16 | UUID_ID = uuid.UUID
|
17 | 17 |
|
18 | 18 |
|
| 19 | +@declarative_mixin |
19 | 20 | class SQLAlchemyBaseUserTable(Generic[ID]):
|
20 | 21 | """Base SQLAlchemy users table definition."""
|
21 | 22 |
|
22 | 23 | __tablename__ = "user"
|
23 | 24 |
|
24 |
| - id: ID |
25 |
| - email: str = Column(String(length=320), unique=True, index=True, nullable=False) |
26 |
| - hashed_password: str = Column(String(length=1024), nullable=False) |
27 |
| - is_active: bool = Column(Boolean, default=True, nullable=False) |
28 |
| - is_superuser: bool = Column(Boolean, default=False, nullable=False) |
29 |
| - is_verified: bool = Column(Boolean, default=False, nullable=False) |
30 |
| - |
31 |
| - |
| 25 | + if TYPE_CHECKING: # pragma: no cover |
| 26 | + id: ID |
| 27 | + email: str |
| 28 | + hashed_password: str |
| 29 | + is_active: bool |
| 30 | + is_superuser: bool |
| 31 | + is_verified: bool |
| 32 | + else: |
| 33 | + email: str = Column(String(length=320), unique=True, index=True, nullable=False) |
| 34 | + hashed_password: str = Column(String(length=1024), nullable=False) |
| 35 | + is_active: bool = Column(Boolean, default=True, nullable=False) |
| 36 | + is_superuser: bool = Column(Boolean, default=False, nullable=False) |
| 37 | + is_verified: bool = Column(Boolean, default=False, nullable=False) |
| 38 | + |
| 39 | + |
| 40 | +@declarative_mixin |
32 | 41 | class SQLAlchemyBaseUserTableUUID(SQLAlchemyBaseUserTable[UUID_ID]):
|
33 |
| - id: UUID_ID = Column(GUID, primary_key=True, default=uuid.uuid4) |
| 42 | + if TYPE_CHECKING: # pragma: no cover |
| 43 | + id: UUID_ID |
| 44 | + else: |
| 45 | + id: UUID_ID = Column(GUID, primary_key=True, default=uuid.uuid4) |
34 | 46 |
|
35 | 47 |
|
| 48 | +@declarative_mixin |
36 | 49 | class SQLAlchemyBaseOAuthAccountTable(Generic[ID]):
|
37 | 50 | """Base SQLAlchemy OAuth account table definition."""
|
38 | 51 |
|
39 | 52 | __tablename__ = "oauth_account"
|
40 | 53 |
|
41 |
| - id: ID |
42 |
| - oauth_name: str = Column(String(length=100), index=True, nullable=False) |
43 |
| - access_token: str = Column(String(length=1024), nullable=False) |
44 |
| - expires_at: Optional[int] = Column(Integer, nullable=True) |
45 |
| - refresh_token: Optional[str] = Column(String(length=1024), nullable=True) |
46 |
| - account_id: str = Column(String(length=320), index=True, nullable=False) |
47 |
| - account_email: str = Column(String(length=320), nullable=False) |
48 |
| - |
49 |
| - |
| 54 | + if TYPE_CHECKING: # pragma: no cover |
| 55 | + id: ID |
| 56 | + oauth_name: str |
| 57 | + access_token: str |
| 58 | + expires_at: Optional[int] |
| 59 | + refresh_token: Optional[str] |
| 60 | + account_id: str |
| 61 | + account_email: str |
| 62 | + else: |
| 63 | + oauth_name: str = Column(String(length=100), index=True, nullable=False) |
| 64 | + access_token: str = Column(String(length=1024), nullable=False) |
| 65 | + expires_at: Optional[int] = Column(Integer, nullable=True) |
| 66 | + refresh_token: Optional[str] = Column(String(length=1024), nullable=True) |
| 67 | + account_id: str = Column(String(length=320), index=True, nullable=False) |
| 68 | + account_email: str = Column(String(length=320), nullable=False) |
| 69 | + |
| 70 | + |
| 71 | +@declarative_mixin |
50 | 72 | class SQLAlchemyBaseOAuthAccountTableUUID(SQLAlchemyBaseOAuthAccountTable[UUID_ID]):
|
51 |
| - id: UUID_ID = Column(GUID, primary_key=True, default=uuid.uuid4) |
| 73 | + if TYPE_CHECKING: # pragma: no cover |
| 74 | + id: UUID_ID |
| 75 | + else: |
| 76 | + id: UUID_ID = Column(GUID, primary_key=True, default=uuid.uuid4) |
52 | 77 |
|
53 | 78 | @declared_attr
|
54 |
| - def user_id(cls): |
| 79 | + def user_id(cls) -> Column[GUID]: |
55 | 80 | return Column(GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False)
|
56 | 81 |
|
57 | 82 |
|
|
0 commit comments