From 8549b370369dcc048a5bad602c750cd7de3ee469 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Mon, 30 Sep 2024 21:04:05 -0700 Subject: [PATCH] Fixed several bugs related to the explicit specialization of a generic type alias parameterized by a single ParamSpec. `Concatenate` was being handled incorrectly. This addresses #9088. --- .../src/analyzer/typeEvaluator.ts | 154 ++++++++++-------- .../src/tests/samples/paramSpec13.py | 42 +++-- .../src/tests/typeEvaluator4.test.ts | 2 +- 3 files changed, 113 insertions(+), 85 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index a9b657432204..0d7614406d6e 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -7040,27 +7040,14 @@ export function createTypeEvaluator( inferVarianceForTypeAlias(baseType); const typeParams = aliasInfo.shared.typeParams; - let typeArgs = adjustTypeArgsForTypeVarTuple(getTypeArgs(node, flags), typeParams, node); + let typeArgs: TypeResultWithNode[] | undefined; + typeArgs = adjustTypeArgsForTypeVarTuple(getTypeArgs(node, flags), typeParams, node); let reportedError = false; - // PEP 612 says that if the class has only one type parameter consisting - // of a ParamSpec, the list of arguments does not need to be enclosed in - // a list. We'll handle that case specially here. Presumably this applies to - // type aliases as well. - if (typeParams.length === 1 && isParamSpec(typeParams[0]) && typeArgs) { - if ( - typeArgs.every( - (typeArg) => !isEllipsisType(typeArg.type) && !typeArg.typeList && !isParamSpec(typeArg.type) - ) - ) { - typeArgs = [ - { - type: UnknownType.create(), - node: typeArgs.length > 0 ? typeArgs[0].node : node, - typeList: typeArgs, - }, - ]; - } + typeArgs = transformTypeArgsForParamSpec(typeParams, typeArgs, node); + if (!typeArgs) { + typeArgs = []; + reportedError = true; } let minTypeArgCount = typeParams.length; @@ -7112,12 +7099,18 @@ export function createTypeEvaluator( if (typeList) { const functionType = FunctionType.createSynthesizedInstance('', FunctionTypeFlags.ParamSpecValue); - typeList.forEach((paramType, paramIndex) => { + typeList.forEach((paramTypeResult, paramIndex) => { + let paramType = paramTypeResult.type; + + if (!validateTypeArg(paramTypeResult)) { + paramType = UnknownType.create(); + } + FunctionType.addParam( functionType, FunctionParam.create( ParamCategory.Simple, - convertToInstance(paramType.type), + convertToInstance(paramType), FunctionParamFlags.NameSynthesized | FunctionParamFlags.TypeDeclared, `__p${paramIndex}` ) @@ -20711,53 +20704,9 @@ export function createTypeEvaluator( let typeArgTypes: Type[] = []; const fullTypeParams = ClassType.getTypeParams(classType); - // PEP 612 says that if the class has only one type parameter consisting - // of a ParamSpec, the list of arguments does not need to be enclosed in - // a list. We'll handle that case specially here. - if (fullTypeParams.length === 1 && isParamSpec(fullTypeParams[0]) && typeArgs) { - if ( - typeArgs.every( - (typeArg) => !isEllipsisType(typeArg.type) && !typeArg.typeList && !isParamSpec(typeArg.type) - ) - ) { - if ( - typeArgs.length !== 1 || - !isInstantiableClass(typeArgs[0].type) || - !ClassType.isBuiltIn(typeArgs[0].type, 'Concatenate') - ) { - // Package up the type arguments into a typeList. - typeArgs = - typeArgs.length > 0 - ? [ - { - type: UnknownType.create(), - node: typeArgs[0].node, - typeList: typeArgs, - }, - ] - : []; - } - } else if (typeArgs.length > 1) { - const paramSpecTypeArg = typeArgs.find((typeArg) => isParamSpec(typeArg.type)); - if (paramSpecTypeArg) { - isValidTypeForm = false; - addDiagnostic( - DiagnosticRule.reportInvalidTypeForm, - LocMessage.paramSpecContext(), - paramSpecTypeArg.node - ); - } - - const listTypeArg = typeArgs.find((typeArg) => !!typeArg.typeList); - if (listTypeArg) { - isValidTypeForm = false; - addDiagnostic( - DiagnosticRule.reportInvalidTypeForm, - LocMessage.typeArgListNotAllowed(), - listTypeArg.node - ); - } - } + typeArgs = transformTypeArgsForParamSpec(fullTypeParams, typeArgs, errorNode); + if (!typeArgs) { + isValidTypeForm = false; } const constraints = new ConstraintTracker(); @@ -20906,6 +20855,75 @@ export function createTypeEvaluator( return { type: specializedClass }; } + // PEP 612 says that if the class has only one type parameter consisting + // of a ParamSpec, the list of arguments does not need to be enclosed in + // a list. We'll handle that case specially here. + function transformTypeArgsForParamSpec( + typeParams: TypeVarType[], + typeArgs: TypeResultWithNode[] | undefined, + errorNode: ExpressionNode + ): TypeResultWithNode[] | undefined { + if (typeParams.length !== 1 || !isParamSpec(typeParams[0]) || !typeArgs) { + return typeArgs; + } + + if (typeArgs.length > 1) { + for (const typeArg of typeArgs) { + if (isParamSpec(typeArg.type)) { + addDiagnostic(DiagnosticRule.reportInvalidTypeForm, LocMessage.paramSpecContext(), typeArg.node); + return undefined; + } + + if (isEllipsisType(typeArg.type)) { + addDiagnostic(DiagnosticRule.reportInvalidTypeForm, LocMessage.ellipsisContext(), typeArg.node); + return undefined; + } + + if (isInstantiableClass(typeArg.type) && ClassType.isBuiltIn(typeArg.type, 'Concatenate')) { + addDiagnostic(DiagnosticRule.reportInvalidTypeForm, LocMessage.concatenateContext(), typeArg.node); + return undefined; + } + + if (typeArg.typeList) { + addDiagnostic( + DiagnosticRule.reportInvalidTypeForm, + LocMessage.typeArgListNotAllowed(), + typeArg.node + ); + return undefined; + } + } + } + + if (typeArgs.length === 1) { + // Don't transform a type list. + if (typeArgs[0].typeList) { + return typeArgs; + } + + const typeArgType = typeArgs[0].type; + + // Don't transform a single ParamSpec or ellipsis. + if (isParamSpec(typeArgType) || isEllipsisType(typeArgType)) { + return typeArgs; + } + + // Don't transform a Concatenate. + if (isInstantiableClass(typeArgType) && ClassType.isBuiltIn(typeArgType, 'Concatenate')) { + return typeArgs; + } + } + + // Package up the type arguments into a type list. + return [ + { + type: UnknownType.create(), + node: typeArgs.length > 0 ? typeArgs[0].node : errorNode, + typeList: typeArgs, + }, + ]; + } + function getTypeOfArg(arg: Arg, inferenceContext: InferenceContext | undefined): TypeResult { if (arg.typeResult) { const type = arg.typeResult.type; diff --git a/packages/pyright-internal/src/tests/samples/paramSpec13.py b/packages/pyright-internal/src/tests/samples/paramSpec13.py index b22eb0675c5e..9d160bfb66ea 100644 --- a/packages/pyright-internal/src/tests/samples/paramSpec13.py +++ b/packages/pyright-internal/src/tests/samples/paramSpec13.py @@ -21,12 +21,10 @@ AddIntParam = Callable[Concatenate[int, _P], _T] -def func1(func: Callable[_P, _R]) -> AddIntParam[_P, _R]: - ... +def func1(func: Callable[_P, _R]) -> AddIntParam[_P, _R]: ... -def func2(a: str, b: list[int]) -> str: - ... +def func2(a: str, b: list[int]) -> str: ... v1 = func1(func2) @@ -37,19 +35,15 @@ def func2(a: str, b: list[int]) -> str: X = AddIntParam[int, int] -class RemoteResponse(Generic[_T]): - ... +class RemoteResponse(Generic[_T]): ... class RemoteFunction(Generic[_P, _R]): - def __init__(self, func: Callable[_P, _R]) -> None: - ... + def __init__(self, func: Callable[_P, _R]) -> None: ... - def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: - ... + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ... - def remote(self, *args: _P.args, **kwargs: _P.kwargs) -> RemoteResponse[_R]: - ... + def remote(self, *args: _P.args, **kwargs: _P.kwargs) -> RemoteResponse[_R]: ... r1 = RemoteFunction(func2) @@ -75,8 +69,7 @@ def remote(self, *args: _P.args, **kwargs: _P.kwargs) -> RemoteResponse[_R]: A = RemoteFunction[int, int] -def remote(func: Callable[_P, _R]) -> RemoteFunction[_P, _R]: - ... +def remote(func: Callable[_P, _R]) -> RemoteFunction[_P, _R]: ... v4 = remote(func2) @@ -87,8 +80,7 @@ def remote(func: Callable[_P, _R]) -> RemoteFunction[_P, _R]: CoroFunc = Callable[_P, Coro[_T]] -class ClassA: - ... +class ClassA: ... CheckFunc = CoroFunc[Concatenate[ClassA, _P], bool] @@ -117,3 +109,21 @@ async def takes_check_func( # This should generate an error. ta1_2: TA1[()] = lambda x: x + + +TA2: TypeAlias = Callable[Concatenate[int, _P], None] + +TA3: TypeAlias = TA2[int, int] +TA4: TypeAlias = TA2[_P] + +# This should generate an error. +TA5: TypeAlias = TA2[[int, _P]] + +# This should generate an error. +TA6: TypeAlias = TA2[[int, ...]] + +TA7: TypeAlias = TA2[Concatenate[int, _P]] +TA8: TypeAlias = TA2[Concatenate[int, ...]] + +# This should generate two errors. +TA9: TypeAlias = TA2[int, Concatenate[int, _P]] diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index 3c19e09e7009..3dde0755ecfb 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -626,7 +626,7 @@ test('ParamSpec12', () => { test('ParamSpec13', () => { const results = TestUtils.typeAnalyzeSampleFiles(['paramSpec13.py']); - TestUtils.validateResults(results, 7); + TestUtils.validateResults(results, 11); }); test('ParamSpec14', () => {