Skip to content

Commit

Permalink
Use asserts instead of casts where possible (#14860)
Browse files Browse the repository at this point in the history
There are many places in mypy's code where `cast`s are currently used
unnecessarily. These can be replaced with `assert`s, which are much more
type-safe, and more mypyc-friendly.
  • Loading branch information
AlexWaygood authored Mar 11, 2023
1 parent 4b3722f commit fddd5c5
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 27 deletions.
5 changes: 3 additions & 2 deletions mypy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

import sys
from io import StringIO
from typing import Callable, TextIO, cast
from typing import Callable, TextIO


def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> tuple[str, str, int]:
Expand All @@ -59,7 +59,8 @@ def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> tuple[str, str, int]
main_wrapper(stdout, stderr)
exit_status = 0
except SystemExit as system_exit:
exit_status = cast(int, system_exit.code)
assert isinstance(system_exit.code, int)
exit_status = system_exit.code

return stdout.getvalue(), stderr.getvalue(), exit_status

Expand Down
24 changes: 14 additions & 10 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,8 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:

if defn.is_property:
# HACK: Infer the type of the property.
self.visit_decorator(cast(Decorator, defn.items[0]))
assert isinstance(defn.items[0], Decorator)
self.visit_decorator(defn.items[0])
for fdef in defn.items:
assert isinstance(fdef, Decorator)
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
Expand Down Expand Up @@ -2575,9 +2576,8 @@ def check_import(self, node: ImportBase) -> None:
if lvalue_type is None:
# TODO: This is broken.
lvalue_type = AnyType(TypeOfAny.special_form)
message = message_registry.INCOMPATIBLE_IMPORT_OF.format(
cast(NameExpr, assign.rvalue).name
)
assert isinstance(assign.rvalue, NameExpr)
message = message_registry.INCOMPATIBLE_IMPORT_OF.format(assign.rvalue.name)
self.check_simple_assignment(
lvalue_type,
assign.rvalue,
Expand Down Expand Up @@ -3657,8 +3657,8 @@ def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, V
not isinstance(lvalue, NameExpr) or isinstance(lvalue.node, Var)
):
if isinstance(lvalue, NameExpr):
inferred = cast(Var, lvalue.node)
assert isinstance(inferred, Var)
assert isinstance(lvalue.node, Var)
inferred = lvalue.node
else:
assert isinstance(lvalue, MemberExpr)
self.expr_checker.accept(lvalue.expr)
Expand Down Expand Up @@ -4984,7 +4984,8 @@ def intersect_instance_callable(self, typ: Instance, callable_type: CallableType
# In order for this to work in incremental mode, the type we generate needs to
# have a valid fullname and a corresponding entry in a symbol table. We generate
# a unique name inside the symbol table of the current module.
cur_module = cast(MypyFile, self.scope.stack[0])
cur_module = self.scope.stack[0]
assert isinstance(cur_module, MypyFile)
gen_name = gen_unique_name(f"<callable subtype of {typ.type.name}>", cur_module.names)

# Synthesize a fake TypeInfo
Expand Down Expand Up @@ -6196,7 +6197,8 @@ def lookup(self, name: str) -> SymbolTableNode:
else:
b = self.globals.get("__builtins__", None)
if b:
table = cast(MypyFile, b.node).names
assert isinstance(b.node, MypyFile)
table = b.node.names
if name in table:
return table[name]
raise KeyError(f"Failed lookup: {name}")
Expand All @@ -6210,7 +6212,8 @@ def lookup_qualified(self, name: str) -> SymbolTableNode:
for i in range(1, len(parts) - 1):
sym = n.names.get(parts[i])
assert sym is not None, "Internal error: attempted lookup of unknown name"
n = cast(MypyFile, sym.node)
assert isinstance(sym.node, MypyFile)
n = sym.node
last = parts[-1]
if last in n.names:
return n.names[last]
Expand Down Expand Up @@ -6514,7 +6517,8 @@ def is_writable_attribute(self, node: Node) -> bool:
return False
return True
elif isinstance(node, OverloadedFuncDef) and node.is_property:
first_item = cast(Decorator, node.items[0])
first_item = node.items[0]
assert isinstance(first_item, Decorator)
return first_item.var.is_settable_property
return False

Expand Down
3 changes: 2 additions & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ def analyze_instance_member_access(

if method.is_property:
assert isinstance(method, OverloadedFuncDef)
first_item = cast(Decorator, method.items[0])
first_item = method.items[0]
assert isinstance(first_item, Decorator)
return analyze_var(name, first_item.var, typ, info, mx)
if mx.is_lvalue:
mx.msg.cant_assign_to_method(mx.context)
Expand Down
4 changes: 3 additions & 1 deletion mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,9 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
if current_overload and current_overload_name == last_if_stmt_overload_name:
# Remove last stmt (IfStmt) from ret if the overload names matched
# Only happens if no executable block had been found in IfStmt
skipped_if_stmts.append(cast(IfStmt, ret.pop()))
popped = ret.pop()
assert isinstance(popped, IfStmt)
skipped_if_stmts.append(popped)
if current_overload and skipped_if_stmts:
# Add bare IfStmt (without overloads) to ret
# Required for mypy to be able to still check conditions
Expand Down
3 changes: 2 additions & 1 deletion mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2183,7 +2183,8 @@ def name(self) -> str:

def expr(self) -> Expression:
"""Return the expression (the body) of the lambda."""
ret = cast(ReturnStmt, self.body.body[-1])
ret = self.body.body[-1]
assert isinstance(ret, ReturnStmt)
expr = ret.expr
assert expr is not None # lambda can't have empty body
return expr
Expand Down
5 changes: 3 additions & 2 deletions mypy/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tokenize
from abc import ABCMeta, abstractmethod
from operator import attrgetter
from typing import Any, Callable, Dict, Iterator, Tuple, cast
from typing import Any, Callable, Dict, Iterator, Tuple
from typing_extensions import Final, TypeAlias as _TypeAlias
from urllib.request import pathname2url

Expand Down Expand Up @@ -704,8 +704,9 @@ def __init__(self, reports: Reports, output_dir: str) -> None:
super().__init__(reports, output_dir)

memory_reporter = reports.add_report("memory-xml", "<memory>")
assert isinstance(memory_reporter, MemoryXmlReporter)
# The dependency will be called first.
self.memory_xml = cast(MemoryXmlReporter, memory_reporter)
self.memory_xml = memory_reporter


class XmlReporter(AbstractXmlReporter):
Expand Down
6 changes: 4 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,8 @@ def analyze_property_with_multi_part_definition(self, defn: OverloadedFuncDef) -
"""
defn.is_property = True
items = defn.items
first_item = cast(Decorator, defn.items[0])
first_item = defn.items[0]
assert isinstance(first_item, Decorator)
deleted_items = []
for i, item in enumerate(items[1:]):
if isinstance(item, Decorator):
Expand Down Expand Up @@ -1357,7 +1358,8 @@ def analyze_function_body(self, defn: FuncItem) -> None:
# Bind the type variables again to visit the body.
if defn.type:
a = self.type_analyzer()
typ = cast(CallableType, defn.type)
typ = defn.type
assert isinstance(typ, CallableType)
a.bind_function_type_variables(typ, defn)
for i in range(len(typ.arg_types)):
store_argument_type(defn, i, typ, self.named_type)
Expand Down
3 changes: 2 additions & 1 deletion mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ def fixup_and_reset_typeinfo(self, node: TypeInfo) -> TypeInfo:
if node in self.replacements:
# The subclass relationships may change, so reset all caches relevant to the
# old MRO.
new = cast(TypeInfo, self.replacements[node])
new = self.replacements[node]
assert isinstance(new, TypeInfo)
type_state.reset_all_subtype_caches_for(new)
return self.fixup(node)

Expand Down
8 changes: 5 additions & 3 deletions mypy/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from collections import Counter
from contextlib import contextmanager
from typing import Iterator, cast
from typing import Iterator
from typing_extensions import Final

from mypy import nodes
Expand Down Expand Up @@ -154,10 +154,12 @@ def visit_func_def(self, o: FuncDef) -> None:
)
return
for defn in o.expanded:
self.visit_func_def(cast(FuncDef, defn))
assert isinstance(defn, FuncDef)
self.visit_func_def(defn)
else:
if o.type:
sig = cast(CallableType, o.type)
assert isinstance(o.type, CallableType)
sig = o.type
arg_types = sig.arg_types
if sig.arg_names and sig.arg_names[0] == "self" and not self.inferred:
arg_types = arg_types[1:]
Expand Down
5 changes: 3 additions & 2 deletions mypy/test/testfinegrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import re
import sys
import unittest
from typing import Any, cast
from typing import Any

import pytest

Expand Down Expand Up @@ -169,7 +169,8 @@ def get_options(self, source: str, testcase: DataDrivenTestCase, build_cache: bo

def run_check(self, server: Server, sources: list[BuildSource]) -> list[str]:
response = server.check(sources, export_types=True, is_tty=False, terminal_width=-1)
out = cast(str, response["out"] or response["err"])
out = response["out"] or response["err"]
assert isinstance(out, str)
return out.splitlines()

def build(self, options: Options, sources: list[BuildSource]) -> list[str]:
Expand Down
6 changes: 4 additions & 2 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import Callable, Sequence, cast
from typing import Callable, Sequence

from mypy.nodes import (
ARG_POS,
Expand Down Expand Up @@ -704,7 +704,9 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
lhs = e.operands[0]
mypy_file = builder.graph["builtins"].tree
assert mypy_file is not None
bool_type = Instance(cast(TypeInfo, mypy_file.names["bool"].node), [])
info = mypy_file.names["bool"].node
assert isinstance(info, TypeInfo)
bool_type = Instance(info, [])
exprs = []
for item in items:
expr = ComparisonExpr([cmp_op], [lhs, item])
Expand Down

0 comments on commit fddd5c5

Please sign in to comment.