diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index 1caf25cbc2be..ede5c5c7dbe3 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -2125,7 +2125,7 @@ export class Checker extends ParseTreeWalker { return; } - if (this._isTypeComparable(leftSubtype, rightSubtype)) { + if (this._evaluator.isTypeComparable(leftSubtype, rightSubtype)) { isComparable = true; } @@ -2151,97 +2151,6 @@ export class Checker extends ParseTreeWalker { } } - // Determines whether the two types are potentially comparable -- i.e. - // their types overlap in such a way that it makes sense for them to - // be compared with an == or != operator. - private _isTypeComparable(leftType: Type, rightType: Type) { - if (isAnyOrUnknown(leftType) || isAnyOrUnknown(rightType)) { - return true; - } - - if (isNever(leftType) || isNever(rightType)) { - return false; - } - - if (isModule(leftType) || isModule(rightType)) { - return isTypeSame(leftType, rightType, { ignoreConditions: true }); - } - - const isLeftCallable = isFunction(leftType) || isOverloaded(leftType); - const isRightCallable = isFunction(rightType) || isOverloaded(rightType); - if (isLeftCallable !== isRightCallable) { - return false; - } - - if (isInstantiableClass(leftType) || (isClassInstance(leftType) && ClassType.isBuiltIn(leftType, 'type'))) { - if ( - isInstantiableClass(rightType) || - (isClassInstance(rightType) && ClassType.isBuiltIn(rightType, 'type')) - ) { - const genericLeftType = ClassType.specialize(leftType, /* typeArgs */ undefined); - const genericRightType = ClassType.specialize(rightType, /* typeArgs */ undefined); - - if ( - this._evaluator.assignType(genericLeftType, genericRightType) || - this._evaluator.assignType(genericRightType, genericLeftType) - ) { - return true; - } - } - - // Does the class have an operator overload for eq? - const metaclass = leftType.shared.effectiveMetaclass; - if (metaclass && isClass(metaclass)) { - if (lookUpClassMember(metaclass, '__eq__', MemberAccessFlags.SkipObjectBaseClass)) { - return true; - } - } - - return false; - } - - if (isClassInstance(leftType)) { - if (isClass(rightType)) { - const genericLeftType = ClassType.specialize(leftType, /* typeArgs */ undefined); - const genericRightType = ClassType.specialize(rightType, /* typeArgs */ undefined); - - if ( - this._evaluator.assignType(genericLeftType, genericRightType) || - this._evaluator.assignType(genericRightType, genericLeftType) - ) { - return true; - } - - // Assume that if the types are disjoint and built-in classes that they - // will never be comparable. - if (ClassType.isBuiltIn(leftType) && ClassType.isBuiltIn(rightType) && TypeBase.isInstance(rightType)) { - return false; - } - } - - // Does the class have an operator overload for eq? - const eqMethod = lookUpClassMember( - ClassType.cloneAsInstantiable(leftType), - '__eq__', - MemberAccessFlags.SkipObjectBaseClass - ); - - if (eqMethod) { - // If this is a synthesized method for a dataclass, we can assume - // that other dataclass types will not be comparable. - if (ClassType.isDataClass(leftType) && eqMethod.symbol.getSynthesizedType()) { - return false; - } - - return true; - } - - return false; - } - - return true; - } - // If the function is a generator, validates that its annotated return type // is appropriate for a generator. private _validateGeneratorReturnType(node: FunctionNode, functionType: FunctionType) { diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 25ce7dd27702..f4427db0d6aa 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -25540,6 +25540,120 @@ export function createTypeEvaluator( ); } + // Determines whether the two types are potentially comparable -- i.e. + // their types overlap in such a way that it makes sense for them to + // be compared with an == or != operator. + function isTypeComparable(leftType: Type, rightType: Type) { + if (isAnyOrUnknown(leftType) || isAnyOrUnknown(rightType)) { + return true; + } + + if (isNever(leftType) || isNever(rightType)) { + return false; + } + + if (isModule(leftType) || isModule(rightType)) { + return isTypeSame(leftType, rightType, { ignoreConditions: true }); + } + + const isLeftCallable = isFunction(leftType) || isOverloaded(leftType); + const isRightCallable = isFunction(rightType) || isOverloaded(rightType); + if (isLeftCallable !== isRightCallable) { + return false; + } + + if (isInstantiableClass(leftType) || (isClassInstance(leftType) && ClassType.isBuiltIn(leftType, 'type'))) { + if ( + isInstantiableClass(rightType) || + (isClassInstance(rightType) && ClassType.isBuiltIn(rightType, 'type')) + ) { + const genericLeftType = ClassType.specialize(leftType, /* typeArgs */ undefined); + const genericRightType = ClassType.specialize(rightType, /* typeArgs */ undefined); + + if (assignType(genericLeftType, genericRightType) || assignType(genericRightType, genericLeftType)) { + return true; + } + } + + // Does the class have an operator overload for eq? + const metaclass = leftType.shared.effectiveMetaclass; + if (metaclass && isClass(metaclass)) { + if (lookUpClassMember(metaclass, '__eq__', MemberAccessFlags.SkipObjectBaseClass)) { + return true; + } + } + + return false; + } + + if (isClassInstance(leftType)) { + if (isClass(rightType)) { + const genericLeftType = ClassType.specialize(leftType, /* typeArgs */ undefined); + const genericRightType = ClassType.specialize(rightType, /* typeArgs */ undefined); + + if (assignType(genericLeftType, genericRightType) || assignType(genericRightType, genericLeftType)) { + return true; + } + + // Assume that if the types are disjoint and built-in classes that they + // will never be comparable. + if (ClassType.isBuiltIn(leftType) && ClassType.isBuiltIn(rightType) && TypeBase.isInstance(rightType)) { + // We need to be careful with bool and int literals because + // they are comparable under certain circumstances. + let boolType: ClassType | undefined; + let intType: ClassType | undefined; + if (ClassType.isBuiltIn(leftType, 'bool') && ClassType.isBuiltIn(rightType, 'int')) { + boolType = leftType; + intType = rightType; + } else if (ClassType.isBuiltIn(rightType, 'bool') && ClassType.isBuiltIn(leftType, 'int')) { + boolType = rightType; + intType = leftType; + } + + if (boolType && intType) { + const intVal = intType.priv?.literalValue as number | BigInt | undefined; + if (intVal === undefined) { + return true; + } + if (intVal !== 0 && intVal !== 1) { + return false; + } + + const boolVal = boolType.priv?.literalValue as boolean | undefined; + if (boolVal === undefined) { + return true; + } + + return boolVal === (intVal === 1); + } + + return false; + } + } + + // Does the class have an operator overload for eq? + const eqMethod = lookUpClassMember( + ClassType.cloneAsInstantiable(leftType), + '__eq__', + MemberAccessFlags.SkipObjectBaseClass + ); + + if (eqMethod) { + // If this is a synthesized method for a dataclass, we can assume + // that other dataclass types will not be comparable. + if (ClassType.isDataClass(leftType) && eqMethod.symbol.getSynthesizedType()) { + return false; + } + + return true; + } + + return false; + } + + return true; + } + function assignToUnionType( destType: UnionType, srcType: Type, @@ -28325,6 +28439,7 @@ export function createTypeEvaluator( getCallSignatureInfo, getAbstractSymbols, narrowConstrainedTypeVar, + isTypeComparable, assignType, validateOverrideMethod, validateCallArgs, diff --git a/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts b/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts index 3bbbab30cfa9..3c83e4a6574f 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts @@ -768,6 +768,7 @@ export interface TypeEvaluator { getCallSignatureInfo: (node: CallNode, activeIndex: number, activeOrFake: boolean) => CallSignatureInfo | undefined; getAbstractSymbols: (classType: ClassType) => AbstractSymbol[]; narrowConstrainedTypeVar: (node: ParseNode, typeVar: TypeVarType) => Type | undefined; + isTypeComparable: (leftType: Type, rightType: Type) => boolean; assignType: ( destType: Type, diff --git a/packages/pyright-internal/src/analyzer/typeGuards.ts b/packages/pyright-internal/src/analyzer/typeGuards.ts index fd206a20a6b2..3e614fbefd65 100644 --- a/packages/pyright-internal/src/analyzer/typeGuards.ts +++ b/packages/pyright-internal/src/analyzer/typeGuards.ts @@ -2066,15 +2066,40 @@ export function narrowTypeForContainerElementType(evaluator: TypeEvaluator, refe return referenceSubtype; } - if (evaluator.assignType(referenceSubtype, elementSubtype)) { + // If the two types are disjoint (i.e. are not comparable), eliminate this subtype. + if (!evaluator.isTypeComparable(elementSubtype, referenceSubtype)) { + return undefined; + } + + // If one of the two types is a literal, we can narrow to that type. + if ( + isClassInstance(elementSubtype) && + (isLiteralType(elementSubtype) || isNoneInstance(elementSubtype)) && + evaluator.assignType(referenceSubtype, elementSubtype) + ) { return stripTypeForm(addConditionToType(elementSubtype, referenceSubtype.props?.condition)); } - if (evaluator.assignType(elementSubtype, referenceSubtype)) { + if ( + isClassInstance(referenceSubtype) && + (isLiteralType(referenceSubtype) || isNoneInstance(referenceSubtype)) && + evaluator.assignType(elementSubtype, referenceSubtype) + ) { return stripTypeForm(addConditionToType(referenceSubtype, elementSubtype.props?.condition)); } - return undefined; + // If the element type is a known class object that is assignable to + // the reference type, we can narrow to that class object. + if ( + isInstantiableClass(elementSubtype) && + !elementSubtype.priv.includeSubclasses && + evaluator.assignType(referenceSubtype, elementSubtype) + ) { + return stripTypeForm(addConditionToType(elementSubtype, referenceSubtype.props?.condition)); + } + + // It's not safe to narrow. + return referenceSubtype; }); }); } diff --git a/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py b/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py index 51f4f5363597..37ee5b36b625 100644 --- a/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py +++ b/packages/pyright-internal/src/tests/samples/typeNarrowingIn1.py @@ -85,10 +85,10 @@ def func6(x: type): def func7(x: object | bytes, y: str, z: int): if x in (y, z): - reveal_type(x, expected_text="str | int") + reveal_type(x, expected_text="object") else: reveal_type(x, expected_text="object | bytes") - reveal_type(x, expected_text="str | int | object | bytes") + reveal_type(x, expected_text="object | bytes") def func8(x: object): @@ -127,13 +127,6 @@ class TD2(TypedDict): y: str -def func11(x: dict[str, str]): - if x in (TD1(x="a"), TD2(y="b")): - reveal_type(x, expected_text="TD1 | TD2") - else: - reveal_type(x, expected_text="dict[str, str]") - - T1 = TypeVar("T1", TD1, TD2) @@ -175,4 +168,39 @@ def func14(x: str, y: dict[Any, Any]): def func15(x: Any, y: dict[str, str]): if x in y: - reveal_type(x, expected_text="str") + reveal_type(x, expected_text="Any") + + +def func16(x: int, y: list[Literal[0, 1]]): + if x in y: + reveal_type(x, expected_text="Literal[0, 1]") + + +def func17(x: Literal[-1, 0], y: list[Literal[0, 1]]): + if x in y: + reveal_type(x, expected_text="Literal[0]") + + +def func18(x: Literal[0, 1, 2], y: list[Literal[0, 1]]): + if x in y: + reveal_type(x, expected_text="Literal[0, 1]") + + +def func19(x: float, y: list[int]): + if x in y: + reveal_type(x, expected_text="float") + + +def func20(x: float, y: list[Literal[0, 1]]): + if x in y: + reveal_type(x, expected_text="Literal[0, 1]") + + +def func21(x: int, y: list[Literal[0, True]]): + if x in y: + reveal_type(x, expected_text="Literal[0, True]") + + +def func22(x: bool, y: list[Literal[0, 1]]): + if x in y: + reveal_type(x, expected_text="bool") diff --git a/packages/pyright-internal/src/tests/samples/unnecessaryContains1.py b/packages/pyright-internal/src/tests/samples/unnecessaryContains1.py index e2437c07ba4d..940ba37f6a43 100644 --- a/packages/pyright-internal/src/tests/samples/unnecessaryContains1.py +++ b/packages/pyright-internal/src/tests/samples/unnecessaryContains1.py @@ -34,7 +34,7 @@ def func3(x: list[str]): return # This should generate an error if "reportUnnecessaryContains" is enabled. - if x not in ([1, 2], [3]): + if x not in ((1, 2), (3,)): pass