Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make literal exprs have inferred type of 'Literal' based on context #5990

Merged
merged 5 commits into from
Dec 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:

def analyze_var_ref(self, var: Var, context: Context) -> Type:
if var.type:
return var.type
if is_literal_type_like(self.type_context[-1]) and var.name() in {'True', 'False'}:
return LiteralType(var.name() == 'True', self.named_type('builtins.bool'))
else:
return var.type
else:
if not var.is_ready and self.chk.in_checked_function():
self.chk.handle_cannot_determine_type(var.name(), context)
Expand Down Expand Up @@ -1721,11 +1724,17 @@ def analyze_external_member_access(self, member: str, base_type: Type,

def visit_int_expr(self, e: IntExpr) -> Type:
"""Type check an integer literal (trivial)."""
return self.named_type('builtins.int')
typ = self.named_type('builtins.int')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ

def visit_str_expr(self, e: StrExpr) -> Type:
"""Type check a string literal (trivial)."""
return self.named_type('builtins.str')
typ = self.named_type('builtins.str')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ

def visit_bytes_expr(self, e: BytesExpr) -> Type:
"""Type check a bytes literal (trivial)."""
Expand Down Expand Up @@ -3583,3 +3592,17 @@ def merge_typevars_in_callables_by_name(
output.append(target)

return output, variables


def is_literal_type_like(t: Optional[Type]) -> bool:
"""Returns 'true' if the given type context is potentially either a LiteralType,
a Union of LiteralType, or something similar.
"""
if t is None:
return False
elif isinstance(t, LiteralType):
return True
elif isinstance(t, UnionType):
return any(is_literal_type_like(item) for item in t.items)
else:
return False
6 changes: 3 additions & 3 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def visit_erased_type(self, template: ErasedType) -> List[Constraint]:
def visit_deleted_type(self, template: DeletedType) -> List[Constraint]:
return []

def visit_literal_type(self, template: LiteralType) -> List[Constraint]:
return []

# Errors

def visit_partial_type(self, template: PartialType) -> List[Constraint]:
Expand Down Expand Up @@ -472,9 +475,6 @@ def visit_typeddict_type(self, template: TypedDictType) -> List[Constraint]:
else:
return []

def visit_literal_type(self, template: LiteralType) -> List[Constraint]:
raise NotImplementedError()

def visit_union_type(self, template: UnionType) -> List[Constraint]:
assert False, ("Unexpected UnionType in ConstraintBuilderVisitor"
" (should have been handled in infer_constraints)")
Expand Down
10 changes: 9 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def visit_instance(self, t: Instance) -> Type:
return join_types(t, self.s)
elif isinstance(self.s, TypedDictType):
return join_types(t, self.s)
elif isinstance(self.s, LiteralType):
return join_types(t, self.s)
else:
return self.default(self.s)

Expand Down Expand Up @@ -268,7 +270,13 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
return self.default(self.s)

def visit_literal_type(self, t: LiteralType) -> Type:
raise NotImplementedError()
if isinstance(self.s, LiteralType):
if t == self.s:
return t
else:
return join_types(self.s.fallback, t.fallback)
else:
return join_types(self.s, t.fallback)

def visit_partial_type(self, t: PartialType) -> Type:
# We only have partial information so we can't decide the join result. We should
Expand Down
43 changes: 39 additions & 4 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,11 @@ def test_is_proper_subtype_and_subtype_literal_types(self) -> None:
fx = self.fx

lit1 = LiteralType(1, fx.a)
lit2 = LiteralType("foo", fx.b)
lit3 = LiteralType("bar", fx.b)
lit2 = LiteralType("foo", fx.d)
lit3 = LiteralType("bar", fx.d)

assert_true(is_proper_subtype(lit1, fx.a))
assert_false(is_proper_subtype(lit1, fx.b))
assert_false(is_proper_subtype(lit1, fx.d))
assert_false(is_proper_subtype(fx.a, lit1))
assert_true(is_proper_subtype(fx.uninhabited, lit1))
assert_false(is_proper_subtype(lit1, fx.uninhabited))
Expand All @@ -262,7 +262,7 @@ def test_is_proper_subtype_and_subtype_literal_types(self) -> None:
assert_false(is_proper_subtype(lit2, lit3))

assert_true(is_subtype(lit1, fx.a))
assert_false(is_subtype(lit1, fx.b))
assert_false(is_subtype(lit1, fx.d))
assert_false(is_subtype(fx.a, lit1))
assert_true(is_subtype(fx.uninhabited, lit1))
assert_false(is_subtype(lit1, fx.uninhabited))
Expand Down Expand Up @@ -621,6 +621,41 @@ def test_type_type(self) -> None:
self.assert_join(self.fx.type_type, self.fx.type_any, self.fx.type_type)
self.assert_join(self.fx.type_b, self.fx.anyt, self.fx.anyt)

def test_literal_type(self) -> None:
a = self.fx.a
d = self.fx.d
lit1 = LiteralType(1, a)
lit2 = LiteralType(2, a)
lit3 = LiteralType("foo", d)

self.assert_join(lit1, lit1, lit1)
self.assert_join(lit1, a, a)
self.assert_join(lit1, d, self.fx.o)
self.assert_join(lit1, lit2, a)
self.assert_join(lit1, lit3, self.fx.o)
self.assert_join(lit1, self.fx.anyt, self.fx.anyt)
self.assert_join(UnionType([lit1, lit2]), lit2, UnionType([lit1, lit2]))
self.assert_join(UnionType([lit1, lit2]), a, a)
self.assert_join(UnionType([lit1, lit3]), a, UnionType([a, lit3]))
self.assert_join(UnionType([d, lit3]), lit3, UnionType([d, lit3]))
self.assert_join(UnionType([d, lit3]), d, UnionType([d, lit3]))
self.assert_join(UnionType([a, lit1]), lit1, UnionType([a, lit1]))
self.assert_join(UnionType([a, lit1]), lit2, UnionType([a, lit1]))
self.assert_join(UnionType([lit1, lit2]),
UnionType([lit1, lit2]),
UnionType([lit1, lit2]))

# The order in which we try joining two unions influences the
# ordering of the items in the final produced unions. So, we
# manually call 'assert_simple_join' and tune the output
# after swapping the arguments here.
self.assert_simple_join(UnionType([lit1, lit2]),
UnionType([lit2, lit3]),
UnionType([lit1, lit2, lit3]))
self.assert_simple_join(UnionType([lit2, lit3]),
UnionType([lit1, lit2]),
UnionType([lit2, lit3, lit1]))

# There are additional test cases in check-inference.test.

# TODO: Function types + varargs and default args.
Expand Down
2 changes: 2 additions & 0 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,8 @@ def visit_unbound_type(self, t: UnboundType) -> TypeVarList:
return [(name, node.node)]
elif not self.include_callables and self._seems_like_callable(t):
return []
elif node and node.fullname in ('typing_extensions.Literal', 'typing.Literal'):
return []
else:
return super().visit_unbound_type(t)

Expand Down
Loading