Skip to content

Commit

Permalink
Improved handling of types float and complex, which are special-c…
Browse files Browse the repository at this point in the history
…ased in PEP 484 as "promotion types". The new logic now properly models the runtime behavior for `isinstance` and class pattern matching when used with these promotion types. This addresses #6008. (#6013)

Co-authored-by: Eric Traut <erictr@microsoft.com>
  • Loading branch information
erictraut and msfterictraut authored Sep 24, 2023
1 parent 78a083c commit 90ad51a
Show file tree
Hide file tree
Showing 13 changed files with 141 additions and 23 deletions.
3 changes: 2 additions & 1 deletion packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3510,6 +3510,8 @@ export class Checker extends ParseTreeWalker {
return transformPossibleRecursiveTypeAlias(subtype);
});

arg0Type = this._evaluator.expandPromotionTypes(node, arg0Type);

const arg1Type = this._evaluator.getType(node.arguments[1].valueExpression);
if (!arg1Type) {
return;
Expand Down Expand Up @@ -3663,7 +3665,6 @@ export class Checker extends ParseTreeWalker {
this._evaluator,
varType,
filterType,
filterType,
isInstanceCheck
);

Expand Down
3 changes: 2 additions & 1 deletion packages/pyright-internal/src/analyzer/patternMatching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,7 @@ function narrowTypeBasedOnClassPattern(
// If this is a class (but not a type alias that refers to a class),
// specialize it with Unknown type arguments.
if (isClass(exprType) && !exprType.typeAliasInfo) {
exprType = ClassType.cloneForPromotionType(exprType, /* isTypeArgumentExplicit */ false);
exprType = specializeClassType(exprType);
}

Expand All @@ -676,7 +677,7 @@ function narrowTypeBasedOnClassPattern(
const isPatternMetaclass = isMetaclassInstance(classInstance);

return evaluator.mapSubtypesExpandTypeVars(
type,
evaluator.expandPromotionTypes(pattern, type),
/* conditionFilter */ undefined,
(subjectSubtypeExpanded, subjectSubtypeUnexpanded) => {
// Handle the case where the class pattern references type() or a subtype thereof
Expand Down
66 changes: 56 additions & 10 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3545,6 +3545,41 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
writeTypeCache(target, { type, isIncomplete: isTypeIncomplete }, EvaluatorFlags.None);
}

// If the type includes promotion types, expand these to their constituent types.
function expandPromotionTypes(node: ParseNode, type: Type): Type {
return mapSubtypes(type, (subtype) => {
if (!isClass(subtype) || !subtype.includePromotions) {
return subtype;
}

const typesToCombine: Type[] = [ClassType.cloneForPromotionType(subtype, /* includePromotions */ false)];

const promotionTypeNames = typePromotions.get(subtype.details.fullName);
if (promotionTypeNames) {
for (const promotionTypeName of promotionTypeNames) {
const nameSplit = promotionTypeName.split('.');
let promotionSubtype = getBuiltInType(node, nameSplit[nameSplit.length - 1]);

if (promotionSubtype && isInstantiableClass(promotionSubtype)) {
promotionSubtype = ClassType.cloneForPromotionType(
promotionSubtype,
/* includePromotions */ false
);

if (isClassInstance(subtype)) {
promotionSubtype = ClassType.cloneAsInstance(promotionSubtype);
}

promotionSubtype = addConditionToType(promotionSubtype, subtype.condition);
typesToCombine.push(promotionSubtype);
}
}
}

return combineTypes(typesToCombine);
});
}

// Replaces all of the top-level TypeVars (as opposed to TypeVars
// used as type arguments in other types) with their concrete form.
// If conditionFilter is specified and the TypeVar is a constrained
Expand Down Expand Up @@ -5097,7 +5132,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions

// If the base type was incomplete and unbound, don't proceed
// because false positive errors will be generated.
if (baseTypeResult.isIncomplete && isUnbound(baseTypeResult.type)) {
if (baseTypeResult.isIncomplete && isUnbound(baseType)) {
return { type: UnknownType.create(/* isIncomplete */ true), isIncomplete: true };
}

Expand Down Expand Up @@ -5130,6 +5165,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
baseType = makeTopLevelTypeVarsConcrete(baseType);
}

// Do union expansion for promotion types.
baseType = expandPromotionTypes(node, baseType);

switch (baseType.category) {
case TypeCategory.Any:
case TypeCategory.Unknown:
Expand Down Expand Up @@ -15817,6 +15855,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions

classType.details.typeVarScopeId = ParseTreeUtils.getScopeIdForNode(node);

// Is this a special type that supports type promotions according to PEP 484?
if (typePromotions.has(classType.details.fullName)) {
classType.includePromotions = true;
}

// Some classes refer to themselves within type arguments used within
// base classes. We'll register the partially-constructed class type
// to allow these to be resolved.
Expand Down Expand Up @@ -21575,15 +21618,17 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}

// Handle special-case type promotions.
const promotionList = typePromotions.get(destType.details.fullName);
if (
promotionList &&
promotionList.some((srcName) =>
srcType.details.mro.some((mroClass) => isClass(mroClass) && srcName === mroClass.details.fullName)
)
) {
if ((flags & AssignTypeFlags.EnforceInvariance) === 0) {
return true;
if (destType.includePromotions) {
const promotionList = typePromotions.get(destType.details.fullName);
if (
promotionList &&
promotionList.some((srcName) =>
srcType.details.mro.some((mroClass) => isClass(mroClass) && srcName === mroClass.details.fullName)
)
) {
if ((flags & AssignTypeFlags.EnforceInvariance) === 0) {
return true;
}
}
}

Expand Down Expand Up @@ -26010,6 +26055,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
getGetterTypeFromProperty,
getTypeOfArgument,
markNamesAccessed,
expandPromotionTypes,
makeTopLevelTypeVarsConcrete,
mapSubtypesExpandTypeVars,
isTypeSubsumedByOtherType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ export interface TypeEvaluator {
getGetterTypeFromProperty: (propertyClass: ClassType, inferTypeIfNeeded: boolean) => Type | undefined;
getTypeOfArgument: (arg: FunctionArgument) => TypeResult;
markNamesAccessed: (node: ParseNode, names: string[]) => void;
expandPromotionTypes: (node: ParseNode, type: Type) => Type;
makeTopLevelTypeVarsConcrete: (type: Type, makeParamSpecsConcrete?: boolean) => Type;
mapSubtypesExpandTypeVars: (
type: Type,
Expand Down
6 changes: 3 additions & 3 deletions packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,6 @@ export function isIsinstanceFilterSuperclass(
export function isIsinstanceFilterSubclass(
evaluator: TypeEvaluator,
varType: ClassType,
filterType: Type,
concreteFilterType: ClassType,
isInstanceCheck: boolean
) {
Expand Down Expand Up @@ -1304,10 +1303,12 @@ function narrowTypeForIsInstance(
allowIntersections: boolean,
errorNode: ExpressionNode
): Type {
const expandedTypes = mapSubtypes(type, (subtype) => {
let expandedTypes = mapSubtypes(type, (subtype) => {
return transformPossibleRecursiveTypeAlias(subtype);
});

expandedTypes = evaluator.expandPromotionTypes(errorNode, type);

// Filters the varType by the parameters of the isinstance
// and returns the list of types the varType could be after
// applying the filter.
Expand Down Expand Up @@ -1336,7 +1337,6 @@ function narrowTypeForIsInstance(
const filterIsSubclass = isIsinstanceFilterSubclass(
evaluator,
varType,
filterType,
concreteFilterType,
isInstanceCheck
);
Expand Down
11 changes: 11 additions & 0 deletions packages/pyright-internal/src/analyzer/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,11 @@ export interface ClassType extends TypeBase {
// of abstract or protocol classes.
includeSubclasses?: boolean;

// This class type represents the class and any auto-promotion
// types that PEP 484 indicates should be treated as subclasses
// when the type appears within a type annotation.
includePromotions?: boolean;

// Some types can be further constrained to have
// literal types (e.g. true or 'string' or 3).
literalValue?: LiteralValue | undefined;
Expand Down Expand Up @@ -822,6 +827,12 @@ export namespace ClassType {
return newClassType;
}

export function cloneForPromotionType(classType: ClassType, includePromotions: boolean): ClassType {
const newClassType = TypeBase.cloneType(classType);
newClassType.includePromotions = includePromotions;
return newClassType;
}

export function cloneForTypeGuard(
classType: ClassType,
typeGuardType: Type,
Expand Down
13 changes: 10 additions & 3 deletions packages/pyright-internal/src/tests/samples/isinstance2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@ class DbModel(Document):
pass


def foo() -> Union[int, DbModel]:
def func1() -> Union[int, DbModel]:
return DbModel()


# This should not generate an error even though DbModel is
# derived from an unknown base class.
isinstance(foo(), int)
isinstance(func1(), int)


def bar(obj: object, typ: type):
def func2(obj: object, typ: type):
return isinstance(obj, typ)


def func3(obj: float):
if isinstance(obj, float):
reveal_type(obj, expected_text="float")
else:
reveal_type(obj, expected_text="int")
30 changes: 30 additions & 0 deletions packages/pyright-internal/src/tests/samples/matchClass1.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,33 @@ def func15(x: IntPair | None) -> None:
case IntPair((y, z)):
reveal_type(y, expected_text="int")
reveal_type(z, expected_text="int")


def func16(x: str | float | bool | None):
match x:
case str(v) | bool(v) | float(v):
reveal_type(v, expected_text="str | bool | float")
reveal_type(x, expected_text="str | bool | float")
case v:
reveal_type(v, expected_text="int | None")
reveal_type(x, expected_text="int | None")
reveal_type(x, expected_text="str | bool | float | int | None")


def func17(x: str | float | bool | None):
match x:
case str() | float() | bool():
reveal_type(x, expected_text="str | float | bool")
case _:
reveal_type(x, expected_text="int | None")
reveal_type(x, expected_text="str | float | bool | int | None")


def func18(x: str | float | bool | None):
match x:
case str(v) | float(v) | bool(v):
reveal_type(v, expected_text="str | float | bool")
reveal_type(x, expected_text="str | float | bool")
case _:
reveal_type(x, expected_text="int | None")
reveal_type(x, expected_text="str | float | bool | int | None")
3 changes: 2 additions & 1 deletion packages/pyright-internal/src/tests/samples/matchClass2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ class Point:
reveal_type(x, expected_text="int")
reveal_type(y, expected_text="int")
reveal_type(opt, expected_text="int | None")
distance = (x ** 2 + y ** 2) ** 0.5
distance = (x ** 2 + y ** 2) ** 0.5

Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def func10(subj: Color | None = None) -> list[str]:
def func11(subj: int | float | None):
match subj:
case float():
reveal_type(subj, expected_text="int | float")
reveal_type(subj, expected_text="float")
case int():
reveal_type(subj, expected_text="int")
case NoneType():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def func5(x: int | str | complex):
if isinstance(x, (int, str)):
reveal_type(x, expected_text="int | str")
else:
reveal_type(x, expected_text="complex")
reveal_type(x, expected_text="complex | float")


def func6(x: type[int] | type[str] | type[complex]):
if issubclass(x, (int, str)):
reveal_type(x, expected_text="type[int] | type[str]")
else:
reveal_type(x, expected_text="type[complex]")
reveal_type(x, expected_text="type[complex] | type[float]")


def func7(x: int | SomeTypedDict | None):
Expand Down
20 changes: 20 additions & 0 deletions packages/pyright-internal/src/tests/samples/typePromotions1.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,23 @@ def func3(x: IntSubclass) -> float:

def func4(x: IntNewType) -> float:
return x


def func5(f: float):
# This should generate an error because "hex" isn't
# a valid method for an int.
f.hex()

if isinstance(f, float):
reveal_type(f, expected_text="float")
f.hex()
else:
reveal_type(f, expected_text="int")


def func6(f: complex):
if isinstance(f, float):
reveal_type(f, expected_text="float")
f.hex()
else:
reveal_type(f, expected_text="complex | int")
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/typeEvaluator3.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ test('Never2', () => {
test('TypePromotions1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typePromotions1.py']);

TestUtils.validateResults(analysisResults, 0);
TestUtils.validateResults(analysisResults, 1);
});

test('Index1', () => {
Expand Down

0 comments on commit 90ad51a

Please sign in to comment.