Skip to content

Commit

Permalink
Added bidirectional type inference for container types that are param…
Browse files Browse the repository at this point in the history
…eterized by a protocol that matches a module. This addresses microsoft#9988. (microsoft#9989)
  • Loading branch information
erictraut authored Feb 27, 2025
1 parent 1d1b43c commit 861abd9
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 12 deletions.
11 changes: 9 additions & 2 deletions packages/pyright-internal/src/analyzer/tuples.ts
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ export function getTypeOfTupleWithContext(
if (node.d.items.length > maxInferredTupleEntryCount && entryTypeResults.some((result) => result.typeErrors)) {
type = makeTupleObject(evaluator, [{ type: UnknownType.create(), isUnbounded: true }]);
} else {
type = makeTupleObject(evaluator, evaluator.buildTupleTypesList(entryTypeResults, /* stripLiterals */ false));
type = makeTupleObject(
evaluator,
evaluator.buildTupleTypesList(entryTypeResults, /* stripLiterals */ false, /* convertModule */ false)
);
}

return { type, expectedTypeDiagAddendum, isIncomplete };
Expand All @@ -251,7 +254,11 @@ export function getTypeOfTupleInferred(evaluator: TypeEvaluator, node: TupleNode

const type = makeTupleObject(
evaluator,
evaluator.buildTupleTypesList(entryTypeResults, (flags & EvalFlags.StripTupleLiterals) !== 0)
evaluator.buildTupleTypesList(
entryTypeResults,
(flags & EvalFlags.StripTupleLiterals) !== 0,
/* convertModule */ true
)
);

if (isIncomplete) {
Expand Down
18 changes: 9 additions & 9 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8325,7 +8325,11 @@ export function createTypeEvaluator(
return typeResult;
}

function buildTupleTypesList(entryTypeResults: TypeResult[], stripLiterals: boolean): TupleTypeArg[] {
function buildTupleTypesList(
entryTypeResults: TypeResult[],
stripLiterals: boolean,
convertModule: boolean
): TupleTypeArg[] {
const entryTypes: TupleTypeArg[] = [];

for (const typeResult of entryTypeResults) {
Expand Down Expand Up @@ -8355,11 +8359,7 @@ export function createTypeEvaluator(
} else if (isNever(typeResult.type) && typeResult.isIncomplete && !typeResult.unpackedType) {
entryTypes.push({ type: UnknownType.create(/* isIncomplete */ true), isUnbounded: false });
} else {
let entryType = convertSpecialFormToRuntimeValue(
typeResult.type,
EvalFlags.None,
/* convertModule */ true
);
let entryType = convertSpecialFormToRuntimeValue(typeResult.type, EvalFlags.None, convertModule);
entryType = stripLiterals ? stripTypeForm(stripLiteralValue(entryType)) : entryType;
entryTypes.push({ type: entryType, isUnbounded: !!typeResult.unpackedType });
}
Expand Down Expand Up @@ -13956,10 +13956,10 @@ export function createTypeEvaluator(

// Strip any literal values and TypeForm types.
const keyTypes = keyTypeResults.map((t) =>
stripTypeForm(convertSpecialFormToRuntimeValue(stripLiteralValue(t.type), flags, /* convertModule */ true))
stripTypeForm(convertSpecialFormToRuntimeValue(stripLiteralValue(t.type), flags, !hasExpectedType))
);
const valueTypes = valueTypeResults.map((t) =>
stripTypeForm(convertSpecialFormToRuntimeValue(stripLiteralValue(t.type), flags, /* convertModule */ true))
stripTypeForm(convertSpecialFormToRuntimeValue(stripLiteralValue(t.type), flags, !hasExpectedType))
);

if (keyTypes.length > 0) {
Expand Down Expand Up @@ -14523,7 +14523,7 @@ export function createTypeEvaluator(
}

entryTypeResult.type = stripTypeForm(
convertSpecialFormToRuntimeValue(entryTypeResult.type, flags, /* convertModule */ true)
convertSpecialFormToRuntimeValue(entryTypeResult.type, flags, !hasExpectedType)
);

if (entryTypeResult.isIncomplete) {
Expand Down
6 changes: 5 additions & 1 deletion packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,11 @@ export interface TypeEvaluator {
getGetterTypeFromProperty: (propertyClass: ClassType) => Type | undefined;
getTypeOfArg: (arg: Arg, inferenceContext: InferenceContext | undefined) => TypeResult;
convertNodeToArg: (node: ArgumentNode) => ArgWithExpression;
buildTupleTypesList: (entryTypeResults: TypeResult[], stripLiterals: boolean) => TupleTypeArg[];
buildTupleTypesList: (
entryTypeResults: TypeResult[],
stripLiterals: boolean,
convertModules: boolean
) => TupleTypeArg[];
markNamesAccessed: (node: ParseNode, names: string[]) => void;
expandPromotionTypes: (node: ParseNode, type: Type) => Type;
makeTopLevelTypeVarsConcrete: (type: Type, makeParamSpecsConcrete?: boolean) => Type;
Expand Down
15 changes: 15 additions & 0 deletions packages/pyright-internal/src/tests/samples/module2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import typing
import importlib
from typing import Protocol
from types import ModuleType

importlib.reload(typing)
Expand All @@ -23,3 +24,17 @@ def func1(a: ModuleType):

v3 = (importlib, typing)
reveal_type(v3, expected_text="tuple[ModuleType, ModuleType]")


class ModuleProto(Protocol):
def reload(self, module: ModuleType) -> ModuleType: ...


v4: ModuleProto = importlib
reveal_type(v4, expected_text='Module("importlib")')

v5: tuple[ModuleProto] = (importlib,)
reveal_type(v5, expected_text='tuple[Module("importlib")]')

v6: list[ModuleProto] = [importlib]
reveal_type(v6, expected_text="list[ModuleProto]")

0 comments on commit 861abd9

Please sign in to comment.