diff --git a/mypy/checker.py b/mypy/checker.py index 6390b381d918..c6d82e580b7f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -53,7 +53,7 @@ from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any from mypy.semanal import set_callable_name, refers_to_fullname from mypy.mro import calculate_mro -from mypy.erasetype import erase_typevars +from mypy.erasetype import erase_typevars, remove_instance_last_known_values from mypy.expandtype import expand_type, expand_type_by_instance from mypy.visitor import NodeVisitor from mypy.join import join_types @@ -1868,10 +1868,9 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type self.check_indexed_assignment(index_lvalue, rvalue, lvalue) if inferred: - rvalue_type = self.expr_checker.accept( - rvalue, - in_final_declaration=inferred.is_final, - ) + rvalue_type = self.expr_checker.accept(rvalue) + if not inferred.is_final: + rvalue_type = remove_instance_last_known_values(rvalue_type) self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue) def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type], diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b079b82a6a68..856a96d271d6 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -141,15 +141,6 @@ def __init__(self, self.plugin = plugin self.type_context = [None] - # Set to 'True' whenever we are checking the expression in some 'Final' declaration. - # For example, if we're checking the "3" in a statement like "var: Final = 3". - # - # This flag changes the type that eventually gets inferred for "var". Instead of - # inferring *just* a 'builtins.int' instance, we infer an instance that keeps track - # of the underlying literal value. See the comments in Instance's constructors for - # more details. - self.in_final_declaration = False - # Temporary overrides for expression types. This is currently # used by the union math in overloads. # TODO: refactor this to use a pattern similar to one in @@ -224,8 +215,8 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type: def analyze_var_ref(self, var: Var, context: Context) -> Type: if var.type: if isinstance(var.type, Instance): - if self.is_literal_context() and var.type.final_value is not None: - return var.type.final_value + if self.is_literal_context() and var.type.last_known_value is not None: + return var.type.last_known_value if var.name() in {'True', 'False'}: return self.infer_literal_expr_type(var.name() == 'True', 'builtins.bool') return var.type @@ -1812,15 +1803,13 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty typ = self.named_type(fallback_name) if self.is_literal_context(): return LiteralType(value=value, fallback=typ) - elif self.in_final_declaration: + else: return typ.copy_modified(final_value=LiteralType( value=value, fallback=typ, line=typ.line, column=typ.column, )) - else: - return typ def visit_int_expr(self, e: IntExpr) -> Type: """Type check an integer literal (trivial).""" @@ -2450,7 +2439,11 @@ def visit_index_expr(self, e: IndexExpr) -> Type: It may also represent type application. """ result = self.visit_index_expr_helper(e) - return self.narrow_type_from_binder(e, result) + result = self.narrow_type_from_binder(e, result) + if (self.is_literal_context() and isinstance(result, Instance) + and result.last_known_value is not None): + result = result.last_known_value + return result def visit_index_expr_helper(self, e: IndexExpr) -> Type: if e.analyzed: @@ -2542,8 +2535,8 @@ def _get_value(self, index: Expression) -> Optional[int]: if isinstance(operand, IntExpr): return -1 * operand.value typ = self.accept(index) - if isinstance(typ, Instance) and typ.final_value is not None: - typ = typ.final_value + if isinstance(typ, Instance) and typ.last_known_value is not None: + typ = typ.last_known_value if isinstance(typ, LiteralType) and isinstance(typ.value, int): return typ.value return None @@ -2553,8 +2546,8 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) item_name = index.value else: typ = self.accept(index) - if isinstance(typ, Instance) and typ.final_value is not None: - typ = typ.final_value + if isinstance(typ, Instance) and typ.last_known_value is not None: + typ = typ.last_known_value if isinstance(typ, LiteralType) and isinstance(typ.value, str): item_name = typ.value @@ -3253,7 +3246,6 @@ def accept(self, type_context: Optional[Type] = None, allow_none_return: bool = False, always_allow_any: bool = False, - in_final_declaration: bool = False, ) -> Type: """Type check a node in the given type context. If allow_none_return is True and this expression is a call, allow it to return None. This @@ -3261,8 +3253,6 @@ def accept(self, """ if node in self.type_overrides: return self.type_overrides[node] - old_in_final_declaration = self.in_final_declaration - self.in_final_declaration = in_final_declaration self.type_context.append(type_context) try: if allow_none_return and isinstance(node, CallExpr): @@ -3274,8 +3264,8 @@ def accept(self, except Exception as err: report_internal_error(err, self.chk.errors.file, node.line, self.chk.errors, self.chk.options) + self.type_context.pop() - self.in_final_declaration = old_in_final_declaration assert typ is not None self.chk.store_type(node, typ) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 9c9e8e24518c..e17d5f044d84 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -101,8 +101,8 @@ def analyze_member_access(name: str, msg, chk=chk) result = _analyze_member_access(name, typ, mx, override_info) - if in_literal_context and isinstance(result, Instance) and result.final_value is not None: - return result.final_value + if in_literal_context and isinstance(result, Instance) and result.last_known_value is not None: + return result.last_known_value else: return result diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 7949cf46c6da..a179d1c5aef1 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -119,3 +119,17 @@ def visit_type_var(self, t: TypeVarType) -> Type: if self.erase_id(t.id): return self.replacement return t + + +def remove_instance_last_known_values(t: Type) -> Type: + return t.accept(LastKnownValueEraser()) + + +class LastKnownValueEraser(TypeTranslator): + """Removes the Literal[...] type that may be associated with any + Instance types.""" + + def visit_instance(self, t: Instance) -> Type: + if t.last_known_value: + return t.copy_modified(final_value=None) + return t diff --git a/mypy/fixup.py b/mypy/fixup.py index 7cc827f386fc..8d62567d24ed 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -155,8 +155,8 @@ def visit_instance(self, inst: Instance) -> None: base.accept(self) for a in inst.args: a.accept(self) - if inst.final_value is not None: - inst.final_value.accept(self) + if inst.last_known_value is not None: + inst.last_known_value.accept(self) def visit_any(self, o: Any) -> None: pass # Nothing to descend into. diff --git a/mypy/newsemanal/typeanal.py b/mypy/newsemanal/typeanal.py index e8c3b8c7ce2c..56b4259a9fa7 100644 --- a/mypy/newsemanal/typeanal.py +++ b/mypy/newsemanal/typeanal.py @@ -700,9 +700,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L elif isinstance(arg, (NoneType, LiteralType)): # Types that we can just add directly to the literal/potential union of literals. return [arg] - elif isinstance(arg, Instance) and arg.final_value is not None: + elif isinstance(arg, Instance) and arg.last_known_value is not None: # Types generated from declarations like "var: Final = 4". - return [arg.final_value] + return [arg.last_known_value] elif isinstance(arg, UnionType): out = [] for union_arg in arg.items: diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index f16fc534a32d..879fcf8dba69 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -125,8 +125,8 @@ def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]: """If this expression is a string literal, or if the corresponding type is something like 'Literal["some string here"]', returns the underlying string value. Otherwise, returns None.""" - if isinstance(typ, Instance) and typ.final_value is not None: - typ = typ.final_value + if isinstance(typ, Instance) and typ.last_known_value is not None: + typ = typ.last_known_value if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str': val = typ.value diff --git a/mypy/sametypes.py b/mypy/sametypes.py index d39ae878b603..0777865421e0 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -79,7 +79,7 @@ def visit_instance(self, left: Instance) -> bool: return (isinstance(self.right, Instance) and left.type == self.right.type and is_same_types(left.args, self.right.args) and - left.final_value == self.right.final_value) + left.last_known_value == self.right.last_known_value) def visit_type_var(self, left: TypeVarType) -> bool: return (isinstance(self.right, TypeVarType) and diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 962400e451fd..d07c7aca5e40 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -284,7 +284,7 @@ def visit_instance(self, typ: Instance) -> SnapshotItem: return ('Instance', typ.type.fullname(), snapshot_types(typ.args), - None if typ.final_value is None else snapshot_type(typ.final_value)) + None if typ.last_known_value is None else snapshot_type(typ.last_known_value)) def visit_type_var(self, typ: TypeVarType) -> SnapshotItem: return ('TypeVar', diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index 9ecbbefa9b48..41b483a56502 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -342,8 +342,8 @@ def visit_instance(self, typ: Instance) -> None: typ.type = self.fixup(typ.type) for arg in typ.args: arg.accept(self) - if typ.final_value: - typ.final_value.accept(self) + if typ.last_known_value: + typ.last_known_value.accept(self) def visit_any(self, typ: AnyType) -> None: pass diff --git a/mypy/server/deps.py b/mypy/server/deps.py index c957ec81a310..2bd20f72a62b 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -882,8 +882,8 @@ def visit_instance(self, typ: Instance) -> List[str]: triggers = [trigger] for arg in typ.args: triggers.extend(self.get_type_triggers(arg)) - if typ.final_value: - triggers.extend(self.get_type_triggers(typ.final_value)) + if typ.last_known_value: + triggers.extend(self.get_type_triggers(typ.last_known_value)) return triggers def visit_any(self, typ: AnyType) -> List[str]: diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index ca1b4062a8eb..f45e8b6849a5 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -164,8 +164,8 @@ def visit_deleted_type(self, t: DeletedType) -> Type: def visit_instance(self, t: Instance) -> Type: final_value = None # type: Optional[LiteralType] - if t.final_value is not None: - raw_final_value = t.final_value.accept(self) + if t.last_known_value is not None: + raw_final_value = t.last_known_value.accept(self) assert isinstance(raw_final_value, LiteralType) final_value = raw_final_value return Instance( @@ -173,7 +173,7 @@ def visit_instance(self, t: Instance) -> Type: args=self.translate_types(t.args), line=t.line, column=t.column, - final_value=final_value, + last_known_value=final_value, ) def visit_type_var(self, t: TypeVarType) -> Type: diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 4dcbf0ad7a00..30dadc84656e 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -730,9 +730,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L elif isinstance(arg, (NoneType, LiteralType)): # Types that we can just add directly to the literal/potential union of literals. return [arg] - elif isinstance(arg, Instance) and arg.final_value is not None: + elif isinstance(arg, Instance) and arg.last_known_value is not None: # Types generated from declarations like "var: Final = 4". - return [arg.final_value] + return [arg.last_known_value] elif isinstance(arg, UnionType): out = [] for union_arg in arg.items: diff --git a/mypy/types.py b/mypy/types.py index c704101d5bb7..1d86a19d1c04 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -579,11 +579,11 @@ class Instance(Type): The list of type variables may be empty. """ - __slots__ = ('type', 'args', 'erased', 'invalid', 'type_ref', 'final_value') + __slots__ = ('type', 'args', 'erased', 'invalid', 'type_ref', 'last_known_value') def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type], line: int = -1, column: int = -1, erased: bool = False, - final_value: Optional['LiteralType'] = None) -> None: + last_known_value: Optional['LiteralType'] = None) -> None: super().__init__(line, column) self.type = typ self.args = args @@ -595,15 +595,31 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type], # True if recovered after incorrect number of type arguments error self.invalid = False - # This field keeps track of the underlying Literal[...] value if this instance - # was created via a Final declaration. For example, if we did `x: Final = 3`, x - # would have an instance with a `final_value` of `LiteralType(3, int_fallback)`. + # This field keeps track of the underlying Literal[...] value associated with + # this instance, if one is known. # - # Or more broadly, this field lets this Instance "remember" its original declaration. - # We want this behavior because we want implicit Final declarations to act pretty - # much identically with constants: we should be able to replace any places where we - # use some Final variable with the original value and get the same type-checking - # behavior. For example, we want this program: + # This field is set whenever possible within expressions, but is erased upon + # variable assignment (see erasetype.remove_instance_last_known_values) unless + # the variable is declared to be final. + # + # For example, consider the following program: + # + # a = 1 + # b: Final[int] = 2 + # c: Final = 3 + # print(a + b + c + 4) + # + # The 'Instance' objects associated with the expressions '1', '2', '3', and '4' will + # have last_known_values of type Literal[1], Literal[2], Literal[3], and Literal[4] + # respectively. However, the Instance object assigned to 'a' and 'b' will have their + # last_known_value erased: variable 'a' is mutable; variable 'b' was declared to be + # specifically an int. + # + # Or more broadly, this field lets this Instance "remember" its original declaration + # when applicable. We want this behavior because we want implicit Final declarations + # to act pretty much identically with constants: we should be able to replace any + # places where we use some Final variable with the original value and get the same + # type-checking behavior. For example, we want this program: # # def expects_literal(x: Literal[3]) -> None: pass # var: Final = 3 @@ -617,39 +633,37 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: List[Type], # In order to make this work (especially with literal types), we need var's type # (an Instance) to remember the "original" value. # - # This field is currently set only when we encounter an *implicit* final declaration - # like `x: Final = 3` where the RHS is some literal expression. This field remains 'None' - # when we do things like `x: Final[int] = 3` or `x: Final = foo + bar`. + # Preserving this value within expressions is useful for similar reasons. # # Currently most of mypy will ignore this field and will continue to treat this type like # a regular Instance. We end up using this field only when we are explicitly within a # Literal context. - self.final_value = final_value + self.last_known_value = last_known_value def accept(self, visitor: 'TypeVisitor[T]') -> T: return visitor.visit_instance(self) def __hash__(self) -> int: - return hash((self.type, tuple(self.args), self.final_value)) + return hash((self.type, tuple(self.args), self.last_known_value)) def __eq__(self, other: object) -> bool: if not isinstance(other, Instance): return NotImplemented return (self.type == other.type and self.args == other.args - and self.final_value == other.final_value) + and self.last_known_value == other.last_known_value) def serialize(self) -> Union[JsonDict, str]: assert self.type is not None type_ref = self.type.fullname() - if not self.args and not self.final_value: + if not self.args and not self.last_known_value: return type_ref data = {'.class': 'Instance', } # type: JsonDict data['type_ref'] = type_ref data['args'] = [arg.serialize() for arg in self.args] - if self.final_value is not None: - data['final_value'] = self.final_value.serialize() + if self.last_known_value is not None: + data['last_known_value'] = self.last_known_value.serialize() return data @classmethod @@ -666,8 +680,8 @@ def deserialize(cls, data: Union[JsonDict, str]) -> 'Instance': args = [deserialize_type(arg) for arg in args_list] inst = Instance(NOT_READY, args) inst.type_ref = data['type_ref'] # Will be fixed up by fixup.py later. - if 'final_value' in data: - inst.final_value = LiteralType.deserialize(data['final_value']) + if 'last_known_value' in data: + inst.last_known_value = LiteralType.deserialize(data['last_known_value']) return inst def copy_modified(self, *, @@ -679,7 +693,7 @@ def copy_modified(self, *, self.line, self.column, self.erased, - final_value if final_value is not _dummy else self.final_value, + final_value if final_value is not _dummy else self.last_known_value, ) def has_readable_member(self, name: str) -> bool: diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 05b9cff1fbf3..f34832f45684 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2474,6 +2474,7 @@ from typing_extensions import Final, Literal a: Final = 1 implicit = [a] explicit: List[Literal[1]] = [a] +direct = [1] def force1(x: List[Literal[1]]) -> None: pass def force2(x: Literal[1]) -> None: pass @@ -2487,6 +2488,12 @@ force2(reveal_type(implicit[0])) # E: Argument 1 to "force2" has incompatible ty reveal_type(explicit) # E: Revealed type is 'builtins.list[Literal[1]]' force1(reveal_type(explicit)) # E: Revealed type is 'builtins.list[Literal[1]]' force2(reveal_type(explicit[0])) # E: Revealed type is 'Literal[1]' + +reveal_type(direct) # E: Revealed type is 'builtins.list[builtins.int*]' +force1(reveal_type(direct)) # E: Argument 1 to "force1" has incompatible type "List[int]"; expected "List[Literal[1]]" \ + # E: Revealed type is 'builtins.list[builtins.int*]' +force2(reveal_type(direct[0])) # E: Argument 1 to "force2" has incompatible type "int"; expected "Literal[1]" \ + # E: Revealed type is 'builtins.int*' [builtins fixtures/list.pyi] [out] @@ -2806,3 +2813,42 @@ Alias = Test x: Literal[Alias.FOO] reveal_type(x) # E: Revealed type is 'Literal[__main__.Test.FOO]' [out] + +[case testLiteralWithFinalPropagation] +from typing_extensions import Final, Literal + +a: Final = 3 +b: Final = a +c = a + +def expect_3(x: Literal[3]) -> None: pass +expect_3(a) +expect_3(b) +expect_3(c) # E: Argument 1 to "expect_3" has incompatible type "int"; expected "Literal[3]" +[out] + +[case testLiteralWithFinalPropagationIsNotLeaking] +from typing_extensions import Final, Literal + +final_tuple_direct: Final = (2, 3) +final_tuple_indirect: Final = final_tuple_direct +mutable_tuple = final_tuple_direct +final_list_1: Final = [2] +final_list_2: Final = [2, 2] +final_dict: Final = {"foo": 2} +final_set_1: Final = {2} +final_set_2: Final = {2, 2} + +def expect_2(x: Literal[2]) -> None: pass + +expect_2(final_tuple_direct[0]) +expect_2(final_tuple_indirect[0]) + +expect_2(mutable_tuple[0]) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]" +expect_2(final_list_1[0]) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]" +expect_2(final_list_2[0]) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]" +expect_2(final_dict["foo"]) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]" +expect_2(final_set_1.pop()) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]" +expect_2(final_set_2.pop()) # E: Argument 1 to "expect_2" has incompatible type "int"; expected "Literal[2]" +[builtins fixtures/isinstancelist.pyi] +[out] diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test index 232bafc52e59..25675fccbfb4 100644 --- a/test-data/unit/fine-grained.test +++ b/test-data/unit/fine-grained.test @@ -8553,8 +8553,9 @@ from typing_extensions import Literal def expect_3(x: Literal[3]) -> None: pass expect_3(foo) [file mod1.py] +from typing_extensions import Final from mod2 import bar -foo = bar +foo: Final = bar [file mod2.py] from mod3 import qux as bar [file mod3.py] diff --git a/test-data/unit/fixtures/isinstancelist.pyi b/test-data/unit/fixtures/isinstancelist.pyi index 25ff5888a2cf..2730e1663c50 100644 --- a/test-data/unit/fixtures/isinstancelist.pyi +++ b/test-data/unit/fixtures/isinstancelist.pyi @@ -51,3 +51,4 @@ class set(Generic[T]): def add(self, x: T) -> None: pass def discard(self, x: T) -> None: pass def update(self, x: Set[T]) -> None: pass + def pop(self) -> T: pass