Skip to content

Commit 2081b9d

Browse files
[PyMySQL] Fix stubs for pymysql.connections.Connection.__init__ (#14724)
1 parent e02a247 commit 2081b9d

File tree

3 files changed

+22
-106
lines changed

3 files changed

+22
-106
lines changed
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
11
# DictCursorMixin changes method types of inherited classes, but doesn't contain much at runtime
22
pymysql.cursors.DictCursorMixin.__iter__
33
pymysql.cursors.DictCursorMixin.fetch[a-z]*
4-
5-
# FIXME: new stubtest errors from mypy v1.18.1 that need to be looked at more closely.
6-
# See https://github.com/python/typeshed/pull/14699
7-
pymysql.connections.Connection.__init__
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing_extensions import assert_type
2+
3+
from pymysql.connections import Connection
4+
from pymysql.cursors import Cursor
5+
6+
7+
class MyCursor(Cursor):
8+
pass
9+
10+
11+
assert_type(Connection(), Connection[Cursor])
12+
assert_type(Connection(cursorclass=Cursor), Connection[Cursor])
13+
assert_type(Connection(cursorclass=MyCursor), Connection[MyCursor])
14+
15+
Connection(cursorclass=None) # type: ignore

stubs/PyMySQL/pymysql/connections.pyi

Lines changed: 7 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ from _typeshed import FileDescriptorOrPath, Incomplete, Unused
22
from collections.abc import Callable, Mapping
33
from socket import _Address, socket as _socket
44
from ssl import SSLContext, _PasswordType
5-
from typing import Any, AnyStr, Generic, TypeVar, overload
6-
from typing_extensions import Self, deprecated
5+
from typing import Any, AnyStr, Generic, overload
6+
from typing_extensions import Self, TypeVar, deprecated
77

88
from .charset import charset_by_id as charset_by_id, charset_by_name as charset_by_name
99
from .constants import CLIENT as CLIENT, COMMAND as COMMAND, FIELD_TYPE as FIELD_TYPE, SERVER_STATUS as SERVER_STATUS
@@ -21,7 +21,7 @@ from .err import (
2121
Warning,
2222
)
2323

24-
_C = TypeVar("_C", bound=Cursor)
24+
_C = TypeVar("_C", bound=Cursor, default=Cursor)
2525
_C2 = TypeVar("_C2", bound=Cursor)
2626

2727
SSL_ENABLED: bool
@@ -61,7 +61,7 @@ class Connection(Generic[_C]):
6161

6262
@overload
6363
def __init__(
64-
self: Connection[Cursor], # different between overloads
64+
self,
6565
*,
6666
user: str | bytes | None = None,
6767
password: str | bytes = "",
@@ -76,7 +76,7 @@ class Connection(Generic[_C]):
7676
conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None,
7777
use_unicode: bool = True,
7878
client_flag: int = 0,
79-
cursorclass: None = None, # different between overloads
79+
cursorclass: type[_C] = ...,
8080
init_command: str | None = None,
8181
connect_timeout: float = 10,
8282
read_default_group: str | None = None,
@@ -106,104 +106,9 @@ class Connection(Generic[_C]):
106106
db: None = None, # deprecated
107107
) -> None: ...
108108
@overload
109-
def __init__(
110-
# different between overloads
111-
self: Connection[_C], # pyright: ignore[reportInvalidTypeVarUse] #11780
112-
*,
113-
user: str | bytes | None = None,
114-
password: str | bytes = "",
115-
host: str | None = None,
116-
database: str | bytes | None = None,
117-
unix_socket: _Address | None = None,
118-
port: int = 0,
119-
charset: str = "",
120-
collation: str | None = None,
121-
sql_mode: str | None = None,
122-
read_default_file: str | None = None,
123-
conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None,
124-
use_unicode: bool = True,
125-
client_flag: int = 0,
126-
cursorclass: type[_C] = ..., # different between overloads
127-
init_command: str | None = None,
128-
connect_timeout: float = 10,
129-
read_default_group: str | None = None,
130-
autocommit: bool | None = False,
131-
local_infile: bool = False,
132-
max_allowed_packet: int = 16_777_216,
133-
defer_connect: bool = False,
134-
auth_plugin_map: dict[str, Callable[[Connection[Any]], Any]] | None = None,
135-
read_timeout: float | None = None,
136-
write_timeout: float | None = None,
137-
bind_address: str | None = None,
138-
binary_prefix: bool = False,
139-
program_name: str | None = None,
140-
server_public_key: bytes | None = None,
141-
ssl: dict[str, Incomplete] | SSLContext | None = None,
142-
ssl_ca: str | None = None,
143-
ssl_cert: str | None = None,
144-
ssl_disabled: bool | None = None,
145-
ssl_key: str | None = None,
146-
ssl_key_password: _PasswordType | None = None,
147-
ssl_verify_cert: bool | None = None,
148-
ssl_verify_identity: bool | None = None,
149-
compress: Unused = None,
150-
named_pipe: Unused = None,
151-
# different between overloads:
152-
passwd: None = None, # deprecated
153-
db: None = None, # deprecated
154-
) -> None: ...
155-
@overload
156-
@deprecated("'passwd' and 'db' arguments are deprecated. Use 'password' and 'database' instead.")
157-
def __init__(
158-
self: Connection[Cursor], # different between overloads
159-
*,
160-
user: str | bytes | None = None,
161-
password: str | bytes = "",
162-
host: str | None = None,
163-
database: str | bytes | None = None,
164-
unix_socket: _Address | None = None,
165-
port: int = 0,
166-
charset: str = "",
167-
collation: str | None = None,
168-
sql_mode: str | None = None,
169-
read_default_file: str | None = None,
170-
conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None,
171-
use_unicode: bool = True,
172-
client_flag: int = 0,
173-
cursorclass: None = None, # different between overloads
174-
init_command: str | None = None,
175-
connect_timeout: float = 10,
176-
read_default_group: str | None = None,
177-
autocommit: bool | None = False,
178-
local_infile: bool = False,
179-
max_allowed_packet: int = 16_777_216,
180-
defer_connect: bool = False,
181-
auth_plugin_map: dict[str, Callable[[Connection[Any]], Any]] | None = None,
182-
read_timeout: float | None = None,
183-
write_timeout: float | None = None,
184-
bind_address: str | None = None,
185-
binary_prefix: bool = False,
186-
program_name: str | None = None,
187-
server_public_key: bytes | None = None,
188-
ssl: dict[str, Incomplete] | SSLContext | None = None,
189-
ssl_ca: str | None = None,
190-
ssl_cert: str | None = None,
191-
ssl_disabled: bool | None = None,
192-
ssl_key: str | None = None,
193-
ssl_key_password: _PasswordType | None = None,
194-
ssl_verify_cert: bool | None = None,
195-
ssl_verify_identity: bool | None = None,
196-
compress: Unused = None,
197-
named_pipe: Unused = None,
198-
# different between overloads:
199-
passwd: str | bytes | None = None, # deprecated
200-
db: str | bytes | None = None, # deprecated
201-
) -> None: ...
202-
@overload
203109
@deprecated("'passwd' and 'db' arguments are deprecated. Use 'password' and 'database' instead.")
204110
def __init__(
205-
# different between overloads
206-
self: Connection[_C], # pyright: ignore[reportInvalidTypeVarUse] #11780
111+
self,
207112
*,
208113
user: str | bytes | None = None,
209114
password: str | bytes = "",
@@ -218,7 +123,7 @@ class Connection(Generic[_C]):
218123
conv: dict[int | type[Any], Callable[[Any], str] | Callable[[str], Any]] | None = None,
219124
use_unicode: bool = True,
220125
client_flag: int = 0,
221-
cursorclass: type[_C] = ..., # different between overloads
126+
cursorclass: type[_C] = ...,
222127
init_command: str | None = None,
223128
connect_timeout: float = 10,
224129
read_default_group: str | None = None,

0 commit comments

Comments
 (0)