diff --git a/pyproject.toml b/pyproject.toml index e8d4bada1..734975f69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tortoise-orm" -version = "0.21.8" +version = "0.21.7" description = "Easy async ORM for python, built with relations in mind" authors = ["Andrey Bondar ", "Nickolas Grigoriadis ", "long2ice "] license = "Apache-2.0" diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 656a9f7d8..9cdb78d94 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -518,7 +518,7 @@ async def init( cls._inited = True @classmethod - def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None): + def _init_routers(cls, routers: Optional[List[Union[str, type]]] = None) -> None: from tortoise.router import router routers = routers or [] diff --git a/tortoise/connection.py b/tortoise/connection.py index c65e16220..2ac6ddf6e 100644 --- a/tortoise/connection.py +++ b/tortoise/connection.py @@ -24,7 +24,7 @@ def __init__(self) -> None: self._db_config: Optional["DBConfigType"] = None self._create_db: bool = False - async def _init(self, db_config: "DBConfigType", create_db: bool): + async def _init(self, db_config: "DBConfigType", create_db: bool) -> None: if self._db_config is None: self._db_config = db_config else: diff --git a/tortoise/converters.py b/tortoise/converters.py index ab8e1a2c3..628992c05 100644 --- a/tortoise/converters.py +++ b/tortoise/converters.py @@ -14,7 +14,7 @@ _escape_table[ord("'")] = "\\'" -def _escape_unicode(value: str, mapping=None): +def _escape_unicode(value: str, mapping=None) -> str: """escapes *value* without adding quote. Value should be unicode diff --git a/tortoise/exceptions.py b/tortoise/exceptions.py index 4cad456f8..edfef4ebf 100644 --- a/tortoise/exceptions.py +++ b/tortoise/exceptions.py @@ -83,12 +83,12 @@ class ObjectDoesNotExistError(OperationalError, KeyError): The DoesNotExist exception is raised when an item with the passed primary key does not exist """ - def __init__(self, model: "Type[Model]", pk_name: str, pk_val: Any): + def __init__(self, model: "Type[Model]", pk_name: str, pk_val: Any) -> None: self.model: "Type[Model]" = model self.pk_name: str = pk_name self.pk_val: Any = pk_val - def __str__(self): + def __str__(self) -> str: return f"{self.model.__name__} has no object with {self.pk_name}={self.pk_val}" diff --git a/tortoise/expressions.py b/tortoise/expressions.py index e43ebdee5..f7049e4be 100644 --- a/tortoise/expressions.py +++ b/tortoise/expressions.py @@ -73,7 +73,7 @@ class Value(Expression): Wrapper for a value that should be used as a term in a query. """ - def __init__(self, value: Any): + def __init__(self, value: Any) -> None: self.value = value def resolve(self, resolve_context: ResolveContext) -> ResolveResult: @@ -90,7 +90,7 @@ class Connector(Enum): class CombinedExpression(Expression): - def __init__(self, left: Expression, connector: Connector, right: Any): + def __init__(self, left: Expression, connector: Connector, right: Any) -> None: self.left = left self.connector = connector self.right: Expression @@ -124,7 +124,7 @@ class F(Expression): :param name: The name of the field to reference. """ - def __init__(self, name: str): + def __init__(self, name: str) -> None: self.name = name def resolve(self, resolve_context: ResolveContext) -> ResolveResult: @@ -169,43 +169,43 @@ def _combine(self, other: Any, connector: Connector, right_hand: bool) -> Combin return CombinedExpression(other, connector, self) return CombinedExpression(self, connector, other) - def __neg__(self): + def __neg__(self) -> CombinedExpression: return self._combine(-1, Connector.mul, False) - def __add__(self, other): + def __add__(self, other) -> CombinedExpression: return self._combine(other, Connector.add, False) - def __sub__(self, other): + def __sub__(self, other) -> CombinedExpression: return self._combine(other, Connector.sub, False) - def __mul__(self, other): + def __mul__(self, other) -> CombinedExpression: return self._combine(other, Connector.mul, False) - def __truediv__(self, other): + def __truediv__(self, other) -> CombinedExpression: return self._combine(other, Connector.div, False) - def __mod__(self, other): + def __mod__(self, other) -> CombinedExpression: return self._combine(other, Connector.mod, False) - def __pow__(self, other): + def __pow__(self, other) -> CombinedExpression: return self._combine(other, Connector.pow, False) - def __radd__(self, other): + def __radd__(self, other) -> CombinedExpression: return self._combine(other, Connector.add, True) - def __rsub__(self, other): + def __rsub__(self, other) -> CombinedExpression: return self._combine(other, Connector.sub, True) - def __rmul__(self, other): + def __rmul__(self, other) -> CombinedExpression: return self._combine(other, Connector.mul, True) - def __rtruediv__(self, other): + def __rtruediv__(self, other) -> CombinedExpression: return self._combine(other, Connector.div, True) - def __rmod__(self, other): + def __rmod__(self, other) -> CombinedExpression: return self._combine(other, Connector.mod, True) - def __rpow__(self, other): + def __rpow__(self, other) -> CombinedExpression: return self._combine(other, Connector.pow, True) @@ -519,7 +519,7 @@ def __init__( self.field_object: "Optional[Field]" = None self.default_values = default_values - def _get_function_field(self, field: Union[Term, str], *default_values): + def _get_function_field(self, field: Union[Term, str], *default_values) -> PypikaFunction: return self.database_func(field, *default_values) # type:ignore[arg-type] def _resolve_nested_field(self, resolve_context: ResolveContext, field: str) -> ResolveResult: diff --git a/tortoise/fields/base.py b/tortoise/fields/base.py index 67972a179..966e0a266 100644 --- a/tortoise/fields/base.py +++ b/tortoise/fields/base.py @@ -52,7 +52,7 @@ class OnDelete(StrEnum): class _FieldMeta(type): # TODO: Require functions to return field instances instead of this hack - def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict): + def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict) -> type: if len(bases) > 1 and bases[0] is Field: # Instantiate class with only the 1st base class (should be Field) cls = type.__new__(mcs, name, (bases[0],), attrs) @@ -271,7 +271,7 @@ def to_python_value(self, value: Any) -> Any: value = self.field_type(value) # pylint: disable=E1102 return value - def validate(self, value: Any): + def validate(self, value: Any) -> None: """ Validate whether given value is valid diff --git a/tortoise/filters.py b/tortoise/filters.py index 2079e66ff..d5d9ea719 100644 --- a/tortoise/filters.py +++ b/tortoise/filters.py @@ -123,12 +123,12 @@ def contains(field: Term, value: str) -> Criterion: return Like(Cast(field, SqlTypes.VARCHAR), field.wrap_constant(f"%{escape_like(value)}%")) -def search(field: Term, value: str): +def search(field: Term, value: str) -> Any: # will be override in each executor pass -def posix_regex(field: Term, value: str): +def posix_regex(field: Term, value: str) -> Any: # Will be overridden in each executor raise NotImplementedError( "The postgres_posix_regex filter operator is not supported by your database backend" diff --git a/tortoise/indexes.py b/tortoise/indexes.py index bb8bb1611..ffab65da2 100644 --- a/tortoise/indexes.py +++ b/tortoise/indexes.py @@ -38,7 +38,7 @@ def __init__( self.expressions = expressions self.extra = "" - def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool): + def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", safe: bool) -> str: if self.fields: fields = ", ".join(schema_generator.quote(f) for f in self.fields) else: @@ -54,7 +54,7 @@ def get_sql(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]", extra=self.extra, ) - def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]"): + def index_name(self, schema_generator: "BaseSchemaGenerator", model: "Type[Model]") -> str: return self.name or schema_generator._generate_index_name("idx", model, self.fields) diff --git a/tortoise/manager.py b/tortoise/manager.py index e94d369b6..bb2c948be 100644 --- a/tortoise/manager.py +++ b/tortoise/manager.py @@ -1,3 +1,5 @@ +from typing import Any + from tortoise.queryset import QuerySet @@ -14,5 +16,5 @@ def __init__(self, model=None) -> None: def get_queryset(self) -> QuerySet: return QuerySet(self._model) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: return getattr(self.get_queryset(), item) diff --git a/tortoise/models.py b/tortoise/models.py index 7eb7ee8f1..ca5c98aa5 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -487,7 +487,7 @@ def _generate_filters(self) -> None: class ModelMeta(type): __slots__ = () - def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict): + def __new__(mcs, name: str, bases: Tuple[Type, ...], attrs: dict) -> "ModelMeta": fields_db_projection: Dict[str, str] = {} fields_map: Dict[str, Field] = {} filters: Dict[str, FilterInfoDict] = {} @@ -676,7 +676,7 @@ def __init__(self, **kwargs: Any) -> None: else: setattr(self, key, deepcopy(field_object.default)) - def __setattr__(self, key, value): + def __setattr__(self, key, value) -> None: # set field value override async default function if hasattr(self, "_await_when_save"): self._await_when_save.pop(key, None) @@ -782,7 +782,7 @@ def __hash__(self) -> int: raise TypeError("Model instances without id are unhashable") return hash(self.pk) - def __iter__(self): + def __iter__(self) -> Iterable[Tuple]: for field in self._meta.db_fields: yield field, getattr(self, field) @@ -850,7 +850,7 @@ def update_from_dict(self: MODEL, data: dict) -> MODEL: return self @classmethod - def register_listener(cls, signal: Signals, listener: Callable): + def register_listener(cls, signal: Signals, listener: Callable) -> None: """ Register listener to current model class for special Signal. @@ -1020,7 +1020,7 @@ async def refresh_from_db( setattr(self, field, getattr(obj, field, None)) @classmethod - def _choose_db(cls, for_write: bool = False): + def _choose_db(cls, for_write: bool = False) -> BaseDBAsyncClient: """ Return the connection that will be used if this query is executed now. diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 8e70ad8fe..31f3e5d73 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -164,7 +164,7 @@ def _join_table_by_field( self._join_table(join) return joins[-1][0] - def _join_table(self, table_criterio_tuple: TableCriterionTuple): + def _join_table(self, table_criterio_tuple: TableCriterionTuple) -> None: if table_criterio_tuple[0] not in self._joined_tables: self.query = self.query.join(table_criterio_tuple[0], how=JoinType.left_outer).on( table_criterio_tuple[1] @@ -1491,7 +1491,7 @@ def resolve_to_python_value(self, model: Type[MODEL], field: str) -> Callable: raise FieldError(f'Unknown field "{field}" for model "{model}"') - def _resolve_group_bys(self, *field_names: str): + def _resolve_group_bys(self, *field_names: str) -> List: group_bys = [] for field_name in field_names: if field_name in self._annotations: @@ -1777,7 +1777,7 @@ async def _execute(self) -> Union[List[dict], Dict]: class RawSQLQuery(AwaitableQuery): __slots__ = ("_sql", "_db") - def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str): + def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str) -> None: super().__init__(model) self._sql = sql self._db = db diff --git a/tortoise/router.py b/tortoise/router.py index 7ce5f8503..a2e4ab5d6 100644 --- a/tortoise/router.py +++ b/tortoise/router.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, List, Optional, Type +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Type from tortoise.connection import connections from tortoise.exceptions import ConfigurationError @@ -9,12 +11,12 @@ class ConnectionRouter: def __init__(self) -> None: - self._routers: List[type] = None # type: ignore + self._routers: list[type] = None # type: ignore - def init_routers(self, routers: List[type]): + def init_routers(self, routers: list[Callable]) -> None: self._routers = [r() for r in routers] - def _router_func(self, model: Type["Model"], action: str): + def _router_func(self, model: Type["Model"], action: str) -> Any: for r in self._routers: try: method = getattr(r, action) @@ -26,16 +28,16 @@ def _router_func(self, model: Type["Model"], action: str): if chosen_db: return chosen_db - def _db_route(self, model: Type["Model"], action: str): + def _db_route(self, model: Type["Model"], action: str) -> "BaseDBAsyncClient" | None: try: return connections.get(self._router_func(model, action)) except ConfigurationError: return None - def db_for_read(self, model: Type["Model"]) -> Optional["BaseDBAsyncClient"]: + def db_for_read(self, model: Type["Model"]) -> "BaseDBAsyncClient" | None: return self._db_route(model, "db_for_read") - def db_for_write(self, model: Type["Model"]) -> Optional["BaseDBAsyncClient"]: + def db_for_write(self, model: Type["Model"]) -> "BaseDBAsyncClient" | None: return self._db_route(model, "db_for_write") diff --git a/tortoise/signals.py b/tortoise/signals.py index 0b4de29bf..07c00ab79 100644 --- a/tortoise/signals.py +++ b/tortoise/signals.py @@ -1,17 +1,19 @@ from enum import Enum -from typing import Callable +from typing import Callable, TypeVar +T = TypeVar("T") +FuncType = Callable[[T], T] Signals = Enum("Signals", ["pre_save", "post_save", "pre_delete", "post_delete"]) -def post_save(*senders) -> Callable: +def post_save(*senders) -> FuncType: """ Register given models post_save signal. :param senders: Model class """ - def decorator(f): + def decorator(f: T) -> T: for sender in senders: sender.register_listener(Signals.post_save, f) return f @@ -19,14 +21,14 @@ def decorator(f): return decorator -def pre_save(*senders) -> Callable: +def pre_save(*senders) -> FuncType: """ Register given models pre_save signal. :param senders: Model class """ - def decorator(f): + def decorator(f: T) -> T: for sender in senders: sender.register_listener(Signals.pre_save, f) return f @@ -34,14 +36,14 @@ def decorator(f): return decorator -def pre_delete(*senders) -> Callable: +def pre_delete(*senders) -> FuncType: """ Register given models pre_delete signal. :param senders: Model class """ - def decorator(f): + def decorator(f: T) -> T: for sender in senders: sender.register_listener(Signals.pre_delete, f) return f @@ -49,14 +51,14 @@ def decorator(f): return decorator -def post_delete(*senders) -> Callable: +def post_delete(*senders) -> FuncType: """ Register given models post_delete signal. :param senders: Model class """ - def decorator(f): + def decorator(f: T) -> T: for sender in senders: sender.register_listener(Signals.post_delete, f) return f diff --git a/tortoise/transactions.py b/tortoise/transactions.py index 8a1d92b88..9c2180184 100644 --- a/tortoise/transactions.py +++ b/tortoise/transactions.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast +from typing import TYPE_CHECKING, Callable, Optional, TypeVar, cast from tortoise import connections from tortoise.exceptions import ParamsError @@ -7,7 +7,8 @@ if TYPE_CHECKING: # pragma: nocoverage from tortoise.backends.base.client import BaseDBAsyncClient, TransactionContext -FuncType = Callable[..., Any] +T = TypeVar("T") +FuncType = Callable[..., T] F = TypeVar("F", bound=FuncType) @@ -52,7 +53,7 @@ def atomic(connection_name: Optional[str] = None) -> Callable[[F], F]: def wrapper(func: F) -> F: @wraps(func) - async def wrapped(*args, **kwargs): + async def wrapped(*args, **kwargs) -> T: async with in_transaction(connection_name): return await func(*args, **kwargs) diff --git a/tortoise/validators.py b/tortoise/validators.py index 4ffbd45ca..2fa83350a 100644 --- a/tortoise/validators.py +++ b/tortoise/validators.py @@ -1,15 +1,17 @@ +from __future__ import annotations + import abc import ipaddress import re from decimal import Decimal -from typing import Any, Union +from typing import Any from tortoise.exceptions import ValidationError class Validator(metaclass=abc.ABCMeta): @abc.abstractmethod - def __call__(self, value: Any): + def __call__(self, value: Any) -> None: """ All specific validators should implement this method. @@ -23,10 +25,10 @@ class RegexValidator(Validator): A validator to validate the given value whether match regex or not. """ - def __init__(self, pattern: str, flags: Union[int, re.RegexFlag]): + def __init__(self, pattern: str, flags: int | re.RegexFlag) -> None: self.regex = re.compile(pattern, flags) - def __call__(self, value: Any): + def __call__(self, value: Any) -> None: if not self.regex.match(value): raise ValidationError(f"Value '{value}' does not match regex '{self.regex.pattern}'") @@ -36,10 +38,10 @@ class MaxLengthValidator(Validator): A validator to validate the length of given value whether greater than max_length or not. """ - def __init__(self, max_length: int): + def __init__(self, max_length: int) -> None: self.max_length = max_length - def __call__(self, value: str): + def __call__(self, value: str) -> None: if value is None: raise ValidationError("Value must not be None") if len(value) > self.max_length: @@ -51,48 +53,50 @@ class MinLengthValidator(Validator): A validator to validate the length of given value whether less than min_length or not. """ - def __init__(self, min_length: int): + def __init__(self, min_length: int) -> None: self.min_length = min_length - def __call__(self, value: str): + def __call__(self, value: str) -> None: if value is None: raise ValidationError("Value must not be None") if len(value) < self.min_length: raise ValidationError(f"Length of '{value}' {len(value)} < {self.min_length}") -class MinValueValidator(Validator): +class NumericValidator(Validator): + types = (int, float, Decimal) + + def _validate_type(self, value: Any) -> None: + if not isinstance(value, self.types): + raise ValidationError("Value must be a numeric value and is required") + + +class MinValueValidator(NumericValidator): """ Min value validator for FloatField, IntField, SmallIntField, BigIntField """ - def __init__(self, min_value: Union[int, float, Decimal]): - if not isinstance(min_value, (int, float, Decimal)): - raise ValidationError("Value must be a numeric value and is required") + def __init__(self, min_value: int | float | Decimal) -> None: + self._validate_type(min_value) self.min_value = min_value - def __call__(self, value: Union[int, float, Decimal]): - if not isinstance(value, (int, float, Decimal)): - raise ValidationError("Value must be a numeric value and is required") - + def __call__(self, value: int | float | Decimal) -> None: + self._validate_type(value) if value < self.min_value: raise ValidationError(f"Value should be greater or equal to {self.min_value}") -class MaxValueValidator(Validator): +class MaxValueValidator(NumericValidator): """ Max value validator for FloatField, IntField, SmallIntField, BigIntField """ - def __init__(self, max_value: Union[int, float, Decimal]): - if not isinstance(max_value, (int, float, Decimal)): - raise ValidationError("Value must be a numeric value and is required") + def __init__(self, max_value: int | float | Decimal) -> None: + self._validate_type(max_value) self.max_value = max_value - def __call__(self, value: Union[int, float, Decimal]): - if not isinstance(value, (int, float, Decimal)): - raise ValidationError("Value must be a numeric value and is required") - + def __call__(self, value: int | float | Decimal) -> None: + self._validate_type(value) if value > self.max_value: raise ValidationError(f"Value should be less or equal to {self.max_value}") @@ -102,18 +106,18 @@ class CommaSeparatedIntegerListValidator(Validator): A validator to validate whether the given value is valid comma separated integer list or not. """ - def __init__(self, allow_negative: bool = False): + def __init__(self, allow_negative: bool = False) -> None: pattern = r"^%(neg)s\d+(?:%(sep)s%(neg)s\d+)*\Z" % { "neg": "(-)?" if allow_negative else "", "sep": re.escape(","), } self.regex = RegexValidator(pattern, re.I) - def __call__(self, value: str): + def __call__(self, value: str) -> None: self.regex(value) -def validate_ipv4_address(value: Any): +def validate_ipv4_address(value: Any) -> None: """ A validator to validate whether the given value is valid IPv4Address or not. @@ -126,7 +130,7 @@ def validate_ipv4_address(value: Any): raise ValidationError(f"'{value}' is not a valid IPv4 address.") -def validate_ipv6_address(value: Any): +def validate_ipv6_address(value: Any) -> None: """ A validator to validate whether the given value is valid IPv6Address or not. @@ -138,7 +142,7 @@ def validate_ipv6_address(value: Any): raise ValidationError(f"'{value}' is not a valid IPv6 address.") -def validate_ipv46_address(value: Any): +def validate_ipv46_address(value: Any) -> None: """ A validator to validate whether the given value is valid IPv4Address or IPv6Address or not.