Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions stubs/PyMySQL/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# DictCursorMixin changes method types of inherited classes, but doesn't contain much at runtime
pymysql.cursors.DictCursorMixin.__iter__
pymysql.cursors.DictCursorMixin.fetch[a-z]*

# FIXME: new stubtest errors from mypy v1.18.1 that need to be looked at more closely.
# See https://github.com/python/typeshed/pull/14699
pymysql.connections.Connection.__init__
15 changes: 15 additions & 0 deletions stubs/PyMySQL/@tests/test_cases/check_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing_extensions import assert_type

from pymysql.connections import Connection
from pymysql.cursors import Cursor


class MyCursor(Cursor):
pass


assert_type(Connection(), Connection[Cursor])
assert_type(Connection(cursorclass=Cursor), Connection[Cursor])
assert_type(Connection(cursorclass=MyCursor), Connection[MyCursor])

Connection(cursorclass=None) # type: ignore
109 changes: 7 additions & 102 deletions stubs/PyMySQL/pymysql/connections.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ from _typeshed import FileDescriptorOrPath, Incomplete, Unused
from collections.abc import Callable, Mapping
from socket import _Address, socket as _socket
from ssl import SSLContext, _PasswordType
from typing import Any, AnyStr, Generic, TypeVar, overload
from typing_extensions import Self, deprecated
from typing import Any, AnyStr, Generic, overload
from typing_extensions import Self, TypeVar, deprecated

from .charset import charset_by_id as charset_by_id, charset_by_name as charset_by_name
from .constants import CLIENT as CLIENT, COMMAND as COMMAND, FIELD_TYPE as FIELD_TYPE, SERVER_STATUS as SERVER_STATUS
Expand All @@ -21,7 +21,7 @@ from .err import (
Warning,
)

_C = TypeVar("_C", bound=Cursor)
_C = TypeVar("_C", bound=Cursor, default=Cursor)
_C2 = TypeVar("_C2", bound=Cursor)

SSL_ENABLED: bool
Expand Down Expand Up @@ -61,7 +61,7 @@ class Connection(Generic[_C]):

@overload
def __init__(
self: Connection[Cursor], # different between overloads
self,
*,
user: str | bytes | None = None,
password: str | bytes = "",
Expand All @@ -76,7 +76,7 @@ class Connection(Generic[_C]):
conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None,
use_unicode: bool = True,
client_flag: int = 0,
cursorclass: None = None, # different between overloads
cursorclass: type[_C] = ...,
init_command: str | None = None,
connect_timeout: float = 10,
read_default_group: str | None = None,
Expand Down Expand Up @@ -106,104 +106,9 @@ class Connection(Generic[_C]):
db: None = None, # deprecated
) -> None: ...
@overload
def __init__(
# different between overloads
self: Connection[_C], # pyright: ignore[reportInvalidTypeVarUse] #11780
*,
user: str | bytes | None = None,
password: str | bytes = "",
host: str | None = None,
database: str | bytes | None = None,
unix_socket: _Address | None = None,
port: int = 0,
charset: str = "",
collation: str | None = None,
sql_mode: str | None = None,
read_default_file: str | None = None,
conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None,
use_unicode: bool = True,
client_flag: int = 0,
cursorclass: type[_C] = ..., # different between overloads
init_command: str | None = None,
connect_timeout: float = 10,
read_default_group: str | None = None,
autocommit: bool | None = False,
local_infile: bool = False,
max_allowed_packet: int = 16_777_216,
defer_connect: bool = False,
auth_plugin_map: dict[str, Callable[[Connection[Any]], Any]] | None = None,
read_timeout: float | None = None,
write_timeout: float | None = None,
bind_address: str | None = None,
binary_prefix: bool = False,
program_name: str | None = None,
server_public_key: bytes | None = None,
ssl: dict[str, Incomplete] | SSLContext | None = None,
ssl_ca: str | None = None,
ssl_cert: str | None = None,
ssl_disabled: bool | None = None,
ssl_key: str | None = None,
ssl_key_password: _PasswordType | None = None,
ssl_verify_cert: bool | None = None,
ssl_verify_identity: bool | None = None,
compress: Unused = None,
named_pipe: Unused = None,
# different between overloads:
passwd: None = None, # deprecated
db: None = None, # deprecated
) -> None: ...
@overload
@deprecated("'passwd' and 'db' arguments are deprecated. Use 'password' and 'database' instead.")
def __init__(
self: Connection[Cursor], # different between overloads
*,
user: str | bytes | None = None,
password: str | bytes = "",
host: str | None = None,
database: str | bytes | None = None,
unix_socket: _Address | None = None,
port: int = 0,
charset: str = "",
collation: str | None = None,
sql_mode: str | None = None,
read_default_file: str | None = None,
conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None,
use_unicode: bool = True,
client_flag: int = 0,
cursorclass: None = None, # different between overloads
init_command: str | None = None,
connect_timeout: float = 10,
read_default_group: str | None = None,
autocommit: bool | None = False,
local_infile: bool = False,
max_allowed_packet: int = 16_777_216,
defer_connect: bool = False,
auth_plugin_map: dict[str, Callable[[Connection[Any]], Any]] | None = None,
read_timeout: float | None = None,
write_timeout: float | None = None,
bind_address: str | None = None,
binary_prefix: bool = False,
program_name: str | None = None,
server_public_key: bytes | None = None,
ssl: dict[str, Incomplete] | SSLContext | None = None,
ssl_ca: str | None = None,
ssl_cert: str | None = None,
ssl_disabled: bool | None = None,
ssl_key: str | None = None,
ssl_key_password: _PasswordType | None = None,
ssl_verify_cert: bool | None = None,
ssl_verify_identity: bool | None = None,
compress: Unused = None,
named_pipe: Unused = None,
# different between overloads:
passwd: str | bytes | None = None, # deprecated
db: str | bytes | None = None, # deprecated
) -> None: ...
@overload
@deprecated("'passwd' and 'db' arguments are deprecated. Use 'password' and 'database' instead.")
def __init__(
# different between overloads
self: Connection[_C], # pyright: ignore[reportInvalidTypeVarUse] #11780
self,
*,
user: str | bytes | None = None,
password: str | bytes = "",
Expand All @@ -218,7 +123,7 @@ class Connection(Generic[_C]):
conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None,
use_unicode: bool = True,
client_flag: int = 0,
cursorclass: type[_C] = ..., # different between overloads
cursorclass: type[_C] = ...,
init_command: str | None = None,
connect_timeout: float = 10,
read_default_group: str | None = None,
Expand Down
Loading