Skip to content

Commit

Permalink
Make literal exprs have inferred type of 'Literal' based on context (#…
Browse files Browse the repository at this point in the history
…5990)

This pull request modifies the type checking logic so that literal
expressions will have an inferred type of 'Literal' if the context asks
for a literal type. That is, it implements support for this:

    x: Literal[1] = 1
    y = 1

    reveal_type(x)  # E: Revealed type is 'Literal[1]'
    reveal_type(y)  # E: Revealed type is 'builtins.int'

This pull requests also implements the `visit_literal_type` method
in the  `constraints.ConstraintBuilderVisitor` and `join.TypeJoinVisitor`
methods. Both visitors are exercised indirectly through the "let's use
literal types in collection contexts" code, but only the latter is
tested directly: I wasn't really sure how to directly test
`ConstraintBuilderVisitor`.

The implementation is simple though -- I'm pretty sure literal types
count as a "leaf type" so it's fine to return an empty list
(no constraints).
  • Loading branch information
Michael0x2a authored Dec 5, 2018
1 parent 1c824b6 commit ad2d4ba
Show file tree
Hide file tree
Showing 6 changed files with 559 additions and 24 deletions.
29 changes: 26 additions & 3 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:

def analyze_var_ref(self, var: Var, context: Context) -> Type:
if var.type:
return var.type
if is_literal_type_like(self.type_context[-1]) and var.name() in {'True', 'False'}:
return LiteralType(var.name() == 'True', self.named_type('builtins.bool'))
else:
return var.type
else:
if not var.is_ready and self.chk.in_checked_function():
self.chk.handle_cannot_determine_type(var.name(), context)
Expand Down Expand Up @@ -1721,11 +1724,17 @@ def analyze_external_member_access(self, member: str, base_type: Type,

def visit_int_expr(self, e: IntExpr) -> Type:
"""Type check an integer literal (trivial)."""
return self.named_type('builtins.int')
typ = self.named_type('builtins.int')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ

def visit_str_expr(self, e: StrExpr) -> Type:
"""Type check a string literal (trivial)."""
return self.named_type('builtins.str')
typ = self.named_type('builtins.str')
if is_literal_type_like(self.type_context[-1]):
return LiteralType(value=e.value, fallback=typ)
return typ

def visit_bytes_expr(self, e: BytesExpr) -> Type:
"""Type check a bytes literal (trivial)."""
Expand Down Expand Up @@ -3583,3 +3592,17 @@ def merge_typevars_in_callables_by_name(
output.append(target)

return output, variables


def is_literal_type_like(t: Optional[Type]) -> bool:
"""Returns 'true' if the given type context is potentially either a LiteralType,
a Union of LiteralType, or something similar.
"""
if t is None:
return False
elif isinstance(t, LiteralType):
return True
elif isinstance(t, UnionType):
return any(is_literal_type_like(item) for item in t.items)
else:
return False
6 changes: 3 additions & 3 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def visit_erased_type(self, template: ErasedType) -> List[Constraint]:
def visit_deleted_type(self, template: DeletedType) -> List[Constraint]:
return []

def visit_literal_type(self, template: LiteralType) -> List[Constraint]:
return []

# Errors

def visit_partial_type(self, template: PartialType) -> List[Constraint]:
Expand Down Expand Up @@ -472,9 +475,6 @@ def visit_typeddict_type(self, template: TypedDictType) -> List[Constraint]:
else:
return []

def visit_literal_type(self, template: LiteralType) -> List[Constraint]:
raise NotImplementedError()

def visit_union_type(self, template: UnionType) -> List[Constraint]:
assert False, ("Unexpected UnionType in ConstraintBuilderVisitor"
" (should have been handled in infer_constraints)")
Expand Down
10 changes: 9 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def visit_instance(self, t: Instance) -> Type:
return join_types(t, self.s)
elif isinstance(self.s, TypedDictType):
return join_types(t, self.s)
elif isinstance(self.s, LiteralType):
return join_types(t, self.s)
else:
return self.default(self.s)

Expand Down Expand Up @@ -268,7 +270,13 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
return self.default(self.s)

def visit_literal_type(self, t: LiteralType) -> Type:
raise NotImplementedError()
if isinstance(self.s, LiteralType):
if t == self.s:
return t
else:
return join_types(self.s.fallback, t.fallback)
else:
return join_types(self.s, t.fallback)

def visit_partial_type(self, t: PartialType) -> Type:
# We only have partial information so we can't decide the join result. We should
Expand Down
43 changes: 39 additions & 4 deletions mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,11 @@ def test_is_proper_subtype_and_subtype_literal_types(self) -> None:
fx = self.fx

lit1 = LiteralType(1, fx.a)
lit2 = LiteralType("foo", fx.b)
lit3 = LiteralType("bar", fx.b)
lit2 = LiteralType("foo", fx.d)
lit3 = LiteralType("bar", fx.d)

assert_true(is_proper_subtype(lit1, fx.a))
assert_false(is_proper_subtype(lit1, fx.b))
assert_false(is_proper_subtype(lit1, fx.d))
assert_false(is_proper_subtype(fx.a, lit1))
assert_true(is_proper_subtype(fx.uninhabited, lit1))
assert_false(is_proper_subtype(lit1, fx.uninhabited))
Expand All @@ -262,7 +262,7 @@ def test_is_proper_subtype_and_subtype_literal_types(self) -> None:
assert_false(is_proper_subtype(lit2, lit3))

assert_true(is_subtype(lit1, fx.a))
assert_false(is_subtype(lit1, fx.b))
assert_false(is_subtype(lit1, fx.d))
assert_false(is_subtype(fx.a, lit1))
assert_true(is_subtype(fx.uninhabited, lit1))
assert_false(is_subtype(lit1, fx.uninhabited))
Expand Down Expand Up @@ -621,6 +621,41 @@ def test_type_type(self) -> None:
self.assert_join(self.fx.type_type, self.fx.type_any, self.fx.type_type)
self.assert_join(self.fx.type_b, self.fx.anyt, self.fx.anyt)

def test_literal_type(self) -> None:
a = self.fx.a
d = self.fx.d
lit1 = LiteralType(1, a)
lit2 = LiteralType(2, a)
lit3 = LiteralType("foo", d)

self.assert_join(lit1, lit1, lit1)
self.assert_join(lit1, a, a)
self.assert_join(lit1, d, self.fx.o)
self.assert_join(lit1, lit2, a)
self.assert_join(lit1, lit3, self.fx.o)
self.assert_join(lit1, self.fx.anyt, self.fx.anyt)
self.assert_join(UnionType([lit1, lit2]), lit2, UnionType([lit1, lit2]))
self.assert_join(UnionType([lit1, lit2]), a, a)
self.assert_join(UnionType([lit1, lit3]), a, UnionType([a, lit3]))
self.assert_join(UnionType([d, lit3]), lit3, UnionType([d, lit3]))
self.assert_join(UnionType([d, lit3]), d, UnionType([d, lit3]))
self.assert_join(UnionType([a, lit1]), lit1, UnionType([a, lit1]))
self.assert_join(UnionType([a, lit1]), lit2, UnionType([a, lit1]))
self.assert_join(UnionType([lit1, lit2]),
UnionType([lit1, lit2]),
UnionType([lit1, lit2]))

# The order in which we try joining two unions influences the
# ordering of the items in the final produced unions. So, we
# manually call 'assert_simple_join' and tune the output
# after swapping the arguments here.
self.assert_simple_join(UnionType([lit1, lit2]),
UnionType([lit2, lit3]),
UnionType([lit1, lit2, lit3]))
self.assert_simple_join(UnionType([lit2, lit3]),
UnionType([lit1, lit2]),
UnionType([lit2, lit3, lit1]))

# There are additional test cases in check-inference.test.

# TODO: Function types + varargs and default args.
Expand Down
2 changes: 2 additions & 0 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,8 @@ def visit_unbound_type(self, t: UnboundType) -> TypeVarList:
return [(name, node.node)]
elif not self.include_callables and self._seems_like_callable(t):
return []
elif node and node.fullname in ('typing_extensions.Literal', 'typing.Literal'):
return []
else:
return super().visit_unbound_type(t)

Expand Down
Loading

0 comments on commit ad2d4ba

Please sign in to comment.