-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add type annotations for peewee
#4776
Conversation
dargueta
commented
Nov 18, 2020
•
edited
Loading
edited
- AsIs
- AutoField
- BareField
- BigAutoField
- BigBitField
- BigIntegerField
- BinaryUUIDField
- BitField
- BlobField
- BooleanField
- Case
- Cast
- CharField
- Check
- chunked
- Column
- CompositeKey
- Context
- Database
- DatabaseError
- DatabaseProxy
- DataError
- DateField
- DateTimeField
- DecimalField
- DeferredForeignKey
- DeferredThroughModel
- DJANGO_MAP
- DoesNotExist
- DoubleField
- DQ
- EXCLUDED
- Field
- FixedCharField
- FloatField
- fn
- ForeignKeyField
- IdentityField
- ImproperlyConfigured
- Index
- IntegerField
- IntegrityError
- InterfaceError
- InternalError
- IPField
- JOIN
- ManyToManyField
- Model
- ModelIndex
- MySQLDatabase
- NotSupportedError
- OP
- OperationalError
- PostgresqlDatabase
- PrimaryKeyField (Deprecated)
- ProgrammingError
- Proxy
- QualifiedNames
- SchemaManager
- SmallIntegerField
- Select
- SQL
- SqliteDatabase
- Table
- TextField
- TimeField
- TimestampField
- Tuple
- UUIDField
- Value
- ValuesList
- Window
A few short hints:
|
You've probably seen #4262 already, but flagging it here just in case. Thanks for working on this! |
I haven't, actually. Thanks for pointing it out! |
Same heads up as for #4262: Next week we will reshuffle the layout of typeshed. This means that the files in this PR will also need to be moved. This shouldn't be much of a problem, though, as there should be no conflicts. |
Thanks for the warning. I'll be sure to do that! |
By the way @srittau I have a few problems with this code and was wondering if there are conventions for handling them. Most of these surround incompatible subclasses. "Overloaded function signatures 1 and 2 overlap with incompatible return types"Given code like this: def __eq__(self, other):
if isinstance(other, _HashableSource):
return other is self
return Expression(...) I've tried typing it like this: @overload
def __eq__(self, other: _HashableSource) -> bool: ...
@overload
def __eq__(self, other: object) -> Expression: ... MyPy's guidelines for overloading make it seem like this is the proper way to annotate this (specifically this bit) but I'm still getting an error from the linter. This seems like a common enough case that there should be a way to type the function accurately, right? "Return type Expression of '__eq__' incompatible with return type 'bool' in supertype 'object'"The above code also throws a second error that I'm not sure how to get around. This is an ORM so a non-boolean return value from Additional required argumentsThere are numerous cases where we have something like: class A:
def method(self, x: int, y: int) -> bool: ...
class B(A):
# Return type is the same but we have an additional required argument
def method(self, x: int, y: int, z: int) -> bool: ...
class C(A):
# Arguments are the same but the return type is incompatible
def method(self, x: int, y: int) -> OtherClass: ... The way I see it, there are two options:
Any thoughts? Sorry to dump all this on you at once. |
Both of there require a
|
@dargueta Any updates? |
I've been poking at it locally but haven't pushed anything up yet. |
Going to close this as stale for now to keep the list of open PRs manageable. If you're still interested in seeing this through, feel free to reopen or open a new PR. |
Actually, reopening since there's fairly recent activity. I'd suggest landing a minimal version first instead of trying to get everything fully annotated. |
Appreciate your work here @dargueta. One final failing check |
I'll try to get to it this weekend. |
Hej, does this patch solve the linter complaints? diff --git a/third_party/2and3/peewee.pyi b/third_party/2and3/peewee.pyi
index 0d2adef2..423c44d2 100644
--- a/third_party/2and3/peewee.pyi
+++ b/third_party/2and3/peewee.pyi
@@ -34,7 +34,7 @@ from typing import (
)
from typing_extensions import Literal, Protocol
-T = TypeVar("T")
+_T = TypeVar("_T")
_TModel = TypeVar("_TModel", bound="Model")
_TConvFunc = Callable[[Any], Any]
_TFunc = TypeVar("_TFunc", bound=Callable)
@@ -213,16 +213,16 @@ class _DynamicColumn:
@overload
def __get__(self, instance: None, instance_type: type) -> _DynamicColumn: ...
@overload
- def __get__(self, instance: T, instance_type: Type[T]) -> ColumnFactory: ...
+ def __get__(self, instance: _T, instance_type: Type[_T]) -> ColumnFactory: ...
class _ExplicitColumn:
@overload
def __get__(self, instance: None, instance_type: type) -> _ExplicitColumn: ...
@overload
- def __get__(self, instance: T, instance_type: Type[T]) -> NoReturn: ...
+ def __get__(self, instance: _T, instance_type: Type[_T]) -> NoReturn: ...
class _SupportsAlias(Protocol):
- def alias(self: T, name: str) -> T: ...
+ def alias(self: _T, name: str) -> _T: ...
class Source(_SupportsAlias, Node):
c: ClassVar[_DynamicColumn]
@@ -420,7 +420,7 @@ class _DynamicEntity:
@overload
def __get__(self, instance: None, instance_type: type) -> _DynamicEntity: ...
@overload
- def __get__(self, instance: T, instance_type: Type[T]) -> EntityFactory: ...
+ def __get__(self, instance: _T, instance_type: Type[_T]) -> EntityFactory: ...
class Alias(WrappedNode):
c: ClassVar[_DynamicEntity]
@@ -587,7 +587,7 @@ class ForUpdate(Node):
def __init__(
self,
expr: Union[Literal[True], str],
- of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...],]] = ...,
+ of: Optional[Union[_TModelOrTable, List[_TModelOrTable], Set[_TModelOrTable], Tuple[_TModelOrTable, ...], ]] = ...,
nowait: Optional[bool] = ...,
): ...
def __sql__(self, ctx: Context) -> Context: ...
@@ -632,7 +632,7 @@ def qualify_names(node: Expression) -> Expression: ...
@overload
def qualify_names(node: ColumnBase) -> QualifiedNames: ...
@overload
-def qualify_names(node: T) -> T: ...
+def qualify_names(node: _T) -> _T: ...
class OnConflict(Node):
@overload
@@ -1001,7 +1001,7 @@ class Database(_callable_context_manager):
def begin(self) -> None: ...
def commit(self) -> None: ...
def rollback(self) -> None: ...
- def batch_commit(self, it: Iterable[T], n: int) -> Iterator[T]: ...
+ def batch_commit(self, it: Iterable[_T], n: int) -> Iterator[_T]: ...
def table_exists(self, table_name: str, schema: Optional[str] = ...) -> str: ...
def get_tables(self, schema: Optional[str] = ...) -> List[str]: ...
def get_indexes(self, table: str, schema: Optional[str] = ...) -> List[IndexMetadata]: ...
@@ -1193,24 +1193,24 @@ class _savepoint(_callable_context_manager):
def __enter__(self) -> _savepoint: ...
def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: object) -> None: ...
-class CursorWrapper(Generic[T]):
+class CursorWrapper(Generic[_T]):
cursor: __ICursor
count: int
index: int
initialized: bool
populated: bool
- row_cache: List[T]
+ row_cache: List[_T]
def __init__(self, cursor: __ICursor): ...
- def __iter__(self) -> Union[ResultIterator[T], Iterator[T]]: ...
+ def __iter__(self) -> Union[ResultIterator[_T], Iterator[_T]]: ...
@overload
- def __getitem__(self, item: int) -> T: ...
+ def __getitem__(self, item: int) -> _T: ...
@overload
- def __getitem__(self, item: slice) -> List[T]: ...
+ def __getitem__(self, item: slice) -> List[_T]: ...
def __len__(self) -> int: ...
def initialize(self) -> None: ...
- def iterate(self, cache: bool = ...) -> T: ...
- def process_row(self, row: tuple) -> T: ...
- def iterator(self) -> Iterator[T]: ...
+ def iterate(self, cache: bool = ...) -> _T: ...
+ def process_row(self, row: tuple) -> _T: ...
+ def iterator(self) -> Iterator[_T]: ...
def fill_cache(self, n: int = ...) -> None: ...
class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ...
@@ -1219,16 +1219,16 @@ class DictCursorWrapper(CursorWrapper[Mapping[str, object]]): ...
class NamedTupleCursorWrapper(CursorWrapper[tuple]):
tuple_class: Type[tuple]
-class ObjectCursorWrapper(DictCursorWrapper[T]):
- constructor: Callable[..., T]
- def __init__(self, cursor: __ICursor, constructor: Callable[..., T]): ...
- def process_row(self, row: tuple) -> T: ... # type: ignore
+class ObjectCursorWrapper(DictCursorWrapper[_T]):
+ constructor: Callable[..., _T]
+ def __init__(self, cursor: __ICursor, constructor: Callable[..., _T]): ...
+ def process_row(self, row: tuple) -> _T: ... # type: ignore
-class ResultIterator(Generic[T]):
- cursor_wrapper: CursorWrapper[T]
+class ResultIterator(Generic[_T]):
+ cursor_wrapper: CursorWrapper[_T]
index: int
- def __init__(self, cursor_wrapper: CursorWrapper[T]): ...
- def __iter__(self) -> Iterator[T]: ...
+ def __init__(self, cursor_wrapper: CursorWrapper[_T]): ...
+ def __iter__(self) -> Iterator[_T]: ...
# FIELDS
@@ -1240,7 +1240,7 @@ class FieldAccessor:
@overload
def __get__(self, instance: None, instance_type: type) -> Field: ...
@overload
- def __get__(self, instance: T, instance_type: Type[T]) -> Any: ...
+ def __get__(self, instance: _T, instance_type: Type[_T]) -> Any: ...
class ForeignKeyAccessor(FieldAccessor):
model: Type[Model]
@@ -1324,9 +1324,9 @@ class Field(ColumnBase):
def bind(self, model: Type[Model], name: str, set_attribute: bool = ...) -> None: ...
@property
def column(self) -> Column: ...
- def adapt(self, value: T) -> T: ...
- def db_value(self, value: T) -> T: ...
- def python_value(self, value: T) -> T: ...
+ def adapt(self, value: _T) -> _T: ...
+ def db_value(self, value: _T) -> _T: ...
+ def python_value(self, value: _T) -> _T: ...
def to_value(self, value: Any) -> Value: ...
def get_sort_key(self, ctx: Context) -> Tuple[int, int]: ...
def __sql__(self, ctx: Context) -> Context: ...
@@ -1338,7 +1338,7 @@ class IntegerField(Field):
@overload
def adapt(self, value: Union[str, float, bool]) -> int: ... # type: ignore
@overload
- def adapt(self, value: T) -> T: ...
+ def adapt(self, value: _T) -> _T: ...
class BigIntegerField(IntegerField): ...
class SmallIntegerField(IntegerField): ...
@@ -1357,7 +1357,7 @@ class FloatField(Field):
@overload
def adapt(self, value: Union[str, float, bool]) -> float: ... # type: ignore
@overload
- def adapt(self, value: T) -> T: ...
+ def adapt(self, value: _T) -> _T: ...
class DoubleField(FloatField): ...
@@ -1381,7 +1381,7 @@ class DecimalField(Field):
@overload
def db_value(self, value: Union[float, decimal.Decimal]) -> decimal.Decimal: ... # type: ignore
@overload
- def db_value(self, value: T) -> T: ...
+ def db_value(self, value: _T) -> _T: ...
@overload
def python_value(self, value: None) -> None: ...
@overload
@@ -1404,7 +1404,7 @@ class BlobField(Field):
@overload
def db_value(self, value: Union[str, bytes]) -> bytearray: ...
@overload
- def db_value(self, value: T) -> T: ...
+ def db_value(self, value: _T) -> _T: ...
class BitField(BitwiseMixin, BigIntegerField):
def __init__(self, *args: object, default: Optional[int] = ..., **kwargs: object): ...
@@ -1434,13 +1434,13 @@ class BigBitField(BlobField):
@overload
def db_value(self, value: None) -> None: ...
@overload
- def db_value(self, value: T) -> bytes: ...
+ def db_value(self, value: _T) -> bytes: ...
class UUIDField(Field):
@overload
def db_value(self, value: AnyStr) -> str: ...
@overload
- def db_value(self, value: T) -> T: ...
+ def db_value(self, value: _T) -> _T: ...
@overload
def python_value(self, value: Union[uuid.UUID, AnyStr]) -> uuid.UUID: ...
@overload
@@ -1458,7 +1458,7 @@ class BinaryUUIDField(BlobField):
def format_date_time(value: str, formats: Iterable[str], post_process: Optional[_TConvFunc] = ...) -> str: ...
@overload
-def simple_date_time(value: T) -> T: ...
+def simple_date_time(value: _T) -> _T: ...
class _BaseFormattedField(Field):
# TODO (dargueta): This is a class variable that can be overridden for instances
@@ -1478,7 +1478,7 @@ class DateTimeField(_BaseFormattedField):
def minute(self) -> int: ...
@property
def second(self) -> int: ...
- def adapt(self, value: T) -> T: ...
+ def adapt(self, value: _T) -> _T: ...
def to_timestamp(self) -> Function: ...
def truncate(self, part: str) -> Function: ...
@@ -1492,7 +1492,7 @@ class DateField(_BaseFormattedField):
@overload
def adapt(self, value: datetime.datetime) -> datetime.date: ...
@overload
- def adapt(self, value: T) -> T: ...
+ def adapt(self, value: _T) -> _T: ...
def to_timestamp(self) -> Function: ...
def truncate(self, part: str) -> Function: ...
@@ -1500,7 +1500,7 @@ class TimeField(_BaseFormattedField):
@overload
def adapt(self, value: Union[datetime.datetime, datetime.timedelta]) -> datetime.time: ...
@overload
- def adapt(self, value: T) -> T: ...
+ def adapt(self, value: _T) -> _T: ...
@property
def hour(self) -> int: ...
@property
@@ -1525,7 +1525,7 @@ class TimestampField(BigIntegerField):
@overload
def python_value(self, value: Union[int, float]) -> datetime.datetime: ...
@overload
- def python_value(self, value: T) -> T: ...
+ def python_value(self, value: _T) -> _T: ...
def from_timestamp(self) -> float: ...
@property
def year(self) -> int: ...
@@ -1632,12 +1632,12 @@ class ManyToManyFieldAccessor(FieldAccessor):
dest_fk: ForeignKeyField
def __init__(self, model: Type[Model], field: ForeignKeyField, name: str): ...
@overload
- def __get__(self, instance: None, instance_type: Type[T] = ..., force_query: bool = ...) -> Field: ...
+ def __get__(self, instance: None, instance_type: Type[_T] = ..., force_query: bool = ...) -> Field: ...
@overload
def __get__(
- self, instance: T, instance_type: Type[T] = ..., force_query: bool = ...
+ self, instance: _T, instance_type: Type[_T] = ..., force_query: bool = ...
) -> Union[List[str], ManyToManyQuery]: ...
- def __set__(self, instance: T, value) -> None: ...
+ def __set__(self, instance: _T, value) -> None: ...
class ManyToManyField(MetaField):
accessor_class: ClassVar[Type[ManyToManyFieldAccessor]]
@@ -1679,7 +1679,7 @@ class CompositeKey(MetaField):
@overload
def __get__(self, instance: None, instance_type: type) -> CompositeKey: ...
@overload
- def __get__(self, instance: T, instance_type: Type[T]) -> tuple: ...
+ def __get__(self, instance: _T, instance_type: Type[_T]) -> tuple: ...
def __set__(self, instance: Model, value: Union[list, tuple]) -> None: ...
def __eq__(self, other: Expression) -> Expression: ...
def __ne__(self, other: Expression) -> Expression: ...
@@ -1843,7 +1843,7 @@ class Model(Node, metaclass=ModelBase):
@classmethod
def delete(cls) -> ModelDelete: ...
@classmethod
- def create(cls: Type[T], **query) -> T: ...
+ def create(cls: Type[_T], **query) -> _T: ...
@classmethod
def bulk_create(cls, model_list: Iterable[Type[Model]], batch_size: Optional[int] = ...) -> None: ...
@classmethod
@@ -1874,7 +1874,7 @@ class Model(Node, metaclass=ModelBase):
@property
def dirty_fields(self) -> List[Field]: ...
def dependencies(self, search_nullable: bool = ...) -> Iterator[Tuple[Union[bool, Node], ForeignKeyField]]: ...
- def delete_instance(self: T, recursive: bool = ..., delete_nullable: bool = ...) -> T: ...
+ def delete_instance(self: _T, recursive: bool = ..., delete_nullable: bool = ...) -> _T: ...
def __hash__(self) -> int: ...
def __eq__(self, other: object) -> bool: ...
def __ne__(self, other: object) -> bool: ...
@@ -1898,8 +1898,7 @@ class Model(Node, metaclass=ModelBase):
@classmethod
def truncate_table(cls, **options: object) -> None: ...
@classmethod
- def index(cls, *fields: Union[Field, Node, str], **kwargs: object) -> ModelIndex:
- return ModelIndex(cls, fields, **kwargs)
+ def index(cls, *fields: Union[Field, Node, str], **kwargs: object) -> ModelIndex: ...
@classmethod
def add_index(cls, *fields: Union[str, SQL, Index], **kwargs: object) -> None: ... |
I'll check it out, thanks |
Never mind, I just did it manually. Hopefully I didn't miss anything. |
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
@dargueta, would you be able to look at the remaining mypy errors? In some cases the "overlapping overloads" errors can probably be |
We could split this up into multiple smaller PRs by first auto-generating peewee stubs (there are instructions in CONTRIBUTING.md), and then improving them by applying individual things from this PR. That might be easier than trying to get the whole CI to pass at once, but then dargueta likely won't show up as an author of most of the improvement commits. |
According to mypy_primer, this change has no effect on the checked open source code. 🤖🎉 |