Skip to content

Commit

Permalink
Fixed bug that results in a false negative if a | union operator cr…
Browse files Browse the repository at this point in the history
…eates a union of generic types. These types should be specialized with default type arguments. This addresses #9415.
  • Loading branch information
erictraut committed Nov 7, 2024
1 parent 630dfc2 commit a5d8208
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 9 deletions.
9 changes: 9 additions & 0 deletions packages/pyright-internal/src/analyzer/operations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import {
removeNoneFromUnion,
someSubtypes,
specializeTupleClass,
specializeWithDefaultTypeArgs,
transformPossibleRecursiveTypeAlias,
} from './typeUtils';
import {
Expand Down Expand Up @@ -366,6 +367,14 @@ export function getTypeOfBinaryOperation(
}

if (isUnionableType([adjustedLeftType, adjustedRightType])) {
if (isInstantiableClass(adjustedLeftType)) {
adjustedLeftType = specializeWithDefaultTypeArgs(adjustedLeftType);
}

if (isInstantiableClass(adjustedRightType)) {
adjustedRightType = specializeWithDefaultTypeArgs(adjustedRightType);
}

return createUnionType(
evaluator,
node,
Expand Down
20 changes: 12 additions & 8 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2094,18 +2094,15 @@ export function getTypeVarArgsRecursive(type: Type, recursionCount = 0): TypeVar
// Creates a specialized version of the class, filling in any unspecified
// type arguments with Unknown or default value.
export function specializeWithDefaultTypeArgs(type: ClassType): ClassType {
if (type.shared.typeParams.length === 0 || type.priv.typeArgs) {
if (type.shared.typeParams.length === 0 || type.priv.typeArgs || !type.shared.typeVarScopeId) {
return type;
}

const solution = new ConstraintSolution();
const typeParams = ClassType.getTypeParams(type);

typeParams.forEach((typeParam) => {
solution.setType(typeParam, applySolvedTypeVars(typeParam.shared.defaultType, solution));
});

return applySolvedTypeVars(type, solution) as ClassType;
return applySolvedTypeVars(type, solution, {
replaceUnsolved: { scopeIds: [type.shared.typeVarScopeId], tupleClassType: undefined },
}) as ClassType;
}

// Builds a mapping between type parameters and their specialized
Expand Down Expand Up @@ -3651,6 +3648,7 @@ export class TypeVarTransformer {
let newTypeArgs: Type[] | undefined;
let newTupleTypeArgs: TupleTypeArg[] | undefined;
let specializationNeeded = false;
let isTypeArgExplicit = true;

// If type args were previously provided, specialize them.

Expand Down Expand Up @@ -3701,6 +3699,7 @@ export class TypeVarTransformer {
const newTypeArgType = this.apply(typeParams[0], recursionCount);
newTupleTypeArgs = [{ type: newTypeArgType, isUnbounded: true }];
specializationNeeded = true;
isTypeArgExplicit = false;
}
}

Expand All @@ -3713,6 +3712,11 @@ export class TypeVarTransformer {

if (!newTypeArgs) {
const typeArgs = classType.priv.typeArgs ?? typeParams;

if (!classType.priv.typeArgs) {
isTypeArgExplicit = false;
}

newTypeArgs = typeArgs.map((oldTypeArgType) => {
let newTypeArgType = this.apply(oldTypeArgType, recursionCount);
if (newTypeArgType !== oldTypeArgType) {
Expand All @@ -3736,7 +3740,7 @@ export class TypeVarTransformer {
return ClassType.specialize(
classType,
newTypeArgs,
/* isTypeArgExplicit */ true,
isTypeArgExplicit,
/* includeSubclasses */ undefined,
newTupleTypeArgs
);
Expand Down
10 changes: 10 additions & 0 deletions packages/pyright-internal/src/tests/samples/unions4.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,13 @@ def func1() -> Union: ...

# This should generate an error.
var1: Union


# This should generate two errors.
def func2(x: (list | set)[int]):
reveal_type(x)


# This should generate two errors.
def func3(x: Union[list, set][int]):
reveal_type(x)
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/typeEvaluator4.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ test('Unions3', () => {
test('Unions4', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['unions4.py']);

TestUtils.validateResults(analysisResults, 3);
TestUtils.validateResults(analysisResults, 7);
});

test('Unions5', () => {
Expand Down

0 comments on commit a5d8208

Please sign in to comment.