Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use a generic's bound when narrowing with isinstance if it's covariant, or Never if it's contravariant #745

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3794,7 +3794,7 @@ export class Checker extends ParseTreeWalker {
return;
}

const classTypeList = getIsInstanceClassTypes(this._evaluator, arg1Type);
const classTypeList = getIsInstanceClassTypes(this._evaluator, arg1Type, arg0Type);
if (!classTypeList) {
return;
}
Expand Down
8 changes: 7 additions & 1 deletion packages/pyright-internal/src/analyzer/patternMatching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ import {
mapSubtypes,
partiallySpecializeType,
preserveUnknown,
shouldUseVarianceForSpecialization,
specializeTupleClass,
specializeWithUnknownTypeArgs,
transformPossibleRecursiveTypeAlias,
Expand Down Expand Up @@ -757,7 +758,12 @@ function narrowTypeBasedOnClassPattern(
// specialize it with Unknown type arguments.
if (isClass(exprType) && !exprType.props?.typeAliasInfo) {
exprType = ClassType.cloneRemoveTypePromotions(exprType);
exprType = specializeWithUnknownTypeArgs(exprType, evaluator.getTupleClassType());
evaluator.inferVarianceForClass(exprType);
exprType = specializeWithUnknownTypeArgs(
exprType,
evaluator.getTupleClassType(),
shouldUseVarianceForSpecialization(type) ? evaluator.getObjectType() : undefined
);
}

// Are there any positional arguments? If so, try to get the mappings for
Expand Down
19 changes: 15 additions & 4 deletions packages/pyright-internal/src/analyzer/typeGuards.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ import {
makeTypeVarsFree,
mapSubtypes,
MemberAccessFlags,
shouldUseVarianceForSpecialization,
specializeTupleClass,
specializeWithUnknownTypeArgs,
stripTypeForm,
Expand Down Expand Up @@ -623,7 +624,11 @@ export function getTypeNarrowingCallback(
const arg1TypeResult = evaluator.getTypeOfExpression(arg1Expr, EvalFlags.IsInstanceArgDefaults);
const arg1Type = arg1TypeResult.type;

const classTypeList = getIsInstanceClassTypes(evaluator, arg1Type);
const classTypeList = getIsInstanceClassTypes(
evaluator,
arg1Type,
evaluator.getTypeOfExpression(arg0Expr).type
);
const isIncomplete = !!callTypeResult.isIncomplete || !!arg1TypeResult.isIncomplete;

if (classTypeList) {
Expand Down Expand Up @@ -1125,17 +1130,23 @@ function narrowTypeForIsEllipsis(evaluator: TypeEvaluator, node: ExpressionNode,
// which form and returns a list of classes or undefined.
export function getIsInstanceClassTypes(
evaluator: TypeEvaluator,
argType: Type
argType: Type,
typeToNarrow: Type
): (ClassType | TypeVarType | FunctionType)[] | undefined {
let foundNonClassType = false;
const classTypeList: (ClassType | TypeVarType | FunctionType)[] = [];

const useVarianceForSpecialization = shouldUseVarianceForSpecialization(typeToNarrow);
// Create a helper function that returns a list of class types or
// undefined if any of the types are not valid.
const addClassTypesToList = (types: Type[]) => {
types.forEach((subtype) => {
if (isClass(subtype)) {
subtype = specializeWithUnknownTypeArgs(subtype, evaluator.getTupleClassType());
evaluator.inferVarianceForClass(subtype);
subtype = specializeWithUnknownTypeArgs(
subtype,
evaluator.getTupleClassType(),
useVarianceForSpecialization ? evaluator.getObjectType() : undefined
);

if (isInstantiableClass(subtype) && ClassType.isBuiltIn(subtype, 'Callable')) {
subtype = convertToInstantiable(getUnknownTypeForCallable());
Expand Down
68 changes: 61 additions & 7 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1053,9 +1053,35 @@ export function getTypeVarScopeIds(type: Type): TypeVarScopeId[] {
return scopeIds;
}

// Specializes the class with "Unknown" type args (or the equivalent for ParamSpecs
// or TypeVarTuples).
export function specializeWithUnknownTypeArgs(type: ClassType, tupleClassType?: ClassType): ClassType {
/**
* if the type we're narrowing is Any or Unknown, we don't want to specialize using the
* variance/bound for compatibility with less strictly typed code (cringe)
*/
export const shouldUseVarianceForSpecialization = (type: Type) =>
!isAnyOrUnknown(type) &&
!isPartlyUnknown(type) &&
// TODO: this logic should probably be moved into `isAny`/`isUnknown` or something,
// to fix issues like https://github.com/DetachHead/basedpyright/issues/746
(type.category !== TypeCategory.TypeVar || !type.shared.isSynthesized);

/**
* Specializes the class with "Unknown" type args (or the equivalent for ParamSpecs or TypeVarTuples), or its
* widest possible type if its variance is known and {@link objectTypeForVarianceCheck} is provided (`object` if
* the bound if covariant, `Never` if contravariant). see docstring on {@link getUnknownForTypeVar} for more info
*
* @param tupleClassType the builtin `tuple` type for special-casing tuples. needs to be passed so that this
* module doesn't depend on `typeEvaluator.ts`
* @param objectTypeForVarianceCheck the builtin `object` type to be returned if the type var is covariant.
* passing this parameter enables the variance check which allows it to return a better result than just "Unknown"
* in cases where the variance is known (ie. `object` or its bound if it's covariant, and `Never` if it's
* contravariant). needs to be passed so that this module doesn't depend on `typeEvaluator.ts`. note that
* `evaluator.inferVarianceForClass` needs to be called on {@link type} first if passing this parameter
*/
export function specializeWithUnknownTypeArgs(
type: ClassType,
tupleClassType?: ClassType,
objectTypeForVarianceCheck?: Type
): ClassType {
if (type.shared.typeParams.length === 0) {
return type;
}
Expand All @@ -1073,22 +1099,50 @@ export function specializeWithUnknownTypeArgs(type: ClassType, tupleClassType?:

return ClassType.specialize(
type,
type.shared.typeParams.map((param) => getUnknownForTypeVar(param, tupleClassType)),
type.shared.typeParams.map((param) => getUnknownForTypeVar(param, tupleClassType, objectTypeForVarianceCheck)),
/* isTypeArgExplicit */ false,
/* includeSubclasses */ type.priv.includeSubclasses
);
}

// Returns "Unknown" for simple TypeVars or the equivalent for a ParamSpec.
export function getUnknownForTypeVar(typeVar: TypeVarType, tupleClassType?: ClassType): Type {
/**
* Returns "Unknown" for simple TypeVars or the equivalent for a ParamSpec, or the widest allowed type if
* {@link objectTypeForVarianceCheck} is provided.
*
* ideally it would always do the variance check, but doing so interferes with bidirectional type inference
* in some edge cases. see https://github.com/microsoft/pyright/issues/5404#issuecomment-1639667443
*
* @param tupleClassType the builtin `tuple` type for special-casing tuples. needs to be passed so that this
* module doesn't depend on `typeEvaluator.ts`
* @param objectTypeForVarianceCheck the builtin `object` type to be returned if the type var is covariant.
* passing this parameter enables the variance check which allows it to return a better result than just "Unknown"
* in cases where the variance is known (ie. `object` or its bound if it's covariant, and `Never` if it's
* contravariant). needs to be passed so that this module doesn't depend on `typeEvaluator.ts`. note that
* `evaluator.inferVarianceForClass` needs to be called on {@link type} first if passing this parameter
*/
export function getUnknownForTypeVar(
typeVar: TypeVarType,
tupleClassType?: ClassType,
objectTypeForVarianceCheck?: Type
): Type {
if (isParamSpec(typeVar)) {
return ParamSpecType.getUnknown();
}

if (isTypeVarTuple(typeVar) && tupleClassType) {
return getUnknownForTypeVarTuple(tupleClassType);
}

if (objectTypeForVarianceCheck) {
// if there are no usages of the TypeVar on the class and its variance isn't explicitly specified, it won't be
// known yet. https://github.com/DetachHead/basedpyright/issues/744
const variance = TypeVarType.getVariance(typeVar);
if (variance === Variance.Covariant) {
return typeVar.shared.boundType ?? objectTypeForVarianceCheck;
}
if (variance === Variance.Contravariant) {
return NeverType.createNever();
}
}
return UnknownType.create();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ChildB1(ParentB[_T2]):

def func4(var: ParentB[int]):
if isinstance(var, ChildB1):
reveal_type(var, expected_text="ChildB1[int]")
reveal_type(var, expected_text="<subclass of ParentB[int] and ChildB1[float]>")
Comment on lines 52 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem like a good change. If we already have a narrower type var, it seems like it shouldn't widen it.
Do you know an example where narrowing to ChildB1[int] would be unsafe?

If it's just an issue of being difficult to implement, I think this change probably isn't bad enough to be a blocker.



def func5(var: ParentB[Any]):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import Any, Never, assert_type, Iterable, Iterator, MutableMapping, Reversible


class Covariant[T]:
def foo(self, other: object):
if isinstance(other, Covariant):
assert_type(other, Covariant[object])

def bar(self) -> T: ...

class CovariantByDefault[T]:
"""by default if there are no usages of a type var on a class, it's treated as covariant.
imo this should be an error. see https://github.com/DetachHead/basedpyright/issues/744"""
def foo(self, other: object):
if isinstance(other, CovariantByDefault):
assert_type(other, CovariantByDefault[object])

class CovariantWithBound[T: int | str]:
def foo(self, other: object):
if isinstance(other, CovariantWithBound):
assert_type(other, CovariantWithBound[int | str])

def bar(self) -> T: ...

class Contravariant[T]:
def foo(self, other: object):
if isinstance(other, Contravariant):
assert_type(other, Contravariant[Never])

def bar(self, other: T): ...

class ContravariantWithBound[T: int | str]:
def foo(self, other: object):
if isinstance(other, ContravariantWithBound):
assert_type(other, ContravariantWithBound[Never])

def bar(self, other: T): ...


def foo(value: object):
match value:
case Iterable():
assert_type(value, Iterable[object])

class AnyOrUnknown:
"""for backwards compatibility with badly typed code we keep the old functionality when narrowing `Any`/Unknown"""
def __init__(self, value):
"""arguments in `__init__` get turned into fake type vars if they're untyped, so we need to handle this case.
see https://github.com/DetachHead/basedpyright/issues/746"""
if isinstance(value, Iterable):
assert_type(value, Iterable[Any])

def any(self, value: Any):
if isinstance(value, Iterable):
assert_type(value, Iterable[Any])

def match_case(self, value: Any):
match value:
case Iterable():
assert_type(value, Iterable[Any])

def unknown(self, value):
if isinstance(value, Iterable):
assert_type(value, Iterable[Any])

def partially_unknown(self, value=None):
if isinstance(value, Iterable):
assert_type(value, Iterable[Any])

def foo[KT,VT](self: MutableMapping[KT, VT]) -> Iterator[KT]:
assert isinstance(self, Reversible) # fail
return reversed(self)
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,12 @@ test('subscript context manager types on 3.8', () => {
],
});
});

test('narrowing type vars using their bounds', () => {
const configOptions = new ConfigOptions(Uri.empty());
configOptions.diagnosticRuleSet.reportUnusedParameter = 'none';
const analysisResults = typeAnalyzeSampleFiles(['typeNarrowingUsingBounds.py'], configOptions);
validateResultsButBased(analysisResults, {
errors: [],
});
});
Loading