Skip to content

[clean strict optional] Fix some strict optional errors in mypy #3228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions mypy/checkstrformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re

from typing import cast, List, Tuple, Dict, Callable, Union
from typing import cast, List, Tuple, Dict, Callable, Union, Optional

from mypy.types import (
Type, AnyType, TupleType, Instance, UnionType
Expand All @@ -18,6 +18,7 @@
from mypy.messages import MessageBuilder

FormatStringExpr = Union[StrExpr, BytesExpr, UnicodeExpr]
Checkers = Tuple[Callable[[Expression], None], Callable[[Type], None]]


class ConversionSpecifier:
Expand Down Expand Up @@ -105,7 +106,7 @@ def parse_conversion_specifiers(self, format: str) -> List[ConversionSpecifier]:
return specifiers

def analyze_conversion_specifiers(self, specifiers: List[ConversionSpecifier],
context: Context) -> bool:
context: Context) -> Optional[bool]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code below distinguishes False and None

has_star = any(specifier.has_star() for specifier in specifiers)
has_key = any(specifier.has_key() for specifier in specifiers)
all_have_keys = all(
Expand Down Expand Up @@ -192,9 +193,8 @@ def check_mapping_str_interpolation(self, specifiers: List[ConversionSpecifier],

def build_replacement_checkers(self, specifiers: List[ConversionSpecifier],
context: Context, expr: FormatStringExpr
) -> List[Tuple[Callable[[Expression], None],
Callable[[Type], None]]]:
checkers = [] # type: List[Tuple[Callable[[Expression], None], Callable[[Type], None]]]
) -> Optional[List[Checkers]]:
checkers = [] # type: List[Checkers]
for specifier in specifiers:
checker = self.replacement_checkers(specifier, context, expr)
if checker is None:
Expand All @@ -203,13 +203,12 @@ def build_replacement_checkers(self, specifiers: List[ConversionSpecifier],
return checkers

def replacement_checkers(self, specifier: ConversionSpecifier, context: Context,
expr: FormatStringExpr) -> List[Tuple[Callable[[Expression], None],
Callable[[Type], None]]]:
expr: FormatStringExpr) -> Optional[List[Checkers]]:
"""Returns a list of tuples of two functions that check whether a replacement is
of the right type for the specifier. The first functions take a node and checks
its type in the right type context. The second function just checks a type.
"""
checkers = [] # type: List[Tuple[Callable[[Expression], None], Callable[[Type], None]]]
checkers = [] # type: List[Checkers]

if specifier.width == '*':
checkers.append(self.checkers_for_star(context))
Expand All @@ -227,14 +226,13 @@ def replacement_checkers(self, specifier: ConversionSpecifier, context: Context,
checkers.append(c)
return checkers

def checkers_for_star(self, context: Context) -> Tuple[Callable[[Expression], None],
Callable[[Type], None]]:
def checkers_for_star(self, context: Context) -> Checkers:
"""Returns a tuple of check functions that check whether, respectively,
a node or a type is compatible with a star in a conversion specifier
"""
expected = self.named_type('builtins.int')

def check_type(type: Type = None) -> None:
def check_type(type: Type) -> None:
expected = self.named_type('builtins.int')
self.chk.check_subtype(type, expected, context, '* wants int')

Expand All @@ -246,16 +244,16 @@ def check_expr(expr: Expression) -> None:

def checkers_for_regular_type(self, type: str,
context: Context,
expr: FormatStringExpr) -> Tuple[Callable[[Expression], None],
Callable[[Type], None]]:
expr: FormatStringExpr) -> Optional[Checkers]:
"""Returns a tuple of check functions that check whether, respectively,
a node or a type is compatible with 'type'. Return None in case of an
"""
expected_type = self.conversion_type(type, context, expr)
if expected_type is None:
return None

def check_type(type: Type = None) -> None:
def check_type(type: Type) -> None:
assert expected_type is not None
self.chk.check_subtype(type, expected_type, context,
messages.INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION,
'expression has type', 'placeholder has type')
Expand All @@ -268,16 +266,16 @@ def check_expr(expr: Expression) -> None:

def checkers_for_c_type(self, type: str,
context: Context,
expr: FormatStringExpr) -> Tuple[Callable[[Expression], None],
Callable[[Type], None]]:
expr: FormatStringExpr) -> Optional[Checkers]:
"""Returns a tuple of check functions that check whether, respectively,
a node or a type is compatible with 'type' that is a character type
"""
expected_type = self.conversion_type(type, context, expr)
if expected_type is None:
return None

def check_type(type: Type = None) -> None:
def check_type(type: Type) -> None:
assert expected_type is not None
self.chk.check_subtype(type, expected_type, context,
messages.INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION,
'expression has type', 'placeholder has type')
Expand All @@ -291,7 +289,7 @@ def check_expr(expr: Expression) -> None:

return check_expr, check_type

def conversion_type(self, p: str, context: Context, expr: FormatStringExpr) -> Type:
def conversion_type(self, p: str, context: Context, expr: FormatStringExpr) -> Optional[Type]:
"""Return the type that is accepted for a string interpolation
conversion specifier type.

Expand Down
2 changes: 1 addition & 1 deletion mypy/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def parse(source: Union[str, bytes],
fnam: str,
errors: Errors,
errors: Optional[Errors],
options: Options) -> MypyFile:
"""Parse a source file, without doing any semantic analysis.

Expand Down
9 changes: 6 additions & 3 deletions mypy/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def visit_func_def(self, defn: FuncDef) -> None:
if cur_indent is None:
# Consume the line, but don't mark it as belonging to the function yet.
cur_line += 1
elif cur_indent > start_indent:
elif start_indent is not None and cur_indent > start_indent:
# A non-blank line that belongs to the function.
cur_line += 1
end_line = cur_line
Expand All @@ -211,7 +211,7 @@ def visit_func_def(self, defn: FuncDef) -> None:
is_typed = defn.type is not None
for line in range(start_line, end_line):
old_indent, _ = self.lines_covered[line]
assert start_indent > old_indent
assert start_indent is not None and start_indent > old_indent
self.lines_covered[line] = (start_indent, is_typed)

# Visit the body, in case there are nested functions
Expand Down Expand Up @@ -304,7 +304,7 @@ def __init__(self, reports: Reports, output_dir: str) -> None:
self.css_html_path = os.path.join(reports.data_dir, 'xml', 'mypy-html.css')
xsd_path = os.path.join(reports.data_dir, 'xml', 'mypy.xsd')
self.schema = etree.XMLSchema(etree.parse(xsd_path))
self.last_xml = None # type: etree._ElementTree
self.last_xml = None # type: Optional[etree._ElementTree]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it wouldn't be better to keep this as non-optional and remove the None assignments? Not sure I follow the logic there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand the logic correctly, those None assignments are needed, but the attribute should be always non-None at the end, so that I replaced ifs with asserts, I checked locally it works correctly.

self.files = [] # type: List[FileInfo]

def on_file(self,
Expand Down Expand Up @@ -532,6 +532,7 @@ def on_file(self,

def on_finish(self) -> None:
last_xml = self.memory_xml.last_xml
assert last_xml is not None
out_path = os.path.join(self.output_dir, 'index.xml')
out_xslt = os.path.join(self.output_dir, 'mypy-html.xslt')
out_css = os.path.join(self.output_dir, 'mypy-html.css')
Expand Down Expand Up @@ -575,6 +576,7 @@ def on_file(self,

def on_finish(self) -> None:
last_xml = self.memory_xml.last_xml
assert last_xml is not None
out_path = os.path.join(self.output_dir, 'index.html')
out_css = os.path.join(self.output_dir, 'mypy-html.css')
transformed_html = bytes(self.xslt_html(last_xml, ext=self.param_html))
Expand Down Expand Up @@ -606,6 +608,7 @@ def on_file(self,

def on_finish(self) -> None:
last_xml = self.memory_xml.last_xml
assert last_xml is not None
out_path = os.path.join(self.output_dir, 'index.txt')
stats.ensure_dir_exists(os.path.dirname(out_path))
transformed_txt = bytes(self.xslt_txt(last_xml))
Expand Down
12 changes: 7 additions & 5 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ class StrConv(NodeVisitor[str]):

def __init__(self, show_ids: bool = False) -> None:
self.show_ids = show_ids
self.id_mapper = None # type: Optional[IdMapper]
if show_ids:
self.id_mapper = IdMapper()
else:
self.id_mapper = None

def get_id(self, o: object) -> int:
return self.id_mapper.id(o)
def get_id(self, o: object) -> Optional[int]:
if self.id_mapper:
return self.id_mapper.id(o)
return None

def format_id(self, o: object) -> str:
if self.id_mapper:
Expand All @@ -47,6 +48,7 @@ def dump(self, nodes: Sequence[object], obj: 'mypy.nodes.Context') -> str:
"""
tag = short_type(obj) + ':' + str(obj.get_line())
if self.show_ids:
assert self.id_mapper is not None
tag += '<{}>'.format(self.get_id(obj))
return dump_tagged(nodes, tag, self)

Expand Down Expand Up @@ -504,7 +506,7 @@ def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> str:
return self.dump([o.expr], o)


def dump_tagged(nodes: Sequence[object], tag: str, str_conv: 'StrConv') -> str:
def dump_tagged(nodes: Sequence[object], tag: Optional[str], str_conv: 'StrConv') -> str:
"""Convert an array into a pretty-printed multiline string representation.

The format is
Expand Down
9 changes: 5 additions & 4 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,10 @@ def visit_decorator(self, o: Decorator) -> None:
super().visit_decorator(o)

def visit_class_def(self, o: ClassDef) -> None:
sep = None # type: Optional[int]
if not self._indent and self._state != EMPTY:
sep = len(self._output)
self.add('\n')
else:
sep = None
self.add('%sclass %s' % (self._indent, o.name))
self.record_name(o.name)
base_types = self.get_base_types(o)
Expand Down Expand Up @@ -465,7 +464,7 @@ def visit_import(self, o: Import) -> None:
self.add_import_line('import %s as %s\n' % (id, target_name))
self.record_name(target_name)

def get_init(self, lvalue: str, rvalue: Expression) -> str:
def get_init(self, lvalue: str, rvalue: Expression) -> Optional[str]:
"""Return initializer for a variable.

Return None if we've generated one already or if the variable is internal.
Expand Down Expand Up @@ -511,7 +510,9 @@ def output(self) -> str:
def is_not_in_all(self, name: str) -> bool:
if self.is_private_name(name):
return False
return self.is_top_level() and bool(self._all_) and name not in self._all_
if self._all_:
return self.is_top_level() and name not in self._all_
return False

def is_private_name(self, name: str) -> bool:
return name.startswith('_') and (not name.endswith('__')
Expand Down
9 changes: 6 additions & 3 deletions mypy/stubgenc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@ def generate_c_function_stub(module: ModuleType,
self_arg = '%s, ' % self_var
else:
self_arg = ''
if name in ('__new__', '__init__') and name not in sigs and class_name in class_sigs:
if (name in ('__new__', '__init__') and name not in sigs and class_name and
class_name in class_sigs):
sig = class_sigs[class_name]
else:
docstr = getattr(obj, '__doc__', None)
sig = infer_sig_from_docstring(docstr, name)
if not sig:
inferred = infer_sig_from_docstring(docstr, name)
if inferred:
sig = inferred
else:
if class_name and name not in sigs:
sig = infer_method_sig(name)
else:
Expand Down
3 changes: 2 additions & 1 deletion mypy/stubutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def is_c_module(module: ModuleType) -> bool:
return '__file__' not in module.__dict__ or module.__dict__['__file__'].endswith('.so')


def write_header(file: IO[str], module_name: str, pyversion: Tuple[int, int] = (3, 5)) -> None:
def write_header(file: IO[str], module_name: Optional[str] = None,
pyversion: Tuple[int, int] = (3, 5)) -> None:
if module_name:
if pyversion[0] >= 3:
version = '%d.%d' % (sys.version_info.major,
Expand Down