From 4504eb84a35a67d591b46188442cbddddd6edfa8 Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 8 Mar 2023 18:27:47 +0000 Subject: [PATCH 1/3] Use `assert`s instead of `cast`s where possible --- mypy/api.py | 5 +++-- mypy/checker.py | 27 ++++++++++++++++----------- mypy/checkexpr.py | 4 +++- mypy/checkmember.py | 6 ++++-- mypy/expandtype.py | 5 +++-- mypy/fastparse.py | 4 +++- mypy/nodes.py | 3 ++- mypy/report.py | 5 +++-- mypy/semanal.py | 6 ++++-- mypy/semanal_enum.py | 3 ++- mypy/semanal_namedtuple.py | 3 ++- mypy/server/astdiff.py | 6 ++++-- mypy/server/astmerge.py | 3 ++- mypy/stats.py | 8 +++++--- mypy/test/testfinegrained.py | 5 +++-- mypyc/irbuild/expression.py | 6 ++++-- 16 files changed, 63 insertions(+), 36 deletions(-) diff --git a/mypy/api.py b/mypy/api.py index 589bfbbfa1a7..612fd0442276 100644 --- a/mypy/api.py +++ b/mypy/api.py @@ -47,7 +47,7 @@ import sys from io import StringIO -from typing import Callable, TextIO, cast +from typing import Callable, TextIO def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> tuple[str, str, int]: @@ -59,7 +59,8 @@ def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> tuple[str, str, int] main_wrapper(stdout, stderr) exit_status = 0 except SystemExit as system_exit: - exit_status = cast(int, system_exit.code) + assert isinstance(system_exit.code, int) + exit_status = system_exit.code return stdout.getvalue(), stderr.getvalue(), exit_status diff --git a/mypy/checker.py b/mypy/checker.py index bd762942da48..f0271002870d 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -629,7 +629,8 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: if defn.is_property: # HACK: Infer the type of the property. - self.visit_decorator(cast(Decorator, defn.items[0])) + assert isinstance(defn.items[0], Decorator) + self.visit_decorator(defn.items[0]) for fdef in defn.items: assert isinstance(fdef, Decorator) self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True) @@ -1753,7 +1754,8 @@ def expand_typevars( result: list[tuple[FuncItem, CallableType]] = [] for substitutions in itertools.product(*subst): mapping = dict(substitutions) - expanded = cast(CallableType, expand_type(typ, mapping)) + expanded = expand_type(typ, mapping) + assert isinstance(expanded, CallableType) result.append((expand_func(defn, mapping), expanded)) return result else: @@ -2576,9 +2578,8 @@ def check_import(self, node: ImportBase) -> None: if lvalue_type is None: # TODO: This is broken. lvalue_type = AnyType(TypeOfAny.special_form) - message = message_registry.INCOMPATIBLE_IMPORT_OF.format( - cast(NameExpr, assign.rvalue).name - ) + assert isinstance(assign.rvalue, NameExpr) + message = message_registry.INCOMPATIBLE_IMPORT_OF.format(assign.rvalue.name) self.check_simple_assignment( lvalue_type, assign.rvalue, @@ -3658,8 +3659,8 @@ def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, V not isinstance(lvalue, NameExpr) or isinstance(lvalue.node, Var) ): if isinstance(lvalue, NameExpr): - inferred = cast(Var, lvalue.node) - assert isinstance(inferred, Var) + assert isinstance(lvalue.node, Var) + inferred = lvalue.node else: assert isinstance(lvalue, MemberExpr) self.expr_checker.accept(lvalue.expr) @@ -4985,7 +4986,8 @@ def intersect_instance_callable(self, typ: Instance, callable_type: CallableType # In order for this to work in incremental mode, the type we generate needs to # have a valid fullname and a corresponding entry in a symbol table. We generate # a unique name inside the symbol table of the current module. - cur_module = cast(MypyFile, self.scope.stack[0]) + cur_module = self.scope.stack[0] + assert isinstance(cur_module, MypyFile) gen_name = gen_unique_name(f"", cur_module.names) # Synthesize a fake TypeInfo @@ -6197,7 +6199,8 @@ def lookup(self, name: str) -> SymbolTableNode: else: b = self.globals.get("__builtins__", None) if b: - table = cast(MypyFile, b.node).names + assert isinstance(b.node, MypyFile) + table = b.node.names if name in table: return table[name] raise KeyError(f"Failed lookup: {name}") @@ -6211,7 +6214,8 @@ def lookup_qualified(self, name: str) -> SymbolTableNode: for i in range(1, len(parts) - 1): sym = n.names.get(parts[i]) assert sym is not None, "Internal error: attempted lookup of unknown name" - n = cast(MypyFile, sym.node) + assert isinstance(sym.node, MypyFile) + n = sym.node last = parts[-1] if last in n.names: return n.names[last] @@ -6515,7 +6519,8 @@ def is_writable_attribute(self, node: Node) -> bool: return False return True elif isinstance(node, OverloadedFuncDef) and node.is_property: - first_item = cast(Decorator, node.items[0]) + first_item = node.items[0] + assert isinstance(first_item, Decorator) return first_item.var.is_settable_property return False diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 38b5c2419d95..38a884c3b2a8 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5518,7 +5518,9 @@ def merge_typevars_in_callables_by_name( variables.append(tv) rename[tv.id] = unique_typevars[name] - target = cast(CallableType, expand_type(target, rename)) + t = expand_type(target, rename) + assert isinstance(t, CallableType) + target = t output.append(target) return output, variables diff --git a/mypy/checkmember.py b/mypy/checkmember.py index a2c580e13446..0514ed0defdd 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -312,7 +312,8 @@ def analyze_instance_member_access( if method.is_property: assert isinstance(method, OverloadedFuncDef) - first_item = cast(Decorator, method.items[0]) + first_item = method.items[0] + assert isinstance(first_item, Decorator) return analyze_var(name, first_item.var, typ, info, mx) if mx.is_lvalue: mx.msg.cant_assign_to_method(mx.context) @@ -1150,7 +1151,8 @@ class B(A[str]): pass t = freshen_all_functions_type_vars(t) t = bind_self(t, original_type, is_classmethod=True) assert isuper is not None - t = cast(CallableType, expand_type_by_instance(t, isuper)) + t = expand_type_by_instance(t, isuper) + assert isinstance(t, CallableType) freeze_all_type_vars(t) return t.copy_modified(variables=list(tvars) + list(t.variables)) elif isinstance(t, Overloaded): diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 7933283b24d6..841e44a8135b 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -133,8 +133,9 @@ def freshen_function_type_vars(callee: F) -> F: tv = ParamSpecType.new_unification_variable(v) tvs.append(tv) tvmap[v.id] = tv - fresh = cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvs) - return cast(F, fresh) + expanded = expand_type(callee, tvmap) + assert isinstance(expanded, CallableType) + return cast(F, expanded.copy_modified(variables=tvs)) else: assert isinstance(callee, Overloaded) fresh_overload = Overloaded([freshen_function_type_vars(item) for item in callee.items]) diff --git a/mypy/fastparse.py b/mypy/fastparse.py index a993bd287f06..cc2b47cd6034 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -665,7 +665,9 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]: if current_overload and current_overload_name == last_if_stmt_overload_name: # Remove last stmt (IfStmt) from ret if the overload names matched # Only happens if no executable block had been found in IfStmt - skipped_if_stmts.append(cast(IfStmt, ret.pop())) + popped = ret.pop() + assert isinstance(popped, IfStmt) + skipped_if_stmts.append(popped) if current_overload and skipped_if_stmts: # Add bare IfStmt (without overloads) to ret # Required for mypy to be able to still check conditions diff --git a/mypy/nodes.py b/mypy/nodes.py index e4d8514ad6e2..1157e52ba29e 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2179,7 +2179,8 @@ def name(self) -> str: def expr(self) -> Expression: """Return the expression (the body) of the lambda.""" - ret = cast(ReturnStmt, self.body.body[-1]) + ret = self.body.body[-1] + assert isinstance(ret, ReturnStmt) expr = ret.expr assert expr is not None # lambda can't have empty body return expr diff --git a/mypy/report.py b/mypy/report.py index 75c372200ca3..2edd0957254e 100644 --- a/mypy/report.py +++ b/mypy/report.py @@ -12,7 +12,7 @@ import tokenize from abc import ABCMeta, abstractmethod from operator import attrgetter -from typing import Any, Callable, Dict, Iterator, Tuple, cast +from typing import Any, Callable, Dict, Iterator, Tuple from typing_extensions import Final, TypeAlias as _TypeAlias from urllib.request import pathname2url @@ -704,8 +704,9 @@ def __init__(self, reports: Reports, output_dir: str) -> None: super().__init__(reports, output_dir) memory_reporter = reports.add_report("memory-xml", "") + assert isinstance(memory_reporter, MemoryXmlReporter) # The dependency will be called first. - self.memory_xml = cast(MemoryXmlReporter, memory_reporter) + self.memory_xml = memory_reporter class XmlReporter(AbstractXmlReporter): diff --git a/mypy/semanal.py b/mypy/semanal.py index 2720d2606e92..50ff24cf1f2d 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1314,7 +1314,8 @@ def analyze_property_with_multi_part_definition(self, defn: OverloadedFuncDef) - """ defn.is_property = True items = defn.items - first_item = cast(Decorator, defn.items[0]) + first_item = defn.items[0] + assert isinstance(first_item, Decorator) deleted_items = [] for i, item in enumerate(items[1:]): if isinstance(item, Decorator): @@ -1357,7 +1358,8 @@ def analyze_function_body(self, defn: FuncItem) -> None: # Bind the type variables again to visit the body. if defn.type: a = self.type_analyzer() - typ = cast(CallableType, defn.type) + typ = defn.type + assert isinstance(typ, CallableType) a.bind_function_type_variables(typ, defn) for i in range(len(typ.arg_types)): store_argument_type(defn, i, typ, self.named_type) diff --git a/mypy/semanal_enum.py b/mypy/semanal_enum.py index c7b8e44f65aa..a76ba6869093 100644 --- a/mypy/semanal_enum.py +++ b/mypy/semanal_enum.py @@ -108,7 +108,8 @@ class A(enum.Enum): # Error. Construct dummy return value. info = self.build_enum_call_typeinfo(var_name, [], fullname, node.line) else: - name = cast(StrExpr, call.args[0]).value + assert isinstance(call.args[0], StrExpr) + name = call.args[0].value if name != var_name or is_func_scope: # Give it a unique name derived from the line number. name += "@" + str(call.line) diff --git a/mypy/semanal_namedtuple.py b/mypy/semanal_namedtuple.py index 1194557836b1..78b51ee012e7 100644 --- a/mypy/semanal_namedtuple.py +++ b/mypy/semanal_namedtuple.py @@ -373,7 +373,8 @@ def parse_namedtuple_args( if not isinstance(args[0], StrExpr): self.fail(f'"{type_name}()" expects a string literal as the first argument', call) return None - typename = cast(StrExpr, call.args[0]).value + assert isinstance(call.args[0], StrExpr) + typename = call.args[0].value types: list[Type] = [] tvar_defs = [] if not isinstance(args[1], (ListExpr, TupleExpr)): diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index c942a5eb3b0f..d9fed2dc969b 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -52,7 +52,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' from __future__ import annotations -from typing import Sequence, Tuple, Union, cast +from typing import Sequence, Tuple, Union from typing_extensions import TypeAlias as _TypeAlias from mypy.expandtype import expand_type @@ -442,7 +442,9 @@ def normalize_callable_variables(self, typ: CallableType) -> CallableType: tv = v.copy_modified(id=tid) tvs.append(tv) tvmap[v.id] = tv - return cast(CallableType, expand_type(typ, tvmap)).copy_modified(variables=tvs) + expanded = expand_type(typ, tvmap) + assert isinstance(expanded, CallableType) + return expanded.copy_modified(variables=tvs) def visit_tuple_type(self, typ: TupleType) -> SnapshotItem: return ("TupleType", snapshot_types(typ.items)) diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 1ec6d572a82c..0cc6377bfb0f 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -358,7 +358,8 @@ def fixup_and_reset_typeinfo(self, node: TypeInfo) -> TypeInfo: if node in self.replacements: # The subclass relationships may change, so reset all caches relevant to the # old MRO. - new = cast(TypeInfo, self.replacements[node]) + new = self.replacements[node] + assert isinstance(new, TypeInfo) type_state.reset_all_subtype_caches_for(new) return self.fixup(node) diff --git a/mypy/stats.py b/mypy/stats.py index b3a32c1ce72c..5f4b9d4d201f 100644 --- a/mypy/stats.py +++ b/mypy/stats.py @@ -5,7 +5,7 @@ import os from collections import Counter from contextlib import contextmanager -from typing import Iterator, cast +from typing import Iterator from typing_extensions import Final from mypy import nodes @@ -154,10 +154,12 @@ def visit_func_def(self, o: FuncDef) -> None: ) return for defn in o.expanded: - self.visit_func_def(cast(FuncDef, defn)) + assert isinstance(defn, FuncDef) + self.visit_func_def(defn) else: if o.type: - sig = cast(CallableType, o.type) + assert isinstance(o.type, CallableType) + sig = o.type arg_types = sig.arg_types if sig.arg_names and sig.arg_names[0] == "self" and not self.inferred: arg_types = arg_types[1:] diff --git a/mypy/test/testfinegrained.py b/mypy/test/testfinegrained.py index b19c49bf60bc..5b4c816b5c38 100644 --- a/mypy/test/testfinegrained.py +++ b/mypy/test/testfinegrained.py @@ -18,7 +18,7 @@ import re import sys import unittest -from typing import Any, cast +from typing import Any import pytest @@ -169,7 +169,8 @@ def get_options(self, source: str, testcase: DataDrivenTestCase, build_cache: bo def run_check(self, server: Server, sources: list[BuildSource]) -> list[str]: response = server.check(sources, export_types=True, is_tty=False, terminal_width=-1) - out = cast(str, response["out"] or response["err"]) + out = response["out"] or response["err"] + assert isinstance(out, str) return out.splitlines() def build(self, options: Options, sources: list[BuildSource]) -> list[str]: diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 5997bdbd0a43..0b230c465938 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Callable, Sequence, cast +from typing import Callable, Sequence from mypy.nodes import ( ARG_POS, @@ -704,7 +704,9 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: lhs = e.operands[0] mypy_file = builder.graph["builtins"].tree assert mypy_file is not None - bool_type = Instance(cast(TypeInfo, mypy_file.names["bool"].node), []) + info = mypy_file.names["bool"].node + assert isinstance(info, TypeInfo) + bool_type = Instance(info, []) exprs = [] for item in items: expr = ComparisonExpr([cmp_op], [lhs, item]) From 9f12cae24cef4c390faafb37330ef8f180874ccd Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 8 Mar 2023 19:20:16 +0000 Subject: [PATCH 2/3] Revert changes to `semanal_enum` --- mypy/semanal_enum.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mypy/semanal_enum.py b/mypy/semanal_enum.py index a76ba6869093..c7b8e44f65aa 100644 --- a/mypy/semanal_enum.py +++ b/mypy/semanal_enum.py @@ -108,8 +108,7 @@ class A(enum.Enum): # Error. Construct dummy return value. info = self.build_enum_call_typeinfo(var_name, [], fullname, node.line) else: - assert isinstance(call.args[0], StrExpr) - name = call.args[0].value + name = cast(StrExpr, call.args[0]).value if name != var_name or is_func_scope: # Give it a unique name derived from the line number. name += "@" + str(call.line) From 46914f9a8e2cc228c43bf8f1ec445e5d88901b0b Mon Sep 17 00:00:00 2001 From: AlexWaygood Date: Wed, 8 Mar 2023 19:23:45 +0000 Subject: [PATCH 3/3] Better solution in `semanal_namedtuple` --- mypy/semanal_namedtuple.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mypy/semanal_namedtuple.py b/mypy/semanal_namedtuple.py index 78b51ee012e7..cd5a5e494cd4 100644 --- a/mypy/semanal_namedtuple.py +++ b/mypy/semanal_namedtuple.py @@ -373,8 +373,7 @@ def parse_namedtuple_args( if not isinstance(args[0], StrExpr): self.fail(f'"{type_name}()" expects a string literal as the first argument', call) return None - assert isinstance(call.args[0], StrExpr) - typename = call.args[0].value + typename = args[0].value types: list[Type] = [] tvar_defs = [] if not isinstance(args[1], (ListExpr, TupleExpr)):