Skip to content

Commit

Permalink
Improve typing
Browse files Browse the repository at this point in the history
Change-Id: I9fc86c4a92e1b76d19c9e891ff08ce8a46ad4e35
  • Loading branch information
CaselIT committed Sep 12, 2022
1 parent 747ec30 commit 0e83fdd
Show file tree
Hide file tree
Showing 18 changed files with 133 additions and 80 deletions.
6 changes: 3 additions & 3 deletions alembic/autogenerate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,20 +523,20 @@ def _to_script(

def run_autogenerate(
self, rev: tuple, migration_context: "MigrationContext"
):
) -> None:
self._run_environment(rev, migration_context, True)

def run_no_autogenerate(
self, rev: tuple, migration_context: "MigrationContext"
):
) -> None:
self._run_environment(rev, migration_context, False)

def _run_environment(
self,
rev: tuple,
migration_context: "MigrationContext",
autogenerate: bool,
):
) -> None:
if autogenerate:
if self.command_args["sql"]:
raise util.CommandError(
Expand Down
7 changes: 4 additions & 3 deletions alembic/autogenerate/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,8 +616,8 @@ def _compare_indexes_and_uniques(
# we know are either added implicitly by the DB or that the DB
# can't accurately report on
autogen_context.migration_context.impl.correct_for_autogen_constraints(
conn_uniques,
conn_indexes,
conn_uniques, # type: ignore[arg-type]
conn_indexes, # type: ignore[arg-type]
metadata_unique_constraints,
metadata_indexes,
)
Expand Down Expand Up @@ -1274,7 +1274,8 @@ def _compare_foreign_keys(
)

conn_fks = set(
_make_foreign_key(const, conn_table) for const in conn_fks_list
_make_foreign_key(const, conn_table) # type: ignore[arg-type]
for const in conn_fks_list
)

# give the dialect a chance to correct the FKs to match more
Expand Down
2 changes: 1 addition & 1 deletion alembic/autogenerate/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ def _fk_colspec(
if table_fullname in namespace_metadata.tables:
col = namespace_metadata.tables[table_fullname].c.get(colname)
if col is not None:
colname = _ident(col.name)
colname = _ident(col.name) # type: ignore[assignment]

colspec = "%s.%s" % (table_fullname, colname)

Expand Down
15 changes: 11 additions & 4 deletions alembic/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ from __future__ import annotations
from typing import Any
from typing import Callable
from typing import ContextManager
from typing import Dict
from typing import List
from typing import Optional
from typing import TextIO
from typing import Tuple
Expand All @@ -13,6 +15,7 @@ from typing import Union

if TYPE_CHECKING:
from sqlalchemy.engine.base import Connection
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import MetaData

from .config import Config
Expand Down Expand Up @@ -530,7 +533,9 @@ def configure(
"""

def execute(sql, execution_options=None):
def execute(
sql: Union[ClauseElement, str], execution_options: Optional[dict] = None
) -> None:
"""Execute the given SQL using the current change context.
The behavior of :meth:`.execute` is the same
Expand All @@ -543,7 +548,7 @@ def execute(sql, execution_options=None):
"""

def get_bind():
def get_bind() -> Connection:
"""Return the current 'bind'.
In "online" mode, this is the
Expand Down Expand Up @@ -635,7 +640,9 @@ def get_tag_argument() -> Optional[str]:
"""

def get_x_argument(as_dictionary: bool = False):
def get_x_argument(
as_dictionary: bool = False,
) -> Union[List[str], Dict[str, str]]:
"""Return the value(s) passed for the ``-x`` argument, if any.
The ``-x`` argument is an open ended flag that allows any user-defined
Expand Down Expand Up @@ -723,7 +730,7 @@ def run_migrations(**kw: Any) -> None:

script: ScriptDirectory

def static_output(text):
def static_output(text: str) -> None:
"""Emit text directly to the "offline" SQL stream.
Typically this is for emitting comments that
Expand Down
2 changes: 1 addition & 1 deletion alembic/ddl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def format_table_name(
def format_column_name(
compiler: "DDLCompiler", name: Optional[Union["quoted_name", str]]
) -> Union["quoted_name", str]:
return compiler.preparer.quote(name)
return compiler.preparer.quote(name) # type: ignore[arg-type]


def format_server_default(
Expand Down
45 changes: 23 additions & 22 deletions alembic/ddl/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,16 @@
from ..util import sqla_compat

if TYPE_CHECKING:
from io import StringIO
from typing import Literal
from typing import TextIO

from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.cursor import LegacyCursorResult
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.dml import Update
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
Expand All @@ -60,11 +57,11 @@ def __init__(
):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
_impls[dict_["__dialect__"]] = cls
_impls[dict_["__dialect__"]] = cls # type: ignore[assignment]
return newtype


_impls: dict = {}
_impls: Dict[str, Type["DefaultImpl"]] = {}

Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])

Expand Down Expand Up @@ -98,7 +95,7 @@ def __init__(
connection: Optional["Connection"],
as_sql: bool,
transactional_ddl: Optional[bool],
output_buffer: Optional["StringIO"],
output_buffer: Optional["TextIO"],
context_opts: Dict[str, Any],
) -> None:
self.dialect = dialect
Expand All @@ -119,7 +116,7 @@ def __init__(
)

@classmethod
def get_by_dialect(cls, dialect: "Dialect") -> Any:
def get_by_dialect(cls, dialect: "Dialect") -> Type["DefaultImpl"]:
return _impls[dialect.name]

def static_output(self, text: str) -> None:
Expand Down Expand Up @@ -158,10 +155,10 @@ def bind(self) -> Optional["Connection"]:
def _exec(
self,
construct: Union["ClauseElement", str],
execution_options: None = None,
execution_options: Optional[dict] = None,
multiparams: Sequence[dict] = (),
params: Dict[str, int] = util.immutabledict(),
) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
) -> Optional["CursorResult"]:
if isinstance(construct, str):
construct = text(construct)
if self.as_sql:
Expand All @@ -176,10 +173,11 @@ def _exec(
else:
compile_kw = {}

compiled = construct.compile(
dialect=self.dialect, **compile_kw # type: ignore[arg-type]
)
self.static_output(
str(construct.compile(dialect=self.dialect, **compile_kw))
.replace("\t", " ")
.strip()
str(compiled).replace("\t", " ").strip()
+ self.command_terminator
)
return None
Expand All @@ -192,11 +190,13 @@ def _exec(
assert isinstance(multiparams, tuple)
multiparams += (params,)

return conn.execute(construct, multiparams)
return conn.execute( # type: ignore[call-overload]
construct, multiparams
)

def execute(
self,
sql: Union["Update", "TextClause", str],
sql: Union["ClauseElement", str],
execution_options: None = None,
) -> None:
self._exec(sql, execution_options)
Expand Down Expand Up @@ -424,9 +424,6 @@ def bulk_insert(
)
)
else:
# work around http://www.sqlalchemy.org/trac/ticket/2461
if not hasattr(table, "_autoincrement_column"):
table._autoincrement_column = None
if rows:
if multiinsert:
self._exec(
Expand Down Expand Up @@ -572,7 +569,7 @@ def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
)

def render_ddl_sql_expr(
self, expr: "ClauseElement", is_server_default: bool = False, **kw
self, expr: "ClauseElement", is_server_default: bool = False, **kw: Any
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
Expand All @@ -581,10 +578,14 @@ def render_ddl_sql_expr(
"""

compile_kw = dict(
compile_kwargs={"literal_binds": True, "include_table": False}
compile_kw = {
"compile_kwargs": {"literal_binds": True, "include_table": False}
}
return str(
expr.compile(
dialect=self.dialect, **compile_kw # type: ignore[arg-type]
)
)
return str(expr.compile(dialect=self.dialect, **compile_kw))

def _compat_autogen_column_reflect(
self, inspector: "Inspector"
Expand Down
7 changes: 2 additions & 5 deletions alembic/ddl/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from sqlalchemy.dialects.mssql.base import MSDDLCompiler
from sqlalchemy.dialects.mssql.base import MSSQLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.cursor import LegacyCursorResult
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import TableClause
Expand Down Expand Up @@ -68,9 +67,7 @@ def __init__(self, *arg, **kw) -> None:
"mssql_batch_separator", self.batch_separator
)

def _exec(
self, construct: Any, *args, **kw
) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
result = super(MSSQLImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
Expand Down Expand Up @@ -359,7 +356,7 @@ def visit_column_nullable(
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.existing_type),
format_type(compiler, element.existing_type), # type: ignore[arg-type]
"NULL" if element.nullable else "NOT NULL",
)

Expand Down
6 changes: 1 addition & 5 deletions alembic/ddl/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
Expand All @@ -26,7 +25,6 @@
if TYPE_CHECKING:
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.cursor import LegacyCursorResult
from sqlalchemy.sql.schema import Column


Expand All @@ -48,9 +46,7 @@ def __init__(self, *arg, **kw) -> None:
"oracle_batch_separator", self.batch_separator
)

def _exec(
self, construct: Any, *args, **kw
) -> Optional[Union["LegacyCursorResult", "CursorResult"]]:
def _exec(self, construct: Any, *args, **kw) -> Optional["CursorResult"]:
result = super(OracleImpl, self)._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
Expand Down
11 changes: 7 additions & 4 deletions alembic/op.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ if TYPE_CHECKING:
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Computed
from sqlalchemy.sql.schema import Identity
from sqlalchemy.sql.schema import SchemaItem
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from sqlalchemy.util import immutabledict
Expand Down Expand Up @@ -94,7 +95,7 @@ def alter_column(
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
comment: Union[str, bool, None] = False,
comment: Union[str, Literal[False], None] = False,
server_default: Any = False,
new_column_name: Optional[str] = None,
type_: Union[TypeEngine, Type[TypeEngine], None] = None,
Expand Down Expand Up @@ -202,13 +203,13 @@ def batch_alter_table(
schema: Optional[str] = None,
recreate: Literal["auto", "always", "never"] = "auto",
partial_reordering: Optional[tuple] = None,
copy_from: Optional["Table"] = None,
copy_from: Optional[Table] = None,
table_args: Tuple[Any, ...] = (),
table_kwargs: Mapping[str, Any] = immutabledict({}),
reflect_args: Tuple[Any, ...] = (),
reflect_kwargs: Mapping[str, Any] = immutabledict({}),
naming_convention: Optional[Dict[str, str]] = None,
) -> Iterator["BatchOperations"]:
) -> Iterator[BatchOperations]:
"""Invoke a series of per-table migrations in batch.
Batch mode allows a series of operations specific to a table
Expand Down Expand Up @@ -667,7 +668,9 @@ def create_primary_key(
"""

def create_table(table_name: str, *columns, **kw: Any) -> Optional[Table]:
def create_table(
table_name: str, *columns: SchemaItem, **kw: Any
) -> Optional[Table]:
"""Issue a "create table" instruction using the current migration
context.
Expand Down
6 changes: 5 additions & 1 deletion alembic/operations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from .batch import BatchOperationsImpl
from .ops import MigrateOperation
from ..ddl import DefaultImpl
from ..runtime.migration import MigrationContext
from ..util.sqla_compat import _literal_bindparam

Expand Down Expand Up @@ -74,6 +75,7 @@ class Operations(util.ModuleClsProxy):
"""

impl: Union["DefaultImpl", "BatchOperationsImpl"]
_to_impl = util.Dispatcher()

def __init__(
Expand Down Expand Up @@ -492,7 +494,7 @@ def get_bind(self) -> "Connection":
In a SQL script context, this value is ``None``. [TODO: verify this]
"""
return self.migration_context.impl.bind
return self.migration_context.impl.bind # type: ignore[return-value]


class BatchOperations(Operations):
Expand All @@ -512,6 +514,8 @@ class BatchOperations(Operations):
"""

impl: "BatchOperationsImpl"

def _noop(self, operation):
raise NotImplementedError(
"The %s method does not apply to a batch table alter operation."
Expand Down
Loading

0 comments on commit 0e83fdd

Please sign in to comment.