diff --git a/packages/pyright-internal/src/analyzer/operations.ts b/packages/pyright-internal/src/analyzer/operations.ts index d71c582b19d7..dddd9b7daf2c 100644 --- a/packages/pyright-internal/src/analyzer/operations.ts +++ b/packages/pyright-internal/src/analyzer/operations.ts @@ -44,6 +44,7 @@ import { removeNoneFromUnion, someSubtypes, specializeTupleClass, + specializeWithDefaultTypeArgs, transformPossibleRecursiveTypeAlias, } from './typeUtils'; import { @@ -366,6 +367,14 @@ export function getTypeOfBinaryOperation( } if (isUnionableType([adjustedLeftType, adjustedRightType])) { + if (isInstantiableClass(adjustedLeftType)) { + adjustedLeftType = specializeWithDefaultTypeArgs(adjustedLeftType); + } + + if (isInstantiableClass(adjustedRightType)) { + adjustedRightType = specializeWithDefaultTypeArgs(adjustedRightType); + } + return createUnionType( evaluator, node, diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index 6ad73150e8be..2a53674c7e81 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -2094,18 +2094,15 @@ export function getTypeVarArgsRecursive(type: Type, recursionCount = 0): TypeVar // Creates a specialized version of the class, filling in any unspecified // type arguments with Unknown or default value. export function specializeWithDefaultTypeArgs(type: ClassType): ClassType { - if (type.shared.typeParams.length === 0 || type.priv.typeArgs) { + if (type.shared.typeParams.length === 0 || type.priv.typeArgs || !type.shared.typeVarScopeId) { return type; } const solution = new ConstraintSolution(); - const typeParams = ClassType.getTypeParams(type); - typeParams.forEach((typeParam) => { - solution.setType(typeParam, applySolvedTypeVars(typeParam.shared.defaultType, solution)); - }); - - return applySolvedTypeVars(type, solution) as ClassType; + return applySolvedTypeVars(type, solution, { + replaceUnsolved: { scopeIds: [type.shared.typeVarScopeId], tupleClassType: undefined }, + }) as ClassType; } // Builds a mapping between type parameters and their specialized @@ -3651,6 +3648,7 @@ export class TypeVarTransformer { let newTypeArgs: Type[] | undefined; let newTupleTypeArgs: TupleTypeArg[] | undefined; let specializationNeeded = false; + let isTypeArgExplicit = true; // If type args were previously provided, specialize them. @@ -3701,6 +3699,7 @@ export class TypeVarTransformer { const newTypeArgType = this.apply(typeParams[0], recursionCount); newTupleTypeArgs = [{ type: newTypeArgType, isUnbounded: true }]; specializationNeeded = true; + isTypeArgExplicit = false; } } @@ -3713,6 +3712,11 @@ export class TypeVarTransformer { if (!newTypeArgs) { const typeArgs = classType.priv.typeArgs ?? typeParams; + + if (!classType.priv.typeArgs) { + isTypeArgExplicit = false; + } + newTypeArgs = typeArgs.map((oldTypeArgType) => { let newTypeArgType = this.apply(oldTypeArgType, recursionCount); if (newTypeArgType !== oldTypeArgType) { @@ -3736,7 +3740,7 @@ export class TypeVarTransformer { return ClassType.specialize( classType, newTypeArgs, - /* isTypeArgExplicit */ true, + isTypeArgExplicit, /* includeSubclasses */ undefined, newTupleTypeArgs ); diff --git a/packages/pyright-internal/src/tests/samples/unions4.py b/packages/pyright-internal/src/tests/samples/unions4.py index 12d5deab450e..3ecc2e3938f8 100644 --- a/packages/pyright-internal/src/tests/samples/unions4.py +++ b/packages/pyright-internal/src/tests/samples/unions4.py @@ -18,3 +18,13 @@ def func1() -> Union: ... # This should generate an error. var1: Union + + +# This should generate two errors. +def func2(x: (list | set)[int]): + reveal_type(x) + + +# This should generate two errors. +def func3(x: Union[list, set][int]): + reveal_type(x) diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index 0e30d6773bfe..e44714228a50 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -548,7 +548,7 @@ test('Unions3', () => { test('Unions4', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['unions4.py']); - TestUtils.validateResults(analysisResults, 3); + TestUtils.validateResults(analysisResults, 7); }); test('Unions5', () => {