From 173a643107986f4ecff7da768199b168865b53f9 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Thu, 24 Oct 2024 11:55:28 -0700 Subject: [PATCH] Fixed bug that leads to a false negative when passing multiple `*args` or `**kwargs` arguments to a callable parameterized by a ParamSpec. This addresses #9319. --- .../src/analyzer/typeEvaluator.ts | 16 +++++++++++++++ .../src/localization/localize.ts | 2 ++ .../src/localization/package.nls.en-us.json | 4 ++++ .../src/tests/samples/paramSpec49.py | 3 +++ .../src/tests/samples/paramSpec8.py | 20 ++++++++++++------- .../src/tests/typeEvaluator4.test.ts | 4 ++-- 6 files changed, 40 insertions(+), 9 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index ae504d03a588..f5dc12135673 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -11792,12 +11792,28 @@ export function createTypeEvaluator( if (paramSpec) { if (argParam.argument.argCategory === ArgCategory.UnpackedList) { if (isParamSpecArgs(paramSpec, argResult.argType)) { + if (sawParamSpecArgs) { + addDiagnostic( + DiagnosticRule.reportCallIssue, + LocMessage.paramSpecArgsKwargsDuplicate().format({ type: printType(paramSpec) }), + argParam.errorNode + ); + } + sawParamSpecArgs = true; } } if (argParam.argument.argCategory === ArgCategory.UnpackedDictionary) { if (isParamSpecKwargs(paramSpec, argResult.argType)) { + if (sawParamSpecKwargs) { + addDiagnostic( + DiagnosticRule.reportCallIssue, + LocMessage.paramSpecArgsKwargsDuplicate().format({ type: printType(paramSpec) }), + argParam.errorNode + ); + } + sawParamSpecKwargs = true; } } diff --git a/packages/pyright-internal/src/localization/localize.ts b/packages/pyright-internal/src/localization/localize.ts index 38ff00ee68fb..9e2058727236 100644 --- a/packages/pyright-internal/src/localization/localize.ts +++ b/packages/pyright-internal/src/localization/localize.ts @@ -775,6 +775,8 @@ export namespace Localizer { new ParameterizedString<{ name: string }>(getRawString('Diagnostic.paramAnnotationMissing')); export const paramNameMissing = () => new ParameterizedString<{ name: string }>(getRawString('Diagnostic.paramNameMissing')); + export const paramSpecArgsKwargsDuplicate = () => + new ParameterizedString<{ type: string }>(getRawString('Diagnostic.paramSpecArgsKwargsDuplicate')); export const paramSpecArgsKwargsUsage = () => getRawString('Diagnostic.paramSpecArgsKwargsUsage'); export const paramSpecArgsMissing = () => new ParameterizedString<{ type: string }>(getRawString('Diagnostic.paramSpecArgsMissing')); diff --git a/packages/pyright-internal/src/localization/package.nls.en-us.json b/packages/pyright-internal/src/localization/package.nls.en-us.json index 0630fb87d023..3b146339581f 100644 --- a/packages/pyright-internal/src/localization/package.nls.en-us.json +++ b/packages/pyright-internal/src/localization/package.nls.en-us.json @@ -920,6 +920,10 @@ "paramAnnotationMissing": "Type annotation is missing for parameter \"{name}\"", "paramAssignmentMismatch": "Expression of type \"{sourceType}\" cannot be assigned to parameter of type \"{paramType}\"", "paramNameMissing": "No parameter named \"{name}\"", + "paramSpecArgsKwargsDuplicate": { + "message": "Arguments for ParamSpec \"{type}\" have already been provided", + "comment": "{Locked='ParamSpec'}" + }, "paramSpecArgsKwargsUsage": { "message": "\"args\" and \"kwargs\" attributes of ParamSpec must both appear within a function signature", "comment": "{Locked='args','kwargs','ParamSpec'}" diff --git a/packages/pyright-internal/src/tests/samples/paramSpec49.py b/packages/pyright-internal/src/tests/samples/paramSpec49.py index 44ce563f50d1..a12574f6c220 100644 --- a/packages/pyright-internal/src/tests/samples/paramSpec49.py +++ b/packages/pyright-internal/src/tests/samples/paramSpec49.py @@ -57,3 +57,6 @@ def inner6(*args: P.args, **kwargs: P.kwargs) -> None: # extra *args argument. self.dispatcher.dispatch(stub, 1, *args, *args, **kwargs) + # This should generate an error because it has an + # extra **kwargs argument. + self.dispatcher.dispatch(stub, 1, *args, **kwargs, **kwargs) diff --git a/packages/pyright-internal/src/tests/samples/paramSpec8.py b/packages/pyright-internal/src/tests/samples/paramSpec8.py index 7f6f0e726a92..c548d8fd4117 100644 --- a/packages/pyright-internal/src/tests/samples/paramSpec8.py +++ b/packages/pyright-internal/src/tests/samples/paramSpec8.py @@ -17,7 +17,7 @@ def func2(*args: P.args, s: str, t: int, **kwargs: P.kwargs) -> None: # Rejecte def remove(f: Callable[Concatenate[int, P], int]) -> Callable[P, None]: - def foo(*args: P.args, **kwargs: P.kwargs) -> None: + def func1(*args: P.args, **kwargs: P.kwargs) -> None: f(1, *args, **kwargs) # Accepted # Should generate an error because positional parameter @@ -28,18 +28,24 @@ def foo(*args: P.args, **kwargs: P.kwargs) -> None: # is missing. f(*args, **kwargs) # Rejected - return foo + return func1 def outer(f: Callable[P, None]) -> Callable[P, None]: - def foo(x: int, *args: P.args, **kwargs: P.kwargs) -> None: + def func1(x: int, *args: P.args, **kwargs: P.kwargs) -> None: f(*args, **kwargs) - def bar(*args: P.args, **kwargs: P.kwargs) -> None: - foo(1, *args, **kwargs) # Accepted + def func2(*args: P.args, **kwargs: P.kwargs) -> None: + func1(1, *args, **kwargs) # Accepted # This should generate an error because keyword parameters # are not allowed in this situation. - foo(x=1, *args, **kwargs) # Rejected + func1(x=1, *args, **kwargs) # Rejected - return bar + # This should generate an error because *args is duplicated. + func1(1, *args, *args, **kwargs) + + # This should generate an error because **kwargs is duplicated. + func1(1, *args, **kwargs, **kwargs) + + return func2 diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index 0648f0257203..0e30d6773bfe 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -607,7 +607,7 @@ test('ParamSpec7', () => { test('ParamSpec8', () => { const results = TestUtils.typeAnalyzeSampleFiles(['paramSpec8.py']); - TestUtils.validateResults(results, 5); + TestUtils.validateResults(results, 7); }); test('ParamSpec9', () => { @@ -812,7 +812,7 @@ test('ParamSpec48', () => { test('ParamSpec49', () => { const results = TestUtils.typeAnalyzeSampleFiles(['paramSpec49.py']); - TestUtils.validateResults(results, 5); + TestUtils.validateResults(results, 7); }); test('ParamSpec50', () => {