Skip to content

Commit

Permalink
Added support for *args: Unpack[T] when T is a type variable with…
Browse files Browse the repository at this point in the history
… an upper bound of a tuple. (#9179)
  • Loading branch information
erictraut authored Oct 9, 2024
1 parent 874daba commit a1d1b8a
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 86 deletions.
28 changes: 17 additions & 11 deletions packages/pyright-internal/src/analyzer/constraintSolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import {
isEffectivelyInstantiable,
isLiteralTypeOrUnion,
isPartlyUnknown,
makeUnpacked,
mapSubtypes,
simplifyFunctionToParamSpec,
sortTypes,
Expand Down Expand Up @@ -138,17 +139,22 @@ export function assignTypeVar(
isAssignable = assignParamSpec(evaluator, destType, srcType, diag, constraints, recursionCount);
} else {
if (isTypeVarTuple(destType) && !destType.priv.isInUnion) {
const tupleClassType = evaluator.getTupleClassType();
if (!isUnpacked(srcType) && tupleClassType) {
// Package up the type into a tuple.
srcType = convertToInstance(
specializeTupleClass(
tupleClassType,
[{ type: srcType, isUnbounded: false }],
/* isTypeArgExplicit */ true,
/* isUnpacked */ true
)
);
if (destType.priv.isUnpacked) {
const tupleClassType = evaluator.getTupleClassType();

if (!isUnpacked(srcType) && tupleClassType) {
// Package up the type into a tuple.
srcType = convertToInstance(
specializeTupleClass(
tupleClassType,
[{ type: srcType, isUnbounded: false }],
/* isTypeArgExplicit */ true,
/* isUnpacked */ true
)
);
}
} else {
srcType = makeUnpacked(srcType);
}
}

Expand Down
7 changes: 5 additions & 2 deletions packages/pyright-internal/src/analyzer/patternMatching.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import {
isTypeSame,
isTypeVarTuple,
isUnknown,
isUnpackedTypeVar,
isUnpackedTypeVarTuple,
} from './types';
import {
Expand Down Expand Up @@ -1387,7 +1388,7 @@ function getSequencePatternInfo(
];

const tupleIndeterminateIndex = typeArgs.findIndex(
(t) => t.isUnbounded || isUnpackedTypeVarTuple(t.type)
(t) => t.isUnbounded || isUnpackedTypeVarTuple(t.type) || isUnpackedTypeVar(t.type)
);

let tupleDeterminateEntryCount = typeArgs.length;
Expand Down Expand Up @@ -1417,7 +1418,9 @@ function getSequencePatternInfo(
const removedEntries = typeArgs.splice(patternStarEntryIndex, entriesToCombine);
typeArgs.splice(patternStarEntryIndex, 0, {
type: combineTypes(removedEntries.map((t) => t.type)),
isUnbounded: removedEntries.every((t) => t.isUnbounded || isUnpackedTypeVarTuple(t.type)),
isUnbounded: removedEntries.every(
(t) => t.isUnbounded || isUnpackedTypeVarTuple(t.type) || isUnpackedTypeVar(t.type)
),
});
}

Expand Down
7 changes: 5 additions & 2 deletions packages/pyright-internal/src/analyzer/tuples.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
isTypeVar,
isTypeVarTuple,
isUnion,
isUnpackedTypeVar,
isUnpackedTypeVarTuple,
TupleTypeArg,
Type,
Expand Down Expand Up @@ -366,9 +367,11 @@ export function adjustTupleTypeArgs(
srcTypeArgs: TupleTypeArg[],
flags: AssignTypeFlags
): boolean {
const destUnboundedOrVariadicIndex = destTypeArgs.findIndex((t) => t.isUnbounded || isTypeVarTuple(t.type));
const destUnboundedOrVariadicIndex = destTypeArgs.findIndex(
(t) => t.isUnbounded || isUnpackedTypeVarTuple(t.type) || isUnpackedTypeVar(t.type)
);
const srcUnboundedIndex = srcTypeArgs.findIndex((t) => t.isUnbounded);
const srcVariadicIndex = srcTypeArgs.findIndex((t) => isTypeVarTuple(t.type));
const srcVariadicIndex = srcTypeArgs.findIndex((t) => isUnpackedTypeVarTuple(t.type) || isUnpackedTypeVar(t.type));

if (srcUnboundedIndex >= 0) {
if (isAnyOrUnknown(srcTypeArgs[srcUnboundedIndex].type)) {
Expand Down
73 changes: 43 additions & 30 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ import {
buildSolutionFromSpecializedClass,
ClassMember,
combineSameSizedTuples,
combineTupleTypeArgs,
combineVariances,
computeMroLinearization,
containsAnyOrUnknown,
Expand Down Expand Up @@ -337,6 +338,7 @@ import {
lookUpObjectMember,
makeFunctionTypeVarsBound,
makeInferenceContext,
makePacked,
makeTypeVarsBound,
makeTypeVarsFree,
mapSignatures,
Expand Down Expand Up @@ -1831,11 +1833,7 @@ export function createTypeEvaluator(
// Handle "LiteralString" specially.
if (strClass && isInstantiableClass(strClass)) {
let strInstance = ClassType.cloneAsInstance(strClass);

if (subtype.props?.condition) {
strInstance = TypeBase.cloneForCondition(strInstance, getTypeCondition(subtype));
}

strInstance = TypeBase.cloneForCondition(strInstance, getTypeCondition(subtype));
return strInstance;
}
}
Expand Down Expand Up @@ -4054,6 +4052,10 @@ export function createTypeEvaluator(
});
}

if (subtype.priv.isUnpacked && isClass(boundType)) {
boundType = ClassType.cloneForUnpacked(boundType);
}

boundType = TypeBase.isInstantiable(subtype) ? convertToInstantiable(boundType) : boundType;

return addConditionToType(boundType, [{ typeVar: subtype, constraintIndex: 0 }]);
Expand Down Expand Up @@ -7973,7 +7975,7 @@ export function createTypeEvaluator(
};
} else if (options?.isAnnotatedClass && argIndex > 0) {
// If it's an Annotated[a, b, c], only the first index should be
// treated as a type.The others can be regular(non - type) objects.
// treated as a type. The others can be regular (non-type) objects.
adjFlags =
EvalFlags.NoParamSpec | EvalFlags.NoTypeVarTuple | EvalFlags.NoSpecialize | EvalFlags.NoClassVar;
if (isAnnotationEvaluationPostponed(AnalyzerNodeInfo.getFileInfo(node))) {
Expand Down Expand Up @@ -8076,11 +8078,7 @@ export function createTypeEvaluator(
const upperBound = type.shared.boundType;

if (upperBound && isClassInstance(upperBound) && isTupleClass(upperBound)) {
const concrete = makeTopLevelTypeVarsConcrete(type);

if (isInstantiableClass(concrete)) {
return ClassType.cloneForUnpacked(concrete);
}
return TypeVarType.cloneForUnpacked(type);
}

return undefined;
Expand Down Expand Up @@ -10620,7 +10618,7 @@ export function createTypeEvaluator(
const paramType = paramInfo.type;
const paramName = paramInfo.param.name;

const isParamVariadic = paramInfo.param.category === ParamCategory.ArgsList && isTypeVarTuple(paramType);
const isParamVariadic = paramInfo.param.category === ParamCategory.ArgsList && isUnpacked(paramType);

if (argList[argIndex].argCategory === ArgCategory.UnpackedList) {
let isArgCompatibleWithVariadic = false;
Expand Down Expand Up @@ -10786,7 +10784,7 @@ export function createTypeEvaluator(
effectiveParamType = paramType.priv.tupleTypeArgs[0].type;
}

paramCategory = isTypeVarTuple(effectiveParamType) ? ParamCategory.ArgsList : ParamCategory.Simple;
paramCategory = isUnpacked(effectiveParamType) ? ParamCategory.ArgsList : ParamCategory.Simple;

if (remainingArgCount <= remainingParamCount) {
if (remainingArgCount < remainingParamCount) {
Expand Down Expand Up @@ -11368,7 +11366,7 @@ export function createTypeEvaluator(
const paramType = paramDetails.params[paramDetails.argsIndex].type;
const variadicArgs = validateArgTypeParams.filter((argParam) => argParam.mapsToVarArgList);

if (isTypeVarTuple(paramType) && !paramType.priv.isInUnion) {
if (isUnpacked(paramType) && (!isTypeVarTuple(paramType) || !paramType.priv.isInUnion)) {
const tupleTypeArgs: TupleTypeArg[] = variadicArgs.map((argParam) => {
const argType = getTypeOfArg(argParam.argument, /* inferenceContext */ undefined).type;

Expand Down Expand Up @@ -11401,23 +11399,22 @@ export function createTypeEvaluator(
};
});

let specializedTuple: Type;
if (
tupleTypeArgs.length === 1 &&
!tupleTypeArgs[0].isUnbounded &&
(isUnpackedClass(tupleTypeArgs[0].type) || isTypeVarTuple(tupleTypeArgs[0].type))
) {
// If there is a single unpacked tuple or unpacked variadic type variable
// (including an unpacked TypeVarTuple union) within this tuple,
// simplify the type.
specializedTuple = tupleTypeArgs[0].type;
} else {
specializedTuple = makeTupleObject(evaluatorInterface, tupleTypeArgs, /* isUnpacked */ true);
let specializedTuple: Type | undefined;
if (tupleTypeArgs.length === 1 && !tupleTypeArgs[0].isUnbounded) {
const entryType = tupleTypeArgs[0].type;

if (isUnpacked(entryType)) {
specializedTuple = makePacked(entryType);
}
}

if (!specializedTuple) {
specializedTuple = makeTupleObject(evaluatorInterface, tupleTypeArgs, /* isUnpacked */ false);
}

const combinedArg: ValidateArgTypeParams = {
paramCategory: ParamCategory.ArgsList,
paramType,
paramCategory: ParamCategory.Simple,
paramType: makePacked(paramType),
requiresTypeVarMatching: true,
argument: {
argCategory: ArgCategory.Simple,
Expand Down Expand Up @@ -11879,7 +11876,7 @@ export function createTypeEvaluator(

// If the final return type is an unpacked tuple, turn it into a normal (unpacked) tuple.
if (isUnpackedClass(specializedReturnType)) {
specializedReturnType = ClassType.cloneForUnpacked(specializedReturnType, /* isUnpacked */ false);
specializedReturnType = ClassType.cloneForPacked(specializedReturnType);
}

const liveTypeVarScopes = ParseTreeUtils.getTypeVarScopesForNode(errorNode);
Expand Down Expand Up @@ -18834,7 +18831,7 @@ export function createTypeEvaluator(
}

if (isUnpackedClass(type)) {
return ClassType.cloneForUnpacked(type, /* isUnpacked */ false);
return ClassType.cloneForPacked(type);
}

return makeTupleObject(evaluatorInterface, [{ type, isUnbounded: !isTypeVarTuple(type) }]);
Expand Down Expand Up @@ -24194,6 +24191,22 @@ export function createTypeEvaluator(

let concreteSrcType = makeTopLevelTypeVarsConcrete(srcType);
if (isClass(concreteSrcType) && TypeBase.isInstance(concreteSrcType)) {
// Handle the case where the source is an unpacked tuple.
if (
!destType.priv.isUnpacked &&
concreteSrcType.priv.isUnpacked &&
concreteSrcType.priv.tupleTypeArgs
) {
return assignType(
destType,
combineTupleTypeArgs(concreteSrcType.priv.tupleTypeArgs),
diag,
constraints,
flags,
recursionCount
);
}

// Handle enum literals that are assignable to another (non-Enum) literal.
// This can happen for IntEnum and StrEnum members.
if (
Expand Down
12 changes: 5 additions & 7 deletions packages/pyright-internal/src/analyzer/typePrinter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -615,14 +615,12 @@ function printTypeInternal(
(printTypeFlags & PrintTypeFlags.OmitTypeVarScopes) === 0
);

if (isTypeVarTuple(type)) {
if (type.priv.isUnpacked) {
typeVarName = _printUnpack(typeVarName, printTypeFlags);
}
if (type.priv.isUnpacked) {
typeVarName = _printUnpack(typeVarName, printTypeFlags);
}

if (type.priv.isInUnion) {
typeVarName = `Union[${typeVarName}]`;
}
if (isTypeVarTuple(type) && type.priv.isInUnion) {
typeVarName = `Union[${typeVarName}]`;
}

if (TypeBase.isInstantiable(type)) {
Expand Down
Loading

0 comments on commit a1d1b8a

Please sign in to comment.